]> git.saurik.com Git - cyql.git/blame - __init__.py
Use s formatting flag for %{} and assert it for %()s.
[cyql.git] / __init__.py
CommitLineData
a056a391
JF
1from __future__ import absolute_import
2from __future__ import division
426cccf1 3from __future__ import print_function
a056a391
JF
4from __future__ import unicode_literals
5
6from future_builtins import ascii, filter, hex, map, oct, zip
736932f0
JF
7
8import inspect
9import os
10
11from contextlib import contextmanager
12
13import psycopg2
426cccf1 14import psycopg2.extras
736932f0
JF
15import psycopg2.pool
16
736932f0
JF
17psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
18
c8a72a64
JF
19class connect(object):
20 def __init__(self, dsn):
21 attempt = 0
22 while True:
23 try:
24 self.driver = psycopg2.connect(**dsn)
25 break
26 except psycopg2.OperationalError, e:
27 if attempt == 2:
28 raise e
29 attempt = attempt + 1
30
31 try:
32 self.driver.set_client_encoding('UNICODE')
33 self.driver.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)
34 except:
35 self.driver.close()
36
37 def close(self):
38 self.driver.close()
39
40 def __enter__(self):
41 return self
42
43 def __exit__(self, type, value, traceback):
44 self.close()
426cccf1
JF
45
46 @contextmanager
47 def cursor(self):
38cd52b6 48 cursor = self.driver.cursor(cursor_factory=psycopg2.extras.DictCursor)
426cccf1 49 try:
426cccf1
JF
50 yield cursor
51 finally:
52 cursor.close()
53
54 @contextmanager
51ce7a27 55 def execute(self, statement, depth=0, context=None):
e12c964c
JF
56 # two frames, accounting for execute() and @contextmanager
57 frame = inspect.currentframe(depth + 2)
58
426cccf1 59 with self.cursor() as cursor:
e12c964c 60 f_globals = None
e12c964c 61 f_locals = frame.f_locals
51ce7a27
JF
62
63 if context == None:
64 context = dict(**f_locals)
e12c964c
JF
65
66 start = 0
67 while True:
68 percent = statement.find('%', start)
69 if percent == -1:
70 break
71
72 next = statement[percent + 1]
73 if next == '(':
74 start = statement.index(')', percent + 2) + 2
0a3825ab 75 assert statement[start - 1] == 's'
e12c964c
JF
76 elif next == '{':
77 start = statement.index('}', percent + 2)
0a3825ab 78 assert statement[start + 1] == 's'
e12c964c
JF
79 code = statement[percent + 2:start]
80
81 if f_globals == None:
82 f_globals = frame.f_globals
83
84 key = '__cyql__%i' % (percent,)
85 # XXX: compile() in the frame's context
86 context[key] = eval(code, f_globals, f_locals)
87
0a3825ab 88 statement = '%s%%(%s)%s' % (statement[0:percent], key, statement[start + 1:])
e12c964c 89 start = percent + len(key) + 4
51ce7a27 90 elif next in ('%', 's'):
e12c964c
JF
91 start = percent + 2
92 else:
93 assert False
94
95 cursor.execute(statement, context)
96
97 del context
98 del f_locals
99 del f_globals
100
426cccf1 101 yield cursor
736932f0 102
426cccf1
JF
103 @contextmanager
104 def transact(self, synchronous_commit=True):
38cd52b6 105 self.driver.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED)
426cccf1 106 try:
38cd52b6
JF
107 with self.cursor() as cursor:
108 if not synchronous_commit:
109 cursor.execute('set local synchronous_commit = off')
110
111 yield
426cccf1
JF
112 self.driver.commit()
113 except:
114 self.driver.rollback()
115 raise
38cd52b6
JF
116 finally:
117 self.driver.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)
426cccf1 118
38cd52b6
JF
119 def one_(self, statement):
120 with self.execute(statement, 2) as cursor:
121 one = cursor.fetchone()
122 if one == None:
123 return None
736932f0 124
38cd52b6
JF
125 assert cursor.fetchone() == None
126 return one
127
128 def __call__(self, procedure, *parameters):
129 with self.execute(statement, 1) as cursor:
130 return cursor.callproc(procedure, *parameters)
131
51ce7a27
JF
132 def run(self, statement, context=None):
133 with self.execute(statement, 1, context) as cursor:
38cd52b6 134 return cursor.rowcount
38f04d60
JF
135
136 @contextmanager
137 def set(self, statement):
38cd52b6 138 with self.execute(statement, 1) as cursor:
38f04d60
JF
139 yield cursor
140
141 def all(self, statement):
38cd52b6 142 with self.execute(statement, 1) as cursor:
38f04d60
JF
143 return cursor.fetchall()
144
145 def one(self, statement):
38cd52b6 146 return self.one_(statement)
38f04d60 147
38cd52b6 148 def has(self, statement):
408ed285
JF
149 exists, = self.one_('select exists(%s)' % (statement,))
150 return exists
f1df255a 151
91d72c6c
JF
152def connected(dsn):
153 def wrapped(method):
154 def replaced(*args, **kw):
12c855cb
JF
155 with connect(dsn) as sql:
156 return method(*args, sql=sql, **kw)
91d72c6c
JF
157 return replaced
158 return wrapped
159
7d11917e 160@contextmanager
1e227340 161def transact(dsn, *args, **kw):
7d11917e 162 with connect(dsn) as connection:
1e227340 163 with connection.transact(*args, **kw):
38cd52b6 164 yield connection
7d11917e 165
426cccf1 166"""
736932f0
JF
167def slap_(sql, table, keys, values, path):
168 csr = sql.cursor()
169 try:
170 csr.execute('savepoint iou')
171 try:
172 both = dict(keys, **values)
173 fields = both.keys()
174
175 csr.execute('''
176 insert into %s (%s) values (%s)
177 ''' % (
178 table,
179 ', '.join(fields),
180 ', '.join(['%s' for key in fields])
181 ), both.values())
182 except psycopg2.IntegrityError, e:
183 csr.execute('rollback to savepoint iou')
184
185 csr.execute('''
186 update %s set %s where %s
187 ''' % (
188 table,
189 ', '.join([
190 key + ' = %s'
191 for key in values.keys()]),
192 ' and '.join([
193 key + ' = %s'
194 for key in keys.keys()])
195 ), values.values() + keys.values())
196
197 return path_(csr, path)
198 finally:
199 csr.close()
426cccf1 200"""