]>
Commit | Line | Data |
---|---|---|
a056a391 JF |
1 | from __future__ import absolute_import |
2 | from __future__ import division | |
426cccf1 | 3 | from __future__ import print_function |
a056a391 JF |
4 | from __future__ import unicode_literals |
5 | ||
6 | from future_builtins import ascii, filter, hex, map, oct, zip | |
736932f0 JF |
7 | |
8 | import inspect | |
9 | import os | |
10 | ||
11 | from contextlib import contextmanager | |
12 | ||
13 | import psycopg2 | |
426cccf1 | 14 | import psycopg2.extras |
736932f0 JF |
15 | import psycopg2.pool |
16 | ||
736932f0 JF |
17 | psycopg2.extensions.register_type(psycopg2.extensions.UNICODE) |
18 | ||
c8a72a64 JF |
19 | class connect(object): |
20 | def __init__(self, dsn): | |
21 | attempt = 0 | |
22 | while True: | |
23 | try: | |
24 | self.driver = psycopg2.connect(**dsn) | |
25 | break | |
26 | except psycopg2.OperationalError, e: | |
27 | if attempt == 2: | |
28 | raise e | |
29 | attempt = attempt + 1 | |
30 | ||
31 | try: | |
32 | self.driver.set_client_encoding('UNICODE') | |
33 | self.driver.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT) | |
34 | except: | |
35 | self.driver.close() | |
36 | ||
37 | def close(self): | |
38 | self.driver.close() | |
39 | ||
40 | def __enter__(self): | |
41 | return self | |
42 | ||
43 | def __exit__(self, type, value, traceback): | |
44 | self.close() | |
426cccf1 JF |
45 | |
46 | @contextmanager | |
47 | def cursor(self): | |
38cd52b6 | 48 | cursor = self.driver.cursor(cursor_factory=psycopg2.extras.DictCursor) |
426cccf1 | 49 | try: |
426cccf1 JF |
50 | yield cursor |
51 | finally: | |
52 | cursor.close() | |
53 | ||
54 | @contextmanager | |
51ce7a27 | 55 | def execute(self, statement, depth=0, context=None): |
e12c964c JF |
56 | # two frames, accounting for execute() and @contextmanager |
57 | frame = inspect.currentframe(depth + 2) | |
58 | ||
426cccf1 | 59 | with self.cursor() as cursor: |
e12c964c | 60 | f_globals = None |
e12c964c | 61 | f_locals = frame.f_locals |
51ce7a27 JF |
62 | |
63 | if context == None: | |
64 | context = dict(**f_locals) | |
e12c964c JF |
65 | |
66 | start = 0 | |
67 | while True: | |
68 | percent = statement.find('%', start) | |
69 | if percent == -1: | |
70 | break | |
71 | ||
72 | next = statement[percent + 1] | |
73 | if next == '(': | |
74 | start = statement.index(')', percent + 2) + 2 | |
0a3825ab | 75 | assert statement[start - 1] == 's' |
e12c964c JF |
76 | elif next == '{': |
77 | start = statement.index('}', percent + 2) | |
0a3825ab | 78 | assert statement[start + 1] == 's' |
e12c964c JF |
79 | code = statement[percent + 2:start] |
80 | ||
81 | if f_globals == None: | |
82 | f_globals = frame.f_globals | |
83 | ||
84 | key = '__cyql__%i' % (percent,) | |
85 | # XXX: compile() in the frame's context | |
86 | context[key] = eval(code, f_globals, f_locals) | |
87 | ||
0a3825ab | 88 | statement = '%s%%(%s)%s' % (statement[0:percent], key, statement[start + 1:]) |
e12c964c | 89 | start = percent + len(key) + 4 |
51ce7a27 | 90 | elif next in ('%', 's'): |
e12c964c JF |
91 | start = percent + 2 |
92 | else: | |
93 | assert False | |
94 | ||
95 | cursor.execute(statement, context) | |
96 | ||
97 | del context | |
98 | del f_locals | |
99 | del f_globals | |
100 | ||
426cccf1 | 101 | yield cursor |
736932f0 | 102 | |
426cccf1 JF |
103 | @contextmanager |
104 | def transact(self, synchronous_commit=True): | |
38cd52b6 | 105 | self.driver.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED) |
426cccf1 | 106 | try: |
38cd52b6 JF |
107 | with self.cursor() as cursor: |
108 | if not synchronous_commit: | |
109 | cursor.execute('set local synchronous_commit = off') | |
110 | ||
111 | yield | |
426cccf1 JF |
112 | self.driver.commit() |
113 | except: | |
114 | self.driver.rollback() | |
115 | raise | |
38cd52b6 JF |
116 | finally: |
117 | self.driver.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT) | |
426cccf1 | 118 | |
38cd52b6 JF |
119 | def one_(self, statement): |
120 | with self.execute(statement, 2) as cursor: | |
121 | one = cursor.fetchone() | |
122 | if one == None: | |
123 | return None | |
736932f0 | 124 | |
38cd52b6 JF |
125 | assert cursor.fetchone() == None |
126 | return one | |
127 | ||
128 | def __call__(self, procedure, *parameters): | |
129 | with self.execute(statement, 1) as cursor: | |
130 | return cursor.callproc(procedure, *parameters) | |
131 | ||
51ce7a27 JF |
132 | def run(self, statement, context=None): |
133 | with self.execute(statement, 1, context) as cursor: | |
38cd52b6 | 134 | return cursor.rowcount |
38f04d60 JF |
135 | |
136 | @contextmanager | |
137 | def set(self, statement): | |
38cd52b6 | 138 | with self.execute(statement, 1) as cursor: |
38f04d60 JF |
139 | yield cursor |
140 | ||
141 | def all(self, statement): | |
38cd52b6 | 142 | with self.execute(statement, 1) as cursor: |
38f04d60 JF |
143 | return cursor.fetchall() |
144 | ||
145 | def one(self, statement): | |
38cd52b6 | 146 | return self.one_(statement) |
38f04d60 | 147 | |
38cd52b6 | 148 | def has(self, statement): |
408ed285 JF |
149 | exists, = self.one_('select exists(%s)' % (statement,)) |
150 | return exists | |
f1df255a | 151 | |
91d72c6c JF |
152 | def connected(dsn): |
153 | def wrapped(method): | |
154 | def replaced(*args, **kw): | |
12c855cb JF |
155 | with connect(dsn) as sql: |
156 | return method(*args, sql=sql, **kw) | |
91d72c6c JF |
157 | return replaced |
158 | return wrapped | |
159 | ||
7d11917e | 160 | @contextmanager |
1e227340 | 161 | def transact(dsn, *args, **kw): |
7d11917e | 162 | with connect(dsn) as connection: |
1e227340 | 163 | with connection.transact(*args, **kw): |
38cd52b6 | 164 | yield connection |
7d11917e | 165 | |
426cccf1 | 166 | """ |
736932f0 JF |
167 | def slap_(sql, table, keys, values, path): |
168 | csr = sql.cursor() | |
169 | try: | |
170 | csr.execute('savepoint iou') | |
171 | try: | |
172 | both = dict(keys, **values) | |
173 | fields = both.keys() | |
174 | ||
175 | csr.execute(''' | |
176 | insert into %s (%s) values (%s) | |
177 | ''' % ( | |
178 | table, | |
179 | ', '.join(fields), | |
180 | ', '.join(['%s' for key in fields]) | |
181 | ), both.values()) | |
182 | except psycopg2.IntegrityError, e: | |
183 | csr.execute('rollback to savepoint iou') | |
184 | ||
185 | csr.execute(''' | |
186 | update %s set %s where %s | |
187 | ''' % ( | |
188 | table, | |
189 | ', '.join([ | |
190 | key + ' = %s' | |
191 | for key in values.keys()]), | |
192 | ' and '.join([ | |
193 | key + ' = %s' | |
194 | for key in keys.keys()]) | |
195 | ), values.values() + keys.values()) | |
196 | ||
197 | return path_(csr, path) | |
198 | finally: | |
199 | csr.close() | |
426cccf1 | 200 | """ |