]> git.saurik.com Git - cyql.git/blob - __init__.py
Switch to psycopg2 autocommit syntax.
[cyql.git] / __init__.py
1 from __future__ import absolute_import
2 from __future__ import division
3 from __future__ import print_function
4 from __future__ import unicode_literals
5
6 from future_builtins import ascii, filter, hex, map, oct, zip
7
8 import inspect
9 import os
10
11 from contextlib import contextmanager
12
13 import psycopg2
14 import psycopg2.extras
15 import psycopg2.pool
16
17 psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
18 psycopg2.extensions.register_type(psycopg2.extensions.UNICODEARRAY)
19
20 class connect(object):
21 def __init__(self, dsn):
22 attempt = 0
23 while True:
24 try:
25 self.driver = psycopg2.connect(**dsn)
26 break
27 except psycopg2.OperationalError, e:
28 if attempt == 2:
29 raise e
30 attempt = attempt + 1
31
32 try:
33 self.driver.set_client_encoding('UNICODE')
34 self.driver.autocommit = True
35 except:
36 self.driver.close()
37
38 try:
39 psycopg2.extras.register_hstore(self.driver, globally=False, unicode=True)
40 except psycopg2.ProgrammingError, e:
41 pass
42
43 def close(self):
44 self.driver.close()
45
46 def __enter__(self):
47 return self
48
49 def __exit__(self, type, value, traceback):
50 self.close()
51
52 def begin(self):
53 self.driver.autocommit = False
54
55 def commit(self):
56 self.driver.commit()
57
58 def rollback(self):
59 self.driver.rollback()
60
61 @contextmanager
62 def cursor(self):
63 cursor = self.driver.cursor(cursor_factory=psycopg2.extras.DictCursor)
64 try:
65 yield cursor
66 finally:
67 cursor.close()
68
69 @contextmanager
70 def execute(self, statement, depth=0, context=None):
71 # two frames, accounting for execute() and @contextmanager
72 frame = inspect.currentframe(depth + 2)
73
74 with self.cursor() as cursor:
75 f_globals = None
76 f_locals = frame.f_locals
77
78 if context == None:
79 context = dict(**f_locals)
80
81 start = 0
82 while True:
83 percent = statement.find('%', start)
84 if percent == -1:
85 break
86
87 next = statement[percent + 1]
88 if next == '(':
89 start = statement.index(')', percent + 2) + 2
90 assert statement[start - 1] == 's'
91 elif next == '{':
92 start = statement.index('}', percent + 2)
93 assert statement[start + 1] == 's'
94 code = statement[percent + 2:start]
95
96 if f_globals == None:
97 f_globals = frame.f_globals
98
99 key = '__cyql__%i' % (percent,)
100 # XXX: compile() in the frame's context
101 context[key] = eval(code, f_globals, f_locals)
102
103 statement = '%s%%(%s)%s' % (statement[0:percent], key, statement[start + 1:])
104 start = percent + len(key) + 4
105 elif next in ('%', 's'):
106 start = percent + 2
107 else:
108 assert False
109
110 cursor.execute(statement, context)
111
112 del context
113 del f_locals
114 del f_globals
115
116 yield cursor
117
118 @contextmanager
119 def transact(self, synchronous_commit=True):
120 self.driver.autocommit = False
121 try:
122 with self.cursor() as cursor:
123 if not synchronous_commit:
124 cursor.execute('set local synchronous_commit = off')
125
126 yield
127 self.driver.commit()
128 except:
129 self.driver.rollback()
130 raise
131 finally:
132 self.driver.autocommit = True
133
134 def one_(self, statement, context=None):
135 with self.execute(statement, 2, context) as cursor:
136 one = cursor.fetchone()
137 if one == None:
138 return None
139
140 assert cursor.fetchone() == None
141 return one
142
143 def __call__(self, procedure, *parameters):
144 with self.execute(statement, 1) as cursor:
145 return cursor.callproc(procedure, *parameters)
146
147 def run(self, statement, context=None):
148 with self.execute(statement, 1, context) as cursor:
149 return cursor.rowcount
150
151 @contextmanager
152 def set(self, statement):
153 with self.execute(statement, 1) as cursor:
154 yield cursor
155
156 def all(self, statement, context=None):
157 with self.execute(statement, 1, context) as cursor:
158 return cursor.fetchall()
159
160 def one(self, statement, context=None):
161 return self.one_(statement, context)
162
163 def has(self, statement):
164 exists, = self.one_('select exists(%s)' % (statement,))
165 return exists
166
167 def connected(dsn):
168 def wrapped(method):
169 def replaced(*args, **kw):
170 with connect(dsn) as sql:
171 return method(*args, sql=sql, **kw)
172 return replaced
173 return wrapped
174
175 @contextmanager
176 def transact(dsn, *args, **kw):
177 with connect(dsn) as connection:
178 with connection.transact(*args, **kw):
179 yield connection
180
181 """
182 def slap_(sql, table, keys, values, path):
183 csr = sql.cursor()
184 try:
185 csr.execute('savepoint iou')
186 try:
187 both = dict(keys, **values)
188 fields = both.keys()
189
190 csr.execute('''
191 insert into %s (%s) values (%s)
192 ''' % (
193 table,
194 ', '.join(fields),
195 ', '.join(['%s' for key in fields])
196 ), both.values())
197 except psycopg2.IntegrityError, e:
198 csr.execute('rollback to savepoint iou')
199
200 csr.execute('''
201 update %s set %s where %s
202 ''' % (
203 table,
204 ', '.join([
205 key + ' = %s'
206 for key in values.keys()]),
207 ' and '.join([
208 key + ' = %s'
209 for key in keys.keys()])
210 ), values.values() + keys.values())
211
212 return path_(csr, path)
213 finally:
214 csr.close()
215 """