]> git.saurik.com Git - cyql.git/blob - __init__.py
47f54fd19972e4a9f2e09a25e1430fab5e826247
[cyql.git] / __init__.py
1 from __future__ import absolute_import
2 from __future__ import division
3 from __future__ import print_function
4 from __future__ import unicode_literals
5
6 from future_builtins import ascii, filter, hex, map, oct, zip
7
8 import inspect
9 import os
10
11 from contextlib import contextmanager
12
13 import psycopg2
14 import psycopg2.extras
15 import psycopg2.pool
16
17 psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
18 psycopg2.extensions.register_type(psycopg2.extensions.UNICODEARRAY)
19
20 class connect(object):
21 def __init__(self, dsn):
22 options = dsn.copy()
23 if 'cache' in options:
24 del options['cache']
25
26 if 'cache' in dsn:
27 cached = True
28 cache = dsn['cache']
29 else:
30 cached = False
31 cache = {
32 'hstore': None,
33 }
34
35 attempt = 0
36 while True:
37 try:
38 self.driver = psycopg2.connect(**options)
39 break
40 except psycopg2.OperationalError, e:
41 if attempt == 2:
42 raise
43 attempt = attempt + 1
44
45 self.driver.autocommit = True
46
47 # XXX: all of my databases default to this...
48 #try:
49 # self.driver.set_client_encoding('UNICODE')
50 #except:
51 # self.driver.close()
52 # raise
53
54 hstore = cache['hstore']
55 if hstore == None:
56 hstore = psycopg2.extras.HstoreAdapter.get_oids(self.driver)
57 if hstore != None:
58 hstore = hstore[0]
59 cache['hstore'] = hstore
60
61 if hstore != None:
62 try:
63 psycopg2.extras.register_hstore(self.driver, globally=False, unicode=True, oid=hstore)
64 except psycopg2.ProgrammingError, e:
65 pass
66
67 if not cached:
68 dsn['cache'] = cache
69
70 def close(self):
71 self.driver.close()
72
73 def __enter__(self):
74 return self
75
76 def __exit__(self, type, value, traceback):
77 self.close()
78
79 def begin(self):
80 self.driver.autocommit = False
81
82 def commit(self):
83 self.driver.commit()
84
85 def rollback(self):
86 self.driver.rollback()
87
88 @contextmanager
89 def cursor(self):
90 cursor = self.driver.cursor(cursor_factory=psycopg2.extras.DictCursor)
91 try:
92 yield cursor
93 finally:
94 cursor.close()
95
96 @contextmanager
97 def execute(self, statement, depth=0, context=None):
98 # two frames, accounting for execute() and @contextmanager
99 frame = inspect.currentframe(depth + 2)
100
101 with self.cursor() as cursor:
102 f_globals = None
103 f_locals = frame.f_locals
104
105 if context == None:
106 context = dict(**f_locals)
107
108 start = 0
109 while True:
110 percent = statement.find('%', start)
111 if percent == -1:
112 break
113
114 next = statement[percent + 1]
115 if next == '(':
116 start = statement.index(')', percent + 2) + 2
117 assert statement[start - 1] == 's'
118 elif next == '{':
119 start = statement.index('}', percent + 2)
120 assert statement[start + 1] == 's'
121 code = statement[percent + 2:start]
122
123 if f_globals == None:
124 f_globals = frame.f_globals
125
126 key = '__cyql__%i' % (percent,)
127 # XXX: compile() in the frame's context
128 context[key] = eval(code, f_globals, f_locals)
129
130 statement = '%s%%(%s)%s' % (statement[0:percent], key, statement[start + 1:])
131 start = percent + len(key) + 4
132 elif next in ('%', 's'):
133 start = percent + 2
134 else:
135 assert False
136
137 cursor.execute(statement, context)
138
139 del context
140 del f_locals
141 del f_globals
142
143 yield cursor
144
145 @contextmanager
146 def transact(self, synchronous_commit=True):
147 self.driver.autocommit = False
148 try:
149 with self.cursor() as cursor:
150 if not synchronous_commit:
151 cursor.execute('set local synchronous_commit = off')
152
153 yield
154 self.driver.commit()
155 except:
156 self.driver.rollback()
157 raise
158 finally:
159 self.driver.autocommit = True
160
161 def one_(self, statement, context=None):
162 with self.execute(statement, 2, context) as cursor:
163 one = cursor.fetchone()
164 if one == None:
165 return None
166
167 assert cursor.fetchone() == None
168 return one
169
170 def __call__(self, procedure, *parameters):
171 with self.execute(statement, 1) as cursor:
172 return cursor.callproc(procedure, *parameters)
173
174 def run(self, statement, context=None):
175 with self.execute(statement, 1, context) as cursor:
176 return cursor.rowcount
177
178 @contextmanager
179 def set(self, statement):
180 with self.execute(statement, 2) as cursor:
181 yield cursor
182
183 def all(self, statement, context=None):
184 with self.execute(statement, 1, context) as cursor:
185 return cursor.fetchall()
186
187 def one(self, statement, context=None):
188 return self.one_(statement, context)
189
190 def has(self, statement):
191 exists, = self.one_('select exists(%s)' % (statement,))
192 return exists
193
194 def connected(dsn):
195 def wrapped(method):
196 def replaced(*args, **kw):
197 with connect(dsn) as sql:
198 return method(*args, sql=sql, **kw)
199 return replaced
200 return wrapped
201
202 @contextmanager
203 def transact(dsn, *args, **kw):
204 with connect(dsn) as connection:
205 with connection.transact(*args, **kw):
206 yield connection
207
208 """
209 def slap_(sql, table, keys, values, path):
210 csr = sql.cursor()
211 try:
212 csr.execute('savepoint iou')
213 try:
214 both = dict(keys, **values)
215 fields = both.keys()
216
217 csr.execute('''
218 insert into %s (%s) values (%s)
219 ''' % (
220 table,
221 ', '.join(fields),
222 ', '.join(['%s' for key in fields])
223 ), both.values())
224 except psycopg2.IntegrityError, e:
225 csr.execute('rollback to savepoint iou')
226
227 csr.execute('''
228 update %s set %s where %s
229 ''' % (
230 table,
231 ', '.join([
232 key + ' = %s'
233 for key in values.keys()]),
234 ' and '.join([
235 key + ' = %s'
236 for key in keys.keys()])
237 ), values.values() + keys.values())
238
239 return path_(csr, path)
240 finally:
241 csr.close()
242 """