]> git.saurik.com Git - cyql.git/blame - __init__.py
Add features to support database.py.
[cyql.git] / __init__.py
CommitLineData
a056a391
JF
1from __future__ import absolute_import
2from __future__ import division
426cccf1 3from __future__ import print_function
a056a391
JF
4from __future__ import unicode_literals
5
6from future_builtins import ascii, filter, hex, map, oct, zip
736932f0
JF
7
8import inspect
9import os
10
11from contextlib import contextmanager
12
13import psycopg2
426cccf1 14import psycopg2.extras
736932f0
JF
15import psycopg2.pool
16
736932f0 17psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
709fb2df 18psycopg2.extensions.register_type(psycopg2.extensions.UNICODEARRAY)
736932f0 19
c8a72a64
JF
20class connect(object):
21 def __init__(self, dsn):
22 attempt = 0
23 while True:
24 try:
25 self.driver = psycopg2.connect(**dsn)
26 break
27 except psycopg2.OperationalError, e:
28 if attempt == 2:
29 raise e
30 attempt = attempt + 1
31
32 try:
33 self.driver.set_client_encoding('UNICODE')
34 self.driver.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)
35 except:
36 self.driver.close()
37
9187878c
JF
38 try:
39 psycopg2.extras.register_hstore(self.driver, globally=False, unicode=True)
90dd2c66 40 except psycopg2.ProgrammingError, e:
9187878c 41 pass
0420b0db 42
c8a72a64
JF
43 def close(self):
44 self.driver.close()
45
46 def __enter__(self):
47 return self
48
49 def __exit__(self, type, value, traceback):
50 self.close()
426cccf1 51
bffe10fe
JF
52 def begin(self):
53 self.driver.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED)
54
55 def commit(self):
56 self.driver.commit()
57
58 def rollback(self):
59 self.driver.rollback()
60
426cccf1
JF
61 @contextmanager
62 def cursor(self):
38cd52b6 63 cursor = self.driver.cursor(cursor_factory=psycopg2.extras.DictCursor)
426cccf1 64 try:
426cccf1
JF
65 yield cursor
66 finally:
67 cursor.close()
68
69 @contextmanager
51ce7a27 70 def execute(self, statement, depth=0, context=None):
e12c964c
JF
71 # two frames, accounting for execute() and @contextmanager
72 frame = inspect.currentframe(depth + 2)
73
426cccf1 74 with self.cursor() as cursor:
e12c964c 75 f_globals = None
e12c964c 76 f_locals = frame.f_locals
51ce7a27
JF
77
78 if context == None:
79 context = dict(**f_locals)
e12c964c
JF
80
81 start = 0
82 while True:
83 percent = statement.find('%', start)
84 if percent == -1:
85 break
86
87 next = statement[percent + 1]
88 if next == '(':
89 start = statement.index(')', percent + 2) + 2
0a3825ab 90 assert statement[start - 1] == 's'
e12c964c
JF
91 elif next == '{':
92 start = statement.index('}', percent + 2)
0a3825ab 93 assert statement[start + 1] == 's'
e12c964c
JF
94 code = statement[percent + 2:start]
95
96 if f_globals == None:
97 f_globals = frame.f_globals
98
99 key = '__cyql__%i' % (percent,)
100 # XXX: compile() in the frame's context
101 context[key] = eval(code, f_globals, f_locals)
102
0a3825ab 103 statement = '%s%%(%s)%s' % (statement[0:percent], key, statement[start + 1:])
e12c964c 104 start = percent + len(key) + 4
51ce7a27 105 elif next in ('%', 's'):
e12c964c
JF
106 start = percent + 2
107 else:
108 assert False
109
110 cursor.execute(statement, context)
111
112 del context
113 del f_locals
114 del f_globals
115
426cccf1 116 yield cursor
736932f0 117
426cccf1
JF
118 @contextmanager
119 def transact(self, synchronous_commit=True):
38cd52b6 120 self.driver.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED)
426cccf1 121 try:
38cd52b6
JF
122 with self.cursor() as cursor:
123 if not synchronous_commit:
124 cursor.execute('set local synchronous_commit = off')
125
126 yield
426cccf1
JF
127 self.driver.commit()
128 except:
129 self.driver.rollback()
130 raise
38cd52b6
JF
131 finally:
132 self.driver.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)
426cccf1 133
98706a12
JF
134 def one_(self, statement, context=None):
135 with self.execute(statement, 2, context) as cursor:
38cd52b6
JF
136 one = cursor.fetchone()
137 if one == None:
138 return None
736932f0 139
38cd52b6
JF
140 assert cursor.fetchone() == None
141 return one
142
143 def __call__(self, procedure, *parameters):
144 with self.execute(statement, 1) as cursor:
145 return cursor.callproc(procedure, *parameters)
146
51ce7a27
JF
147 def run(self, statement, context=None):
148 with self.execute(statement, 1, context) as cursor:
38cd52b6 149 return cursor.rowcount
38f04d60
JF
150
151 @contextmanager
152 def set(self, statement):
38cd52b6 153 with self.execute(statement, 1) as cursor:
38f04d60
JF
154 yield cursor
155
bffe10fe
JF
156 def all(self, statement, context=None):
157 with self.execute(statement, 1, context) as cursor:
38f04d60
JF
158 return cursor.fetchall()
159
98706a12
JF
160 def one(self, statement, context=None):
161 return self.one_(statement, context)
38f04d60 162
38cd52b6 163 def has(self, statement):
408ed285
JF
164 exists, = self.one_('select exists(%s)' % (statement,))
165 return exists
f1df255a 166
91d72c6c
JF
167def connected(dsn):
168 def wrapped(method):
169 def replaced(*args, **kw):
12c855cb
JF
170 with connect(dsn) as sql:
171 return method(*args, sql=sql, **kw)
91d72c6c
JF
172 return replaced
173 return wrapped
174
7d11917e 175@contextmanager
1e227340 176def transact(dsn, *args, **kw):
7d11917e 177 with connect(dsn) as connection:
1e227340 178 with connection.transact(*args, **kw):
38cd52b6 179 yield connection
7d11917e 180
426cccf1 181"""
736932f0
JF
182def slap_(sql, table, keys, values, path):
183 csr = sql.cursor()
184 try:
185 csr.execute('savepoint iou')
186 try:
187 both = dict(keys, **values)
188 fields = both.keys()
189
190 csr.execute('''
191 insert into %s (%s) values (%s)
192 ''' % (
193 table,
194 ', '.join(fields),
195 ', '.join(['%s' for key in fields])
196 ), both.values())
197 except psycopg2.IntegrityError, e:
198 csr.execute('rollback to savepoint iou')
199
200 csr.execute('''
201 update %s set %s where %s
202 ''' % (
203 table,
204 ', '.join([
205 key + ' = %s'
206 for key in values.keys()]),
207 ' and '.join([
208 key + ' = %s'
209 for key in keys.keys()])
210 ), values.values() + keys.values())
211
212 return path_(csr, path)
213 finally:
214 csr.close()
426cccf1 215"""