]>
Commit | Line | Data |
---|---|---|
a056a391 JF |
1 | from __future__ import absolute_import |
2 | from __future__ import division | |
426cccf1 | 3 | from __future__ import print_function |
a056a391 JF |
4 | from __future__ import unicode_literals |
5 | ||
6 | from future_builtins import ascii, filter, hex, map, oct, zip | |
736932f0 JF |
7 | |
8 | import inspect | |
9 | import os | |
10 | ||
11 | from contextlib import contextmanager | |
12 | ||
13 | import psycopg2 | |
426cccf1 | 14 | import psycopg2.extras |
736932f0 JF |
15 | import psycopg2.pool |
16 | ||
736932f0 | 17 | psycopg2.extensions.register_type(psycopg2.extensions.UNICODE) |
709fb2df | 18 | psycopg2.extensions.register_type(psycopg2.extensions.UNICODEARRAY) |
736932f0 | 19 | |
18e35f46 JF |
20 | class ConnectionError(): |
21 | pass | |
22 | ||
c8a72a64 JF |
23 | class connect(object): |
24 | def __init__(self, dsn): | |
fff12dc3 JF |
25 | options = dsn.copy() |
26 | if 'cache' in options: | |
27 | del options['cache'] | |
28 | ||
29 | if 'cache' in dsn: | |
30 | cached = True | |
31 | cache = dsn['cache'] | |
32 | else: | |
33 | cached = False | |
34 | cache = { | |
35 | 'hstore': None, | |
36 | } | |
37 | ||
c8a72a64 JF |
38 | attempt = 0 |
39 | while True: | |
40 | try: | |
fff12dc3 | 41 | self.driver = psycopg2.connect(**options) |
c8a72a64 JF |
42 | break |
43 | except psycopg2.OperationalError, e: | |
ef444e49 | 44 | if e.message.startswith('could not connect to server: '): |
18e35f46 | 45 | raise ConnectionError() |
c8a72a64 | 46 | if attempt == 2: |
3027147a | 47 | raise |
c8a72a64 JF |
48 | attempt = attempt + 1 |
49 | ||
8d19a913 JF |
50 | self.driver.autocommit = True |
51 | ||
52 | # XXX: all of my databases default to this... | |
53 | #try: | |
54 | # self.driver.set_client_encoding('UNICODE') | |
55 | #except: | |
56 | # self.driver.close() | |
57 | # raise | |
c8a72a64 | 58 | |
fff12dc3 JF |
59 | hstore = cache['hstore'] |
60 | if hstore == None: | |
61 | hstore = psycopg2.extras.HstoreAdapter.get_oids(self.driver) | |
62 | if hstore != None: | |
63 | hstore = hstore[0] | |
7a6d41b9 | 64 | cache['hstore'] = hstore |
fff12dc3 JF |
65 | |
66 | if hstore != None: | |
67 | try: | |
68 | psycopg2.extras.register_hstore(self.driver, globally=False, unicode=True, oid=hstore) | |
69 | except psycopg2.ProgrammingError, e: | |
70 | pass | |
71 | ||
72 | if not cached: | |
73 | dsn['cache'] = cache | |
0420b0db | 74 | |
c8a72a64 JF |
75 | def close(self): |
76 | self.driver.close() | |
77 | ||
78 | def __enter__(self): | |
79 | return self | |
80 | ||
81 | def __exit__(self, type, value, traceback): | |
82 | self.close() | |
426cccf1 | 83 | |
bffe10fe | 84 | def begin(self): |
f392523d | 85 | self.driver.autocommit = False |
bffe10fe JF |
86 | |
87 | def commit(self): | |
88 | self.driver.commit() | |
89 | ||
90 | def rollback(self): | |
91 | self.driver.rollback() | |
92 | ||
426cccf1 JF |
93 | @contextmanager |
94 | def cursor(self): | |
38cd52b6 | 95 | cursor = self.driver.cursor(cursor_factory=psycopg2.extras.DictCursor) |
426cccf1 | 96 | try: |
426cccf1 JF |
97 | yield cursor |
98 | finally: | |
99 | cursor.close() | |
100 | ||
101 | @contextmanager | |
51ce7a27 | 102 | def execute(self, statement, depth=0, context=None): |
e12c964c JF |
103 | # two frames, accounting for execute() and @contextmanager |
104 | frame = inspect.currentframe(depth + 2) | |
105 | ||
426cccf1 | 106 | with self.cursor() as cursor: |
e12c964c | 107 | f_globals = None |
e12c964c | 108 | f_locals = frame.f_locals |
51ce7a27 JF |
109 | |
110 | if context == None: | |
111 | context = dict(**f_locals) | |
e12c964c JF |
112 | |
113 | start = 0 | |
114 | while True: | |
115 | percent = statement.find('%', start) | |
116 | if percent == -1: | |
117 | break | |
118 | ||
119 | next = statement[percent + 1] | |
120 | if next == '(': | |
121 | start = statement.index(')', percent + 2) + 2 | |
0a3825ab | 122 | assert statement[start - 1] == 's' |
e12c964c JF |
123 | elif next == '{': |
124 | start = statement.index('}', percent + 2) | |
0a3825ab | 125 | assert statement[start + 1] == 's' |
e12c964c JF |
126 | code = statement[percent + 2:start] |
127 | ||
128 | if f_globals == None: | |
129 | f_globals = frame.f_globals | |
130 | ||
131 | key = '__cyql__%i' % (percent,) | |
132 | # XXX: compile() in the frame's context | |
133 | context[key] = eval(code, f_globals, f_locals) | |
134 | ||
0a3825ab | 135 | statement = '%s%%(%s)%s' % (statement[0:percent], key, statement[start + 1:]) |
e12c964c | 136 | start = percent + len(key) + 4 |
51ce7a27 | 137 | elif next in ('%', 's'): |
e12c964c JF |
138 | start = percent + 2 |
139 | else: | |
140 | assert False | |
141 | ||
142 | cursor.execute(statement, context) | |
143 | ||
144 | del context | |
145 | del f_locals | |
146 | del f_globals | |
147 | ||
426cccf1 | 148 | yield cursor |
736932f0 | 149 | |
426cccf1 JF |
150 | @contextmanager |
151 | def transact(self, synchronous_commit=True): | |
f392523d | 152 | self.driver.autocommit = False |
426cccf1 | 153 | try: |
38cd52b6 JF |
154 | with self.cursor() as cursor: |
155 | if not synchronous_commit: | |
156 | cursor.execute('set local synchronous_commit = off') | |
157 | ||
158 | yield | |
426cccf1 JF |
159 | self.driver.commit() |
160 | except: | |
161 | self.driver.rollback() | |
162 | raise | |
38cd52b6 | 163 | finally: |
f392523d | 164 | self.driver.autocommit = True |
426cccf1 | 165 | |
98706a12 JF |
166 | def one_(self, statement, context=None): |
167 | with self.execute(statement, 2, context) as cursor: | |
38cd52b6 JF |
168 | one = cursor.fetchone() |
169 | if one == None: | |
170 | return None | |
736932f0 | 171 | |
38cd52b6 JF |
172 | assert cursor.fetchone() == None |
173 | return one | |
174 | ||
175 | def __call__(self, procedure, *parameters): | |
176 | with self.execute(statement, 1) as cursor: | |
177 | return cursor.callproc(procedure, *parameters) | |
178 | ||
51ce7a27 JF |
179 | def run(self, statement, context=None): |
180 | with self.execute(statement, 1, context) as cursor: | |
38cd52b6 | 181 | return cursor.rowcount |
38f04d60 JF |
182 | |
183 | @contextmanager | |
184 | def set(self, statement): | |
35e48155 | 185 | with self.execute(statement, 2) as cursor: |
38f04d60 JF |
186 | yield cursor |
187 | ||
bffe10fe JF |
188 | def all(self, statement, context=None): |
189 | with self.execute(statement, 1, context) as cursor: | |
38f04d60 JF |
190 | return cursor.fetchall() |
191 | ||
98706a12 JF |
192 | def one(self, statement, context=None): |
193 | return self.one_(statement, context) | |
38f04d60 | 194 | |
38cd52b6 | 195 | def has(self, statement): |
408ed285 JF |
196 | exists, = self.one_('select exists(%s)' % (statement,)) |
197 | return exists | |
f1df255a | 198 | |
91d72c6c JF |
199 | def connected(dsn): |
200 | def wrapped(method): | |
201 | def replaced(*args, **kw): | |
12c855cb JF |
202 | with connect(dsn) as sql: |
203 | return method(*args, sql=sql, **kw) | |
91d72c6c JF |
204 | return replaced |
205 | return wrapped | |
206 | ||
7d11917e | 207 | @contextmanager |
1e227340 | 208 | def transact(dsn, *args, **kw): |
7d11917e | 209 | with connect(dsn) as connection: |
1e227340 | 210 | with connection.transact(*args, **kw): |
38cd52b6 | 211 | yield connection |
7d11917e | 212 | |
426cccf1 | 213 | """ |
736932f0 JF |
214 | def slap_(sql, table, keys, values, path): |
215 | csr = sql.cursor() | |
216 | try: | |
217 | csr.execute('savepoint iou') | |
218 | try: | |
219 | both = dict(keys, **values) | |
220 | fields = both.keys() | |
221 | ||
222 | csr.execute(''' | |
223 | insert into %s (%s) values (%s) | |
224 | ''' % ( | |
225 | table, | |
226 | ', '.join(fields), | |
227 | ', '.join(['%s' for key in fields]) | |
228 | ), both.values()) | |
229 | except psycopg2.IntegrityError, e: | |
230 | csr.execute('rollback to savepoint iou') | |
231 | ||
232 | csr.execute(''' | |
233 | update %s set %s where %s | |
234 | ''' % ( | |
235 | table, | |
236 | ', '.join([ | |
237 | key + ' = %s' | |
238 | for key in values.keys()]), | |
239 | ' and '.join([ | |
240 | key + ' = %s' | |
241 | for key in keys.keys()]) | |
242 | ), values.values() + keys.values()) | |
243 | ||
244 | return path_(csr, path) | |
245 | finally: | |
246 | csr.close() | |
426cccf1 | 247 | """ |