]>
Commit | Line | Data |
---|---|---|
1 | from __future__ import absolute_import | |
2 | from __future__ import division | |
3 | from __future__ import print_function | |
4 | from __future__ import unicode_literals | |
5 | ||
6 | from future_builtins import ascii, filter, hex, map, oct, zip | |
7 | ||
8 | import inspect | |
9 | import os | |
10 | ||
11 | from contextlib import contextmanager | |
12 | ||
13 | import psycopg2 | |
14 | import psycopg2.extras | |
15 | import psycopg2.pool | |
16 | ||
17 | psycopg2.extensions.register_type(psycopg2.extensions.UNICODE) | |
18 | psycopg2.extensions.register_type(psycopg2.extensions.UNICODEARRAY) | |
19 | ||
20 | class ConnectionError(): | |
21 | pass | |
22 | ||
23 | class connect(object): | |
24 | def __init__(self, dsn): | |
25 | options = dsn.copy() | |
26 | if 'cache' in options: | |
27 | del options['cache'] | |
28 | ||
29 | if 'cache' in dsn: | |
30 | cached = True | |
31 | cache = dsn['cache'] | |
32 | else: | |
33 | cached = False | |
34 | cache = { | |
35 | 'hstore': None, | |
36 | } | |
37 | ||
38 | attempt = 0 | |
39 | while True: | |
40 | try: | |
41 | self.driver = psycopg2.connect(**options) | |
42 | break | |
43 | except psycopg2.OperationalError, e: | |
44 | if e.message.startswith('could not connect to server: '): | |
45 | raise ConnectionError() | |
46 | if attempt == 2: | |
47 | raise | |
48 | attempt = attempt + 1 | |
49 | ||
50 | self.driver.autocommit = True | |
51 | ||
52 | # XXX: all of my databases default to this... | |
53 | #try: | |
54 | # self.driver.set_client_encoding('UNICODE') | |
55 | #except: | |
56 | # self.driver.close() | |
57 | # raise | |
58 | ||
59 | hstore = cache['hstore'] | |
60 | if hstore == None: | |
61 | hstore = psycopg2.extras.HstoreAdapter.get_oids(self.driver) | |
62 | if hstore != None: | |
63 | hstore = hstore[0] | |
64 | cache['hstore'] = hstore | |
65 | ||
66 | if hstore != None: | |
67 | try: | |
68 | psycopg2.extras.register_hstore(self.driver, globally=False, unicode=True, oid=hstore) | |
69 | except psycopg2.ProgrammingError, e: | |
70 | pass | |
71 | ||
72 | if not cached: | |
73 | dsn['cache'] = cache | |
74 | ||
75 | def close(self): | |
76 | self.driver.close() | |
77 | ||
78 | def __enter__(self): | |
79 | return self | |
80 | ||
81 | def __exit__(self, type, value, traceback): | |
82 | self.close() | |
83 | ||
84 | def begin(self): | |
85 | self.driver.autocommit = False | |
86 | ||
87 | def commit(self): | |
88 | self.driver.commit() | |
89 | ||
90 | def rollback(self): | |
91 | self.driver.rollback() | |
92 | ||
93 | @contextmanager | |
94 | def cursor(self): | |
95 | cursor = self.driver.cursor(cursor_factory=psycopg2.extras.DictCursor) | |
96 | try: | |
97 | yield cursor | |
98 | finally: | |
99 | cursor.close() | |
100 | ||
101 | @contextmanager | |
102 | def execute(self, statement, depth=0, context=None): | |
103 | # two frames, accounting for execute() and @contextmanager | |
104 | frame = inspect.currentframe(depth + 2) | |
105 | ||
106 | with self.cursor() as cursor: | |
107 | f_globals = None | |
108 | f_locals = frame.f_locals | |
109 | ||
110 | if context == None: | |
111 | context = dict(**f_locals) | |
112 | ||
113 | start = 0 | |
114 | while True: | |
115 | percent = statement.find('%', start) | |
116 | if percent == -1: | |
117 | break | |
118 | ||
119 | next = statement[percent + 1] | |
120 | if next == '(': | |
121 | start = statement.index(')', percent + 2) + 2 | |
122 | assert statement[start - 1] == 's' | |
123 | elif next == '{': | |
124 | start = statement.index('}', percent + 2) | |
125 | assert statement[start + 1] == 's' | |
126 | code = statement[percent + 2:start] | |
127 | ||
128 | if f_globals == None: | |
129 | f_globals = frame.f_globals | |
130 | ||
131 | key = '__cyql__%i' % (percent,) | |
132 | # XXX: compile() in the frame's context | |
133 | context[key] = eval(code, f_globals, f_locals) | |
134 | ||
135 | statement = '%s%%(%s)%s' % (statement[0:percent], key, statement[start + 1:]) | |
136 | start = percent + len(key) + 4 | |
137 | elif next in ('%', 's'): | |
138 | start = percent + 2 | |
139 | else: | |
140 | assert False | |
141 | ||
142 | cursor.execute(statement, context) | |
143 | ||
144 | del context | |
145 | del f_locals | |
146 | del f_globals | |
147 | ||
148 | yield cursor | |
149 | ||
150 | @contextmanager | |
151 | def transact(self, synchronous_commit=True): | |
152 | self.driver.autocommit = False | |
153 | try: | |
154 | with self.cursor() as cursor: | |
155 | if not synchronous_commit: | |
156 | cursor.execute('set local synchronous_commit = off') | |
157 | ||
158 | yield | |
159 | self.driver.commit() | |
160 | except: | |
161 | self.driver.rollback() | |
162 | raise | |
163 | finally: | |
164 | self.driver.autocommit = True | |
165 | ||
166 | def one_(self, statement, context=None): | |
167 | with self.execute(statement, 2, context) as cursor: | |
168 | one = cursor.fetchone() | |
169 | if one == None: | |
170 | return None | |
171 | ||
172 | assert cursor.fetchone() == None | |
173 | return one | |
174 | ||
175 | def __call__(self, procedure, *parameters): | |
176 | with self.execute(statement, 1) as cursor: | |
177 | return cursor.callproc(procedure, *parameters) | |
178 | ||
179 | def run(self, statement, context=None): | |
180 | with self.execute(statement, 1, context) as cursor: | |
181 | return cursor.rowcount | |
182 | ||
183 | @contextmanager | |
184 | def set(self, statement): | |
185 | with self.execute(statement, 2) as cursor: | |
186 | yield cursor | |
187 | ||
188 | def all(self, statement, context=None): | |
189 | with self.execute(statement, 1, context) as cursor: | |
190 | return cursor.fetchall() | |
191 | ||
192 | def one(self, statement, context=None): | |
193 | return self.one_(statement, context) | |
194 | ||
195 | def has(self, statement): | |
196 | exists, = self.one_('select exists(%s)' % (statement,)) | |
197 | return exists | |
198 | ||
199 | def connected(dsn): | |
200 | def wrapped(method): | |
201 | def replaced(*args, **kw): | |
202 | with connect(dsn) as sql: | |
203 | return method(*args, sql=sql, **kw) | |
204 | return replaced | |
205 | return wrapped | |
206 | ||
207 | @contextmanager | |
208 | def transact(dsn, *args, **kw): | |
209 | with connect(dsn) as connection: | |
210 | with connection.transact(*args, **kw): | |
211 | yield connection | |
212 | ||
213 | """ | |
214 | def slap_(sql, table, keys, values, path): | |
215 | csr = sql.cursor() | |
216 | try: | |
217 | csr.execute('savepoint iou') | |
218 | try: | |
219 | both = dict(keys, **values) | |
220 | fields = both.keys() | |
221 | ||
222 | csr.execute(''' | |
223 | insert into %s (%s) values (%s) | |
224 | ''' % ( | |
225 | table, | |
226 | ', '.join(fields), | |
227 | ', '.join(['%s' for key in fields]) | |
228 | ), both.values()) | |
229 | except psycopg2.IntegrityError, e: | |
230 | csr.execute('rollback to savepoint iou') | |
231 | ||
232 | csr.execute(''' | |
233 | update %s set %s where %s | |
234 | ''' % ( | |
235 | table, | |
236 | ', '.join([ | |
237 | key + ' = %s' | |
238 | for key in values.keys()]), | |
239 | ' and '.join([ | |
240 | key + ' = %s' | |
241 | for key in keys.keys()]) | |
242 | ), values.values() + keys.values()) | |
243 | ||
244 | return path_(csr, path) | |
245 | finally: | |
246 | csr.close() | |
247 | """ |