]> git.saurik.com Git - cyql.git/blob - __init__.py
ca9ae63228ae4e0ee95d1d0e851c8b465e2b346f
[cyql.git] / __init__.py
1 from __future__ import absolute_import
2 from __future__ import division
3 from __future__ import print_function
4 from __future__ import unicode_literals
5
6 from future_builtins import ascii, filter, hex, map, oct, zip
7
8 import inspect
9 import os
10
11 from contextlib import contextmanager
12
13 import psycopg2
14 import psycopg2.extras
15 import psycopg2.pool
16
17 psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
18 psycopg2.extensions.register_type(psycopg2.extensions.UNICODEARRAY)
19
20 class connect(object):
21 def __init__(self, dsn):
22 options = dsn.copy()
23 if 'cache' in options:
24 del options['cache']
25
26 if 'cache' in dsn:
27 cached = True
28 cache = dsn['cache']
29 else:
30 cached = False
31 cache = {
32 'hstore': None,
33 }
34
35 attempt = 0
36 while True:
37 try:
38 self.driver = psycopg2.connect(**options)
39 break
40 except psycopg2.OperationalError, e:
41 if attempt == 2:
42 raise
43 attempt = attempt + 1
44
45 self.driver.autocommit = True
46
47 # XXX: all of my databases default to this...
48 #try:
49 # self.driver.set_client_encoding('UNICODE')
50 #except:
51 # self.driver.close()
52 # raise
53
54 hstore = cache['hstore']
55 if hstore == None:
56 hstore = psycopg2.extras.HstoreAdapter.get_oids(self.driver)
57 if hstore != None:
58 hstore = hstore[0]
59
60 if hstore != None:
61 try:
62 psycopg2.extras.register_hstore(self.driver, globally=False, unicode=True, oid=hstore)
63 except psycopg2.ProgrammingError, e:
64 pass
65
66 if not cached:
67 dsn['cache'] = cache
68
69 def close(self):
70 self.driver.close()
71
72 def __enter__(self):
73 return self
74
75 def __exit__(self, type, value, traceback):
76 self.close()
77
78 def begin(self):
79 self.driver.autocommit = False
80
81 def commit(self):
82 self.driver.commit()
83
84 def rollback(self):
85 self.driver.rollback()
86
87 @contextmanager
88 def cursor(self):
89 cursor = self.driver.cursor(cursor_factory=psycopg2.extras.DictCursor)
90 try:
91 yield cursor
92 finally:
93 cursor.close()
94
95 @contextmanager
96 def execute(self, statement, depth=0, context=None):
97 # two frames, accounting for execute() and @contextmanager
98 frame = inspect.currentframe(depth + 2)
99
100 with self.cursor() as cursor:
101 f_globals = None
102 f_locals = frame.f_locals
103
104 if context == None:
105 context = dict(**f_locals)
106
107 start = 0
108 while True:
109 percent = statement.find('%', start)
110 if percent == -1:
111 break
112
113 next = statement[percent + 1]
114 if next == '(':
115 start = statement.index(')', percent + 2) + 2
116 assert statement[start - 1] == 's'
117 elif next == '{':
118 start = statement.index('}', percent + 2)
119 assert statement[start + 1] == 's'
120 code = statement[percent + 2:start]
121
122 if f_globals == None:
123 f_globals = frame.f_globals
124
125 key = '__cyql__%i' % (percent,)
126 # XXX: compile() in the frame's context
127 context[key] = eval(code, f_globals, f_locals)
128
129 statement = '%s%%(%s)%s' % (statement[0:percent], key, statement[start + 1:])
130 start = percent + len(key) + 4
131 elif next in ('%', 's'):
132 start = percent + 2
133 else:
134 assert False
135
136 cursor.execute(statement, context)
137
138 del context
139 del f_locals
140 del f_globals
141
142 yield cursor
143
144 @contextmanager
145 def transact(self, synchronous_commit=True):
146 self.driver.autocommit = False
147 try:
148 with self.cursor() as cursor:
149 if not synchronous_commit:
150 cursor.execute('set local synchronous_commit = off')
151
152 yield
153 self.driver.commit()
154 except:
155 self.driver.rollback()
156 raise
157 finally:
158 self.driver.autocommit = True
159
160 def one_(self, statement, context=None):
161 with self.execute(statement, 2, context) as cursor:
162 one = cursor.fetchone()
163 if one == None:
164 return None
165
166 assert cursor.fetchone() == None
167 return one
168
169 def __call__(self, procedure, *parameters):
170 with self.execute(statement, 1) as cursor:
171 return cursor.callproc(procedure, *parameters)
172
173 def run(self, statement, context=None):
174 with self.execute(statement, 1, context) as cursor:
175 return cursor.rowcount
176
177 @contextmanager
178 def set(self, statement):
179 with self.execute(statement, 2) as cursor:
180 yield cursor
181
182 def all(self, statement, context=None):
183 with self.execute(statement, 1, context) as cursor:
184 return cursor.fetchall()
185
186 def one(self, statement, context=None):
187 return self.one_(statement, context)
188
189 def has(self, statement):
190 exists, = self.one_('select exists(%s)' % (statement,))
191 return exists
192
193 def connected(dsn):
194 def wrapped(method):
195 def replaced(*args, **kw):
196 with connect(dsn) as sql:
197 return method(*args, sql=sql, **kw)
198 return replaced
199 return wrapped
200
201 @contextmanager
202 def transact(dsn, *args, **kw):
203 with connect(dsn) as connection:
204 with connection.transact(*args, **kw):
205 yield connection
206
207 """
208 def slap_(sql, table, keys, values, path):
209 csr = sql.cursor()
210 try:
211 csr.execute('savepoint iou')
212 try:
213 both = dict(keys, **values)
214 fields = both.keys()
215
216 csr.execute('''
217 insert into %s (%s) values (%s)
218 ''' % (
219 table,
220 ', '.join(fields),
221 ', '.join(['%s' for key in fields])
222 ), both.values())
223 except psycopg2.IntegrityError, e:
224 csr.execute('rollback to savepoint iou')
225
226 csr.execute('''
227 update %s set %s where %s
228 ''' % (
229 table,
230 ', '.join([
231 key + ' = %s'
232 for key in values.keys()]),
233 ' and '.join([
234 key + ' = %s'
235 for key in keys.keys()])
236 ), values.values() + keys.values())
237
238 return path_(csr, path)
239 finally:
240 csr.close()
241 """