]>
git.saurik.com Git - cyql.git/blob - __init__.py
6e53915a73de9443299dee42d256c8cb51236390
1 from __future__
import absolute_import
2 from __future__
import division
3 from __future__
import print_function
4 from __future__
import unicode_literals
6 from future_builtins
import ascii
, filter, hex, map, oct, zip
11 from contextlib
import contextmanager
14 import psycopg2
.extras
17 psycopg2
.extensions
.register_type(psycopg2
.extensions
.UNICODE
)
18 psycopg2
.extensions
.register_type(psycopg2
.extensions
.UNICODEARRAY
)
20 class ConnectionError():
23 class connect(object):
24 def __init__(self
, dsn
):
26 if 'cache' in options
:
41 self
.driver
= psycopg2
.connect(**options
)
43 except psycopg2
.OperationalError
, e
:
44 if e
.message
.startswith('could not connect to server: Connection refused\n'):
45 raise ConnectionError()
50 self
.driver
.autocommit
= True
52 # XXX: all of my databases default to this...
54 # self.driver.set_client_encoding('UNICODE')
59 hstore
= cache
['hstore']
61 hstore
= psycopg2
.extras
.HstoreAdapter
.get_oids(self
.driver
)
64 cache
['hstore'] = hstore
68 psycopg2
.extras
.register_hstore(self
.driver
, globally
=False, unicode=True, oid
=hstore
)
69 except psycopg2
.ProgrammingError
, e
:
81 def __exit__(self
, type, value
, traceback
):
85 self
.driver
.autocommit
= False
91 self
.driver
.rollback()
95 cursor
= self
.driver
.cursor(cursor_factory
=psycopg2
.extras
.DictCursor
)
102 def execute(self
, statement
, depth
=0, context
=None):
103 # two frames, accounting for execute() and @contextmanager
104 frame
= inspect
.currentframe(depth
+ 2)
106 with self
.cursor() as cursor
:
108 f_locals
= frame
.f_locals
111 context
= dict(**f_locals
)
115 percent
= statement
.find('%', start
)
119 next
= statement
[percent
+ 1]
121 start
= statement
.index(')', percent
+ 2) + 2
122 assert statement
[start
- 1] == 's'
124 start
= statement
.index('}', percent
+ 2)
125 assert statement
[start
+ 1] == 's'
126 code
= statement
[percent
+ 2:start
]
128 if f_globals
== None:
129 f_globals
= frame
.f_globals
131 key
= '__cyql__%i' % (percent
,)
132 # XXX: compile() in the frame's context
133 context
[key
] = eval(code
, f_globals
, f_locals
)
135 statement
= '%s%%(%s)%s' % (statement
[0:percent
], key
, statement
[start
+ 1:])
136 start
= percent
+ len(key
) + 4
137 elif next
in ('%', 's'):
142 cursor
.execute(statement
, context
)
151 def transact(self
, synchronous_commit
=True):
152 self
.driver
.autocommit
= False
154 with self
.cursor() as cursor
:
155 if not synchronous_commit
:
156 cursor
.execute('set local synchronous_commit = off')
161 self
.driver
.rollback()
164 self
.driver
.autocommit
= True
166 def one_(self
, statement
, context
=None):
167 with self
.execute(statement
, 2, context
) as cursor
:
168 one
= cursor
.fetchone()
172 assert cursor
.fetchone() == None
175 def __call__(self
, procedure
, *parameters
):
176 with self
.execute(statement
, 1) as cursor
:
177 return cursor
.callproc(procedure
, *parameters
)
179 def run(self
, statement
, context
=None):
180 with self
.execute(statement
, 1, context
) as cursor
:
181 return cursor
.rowcount
184 def set(self
, statement
):
185 with self
.execute(statement
, 2) as cursor
:
188 def all(self
, statement
, context
=None):
189 with self
.execute(statement
, 1, context
) as cursor
:
190 return cursor
.fetchall()
192 def one(self
, statement
, context
=None):
193 return self
.one_(statement
, context
)
195 def has(self
, statement
):
196 exists
, = self
.one_('select exists(%s)' % (statement
,))
201 def replaced(*args
, **kw
):
202 with connect(dsn
) as sql
:
203 return method(*args
, sql
=sql
, **kw
)
208 def transact(dsn
, *args
, **kw
):
209 with connect(dsn
) as connection
:
210 with connection
.transact(*args
, **kw
):
214 def slap_(sql, table, keys, values, path):
217 csr.execute('savepoint iou')
219 both = dict(keys, **values)
223 insert into %s (%s) values (%s)
227 ', '.join(['%s' for key in fields])
229 except psycopg2.IntegrityError, e:
230 csr.execute('rollback to savepoint iou')
233 update %s set %s where %s
238 for key in values.keys()]),
241 for key in keys.keys()])
242 ), values.values() + keys.values())
244 return path_(csr, path)