]> git.saurik.com Git - cyql.git/blame - __init__.py
On production websites, connections are by pipes.
[cyql.git] / __init__.py
CommitLineData
a056a391
JF
1from __future__ import absolute_import
2from __future__ import division
426cccf1 3from __future__ import print_function
a056a391
JF
4from __future__ import unicode_literals
5
6from future_builtins import ascii, filter, hex, map, oct, zip
736932f0
JF
7
8import inspect
9import os
10
11from contextlib import contextmanager
12
13import psycopg2
426cccf1 14import psycopg2.extras
736932f0
JF
15import psycopg2.pool
16
736932f0 17psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
709fb2df 18psycopg2.extensions.register_type(psycopg2.extensions.UNICODEARRAY)
736932f0 19
18e35f46
JF
20class ConnectionError():
21 pass
22
c8a72a64
JF
23class connect(object):
24 def __init__(self, dsn):
fff12dc3
JF
25 options = dsn.copy()
26 if 'cache' in options:
27 del options['cache']
28
29 if 'cache' in dsn:
30 cached = True
31 cache = dsn['cache']
32 else:
33 cached = False
34 cache = {
35 'hstore': None,
36 }
37
c8a72a64
JF
38 attempt = 0
39 while True:
40 try:
fff12dc3 41 self.driver = psycopg2.connect(**options)
c8a72a64
JF
42 break
43 except psycopg2.OperationalError, e:
ef444e49 44 if e.message.startswith('could not connect to server: '):
18e35f46 45 raise ConnectionError()
c8a72a64 46 if attempt == 2:
3027147a 47 raise
c8a72a64
JF
48 attempt = attempt + 1
49
8d19a913
JF
50 self.driver.autocommit = True
51
52 # XXX: all of my databases default to this...
53 #try:
54 # self.driver.set_client_encoding('UNICODE')
55 #except:
56 # self.driver.close()
57 # raise
c8a72a64 58
fff12dc3
JF
59 hstore = cache['hstore']
60 if hstore == None:
61 hstore = psycopg2.extras.HstoreAdapter.get_oids(self.driver)
62 if hstore != None:
63 hstore = hstore[0]
7a6d41b9 64 cache['hstore'] = hstore
fff12dc3
JF
65
66 if hstore != None:
67 try:
68 psycopg2.extras.register_hstore(self.driver, globally=False, unicode=True, oid=hstore)
69 except psycopg2.ProgrammingError, e:
70 pass
71
72 if not cached:
73 dsn['cache'] = cache
0420b0db 74
c8a72a64
JF
75 def close(self):
76 self.driver.close()
77
78 def __enter__(self):
79 return self
80
81 def __exit__(self, type, value, traceback):
82 self.close()
426cccf1 83
bffe10fe 84 def begin(self):
f392523d 85 self.driver.autocommit = False
bffe10fe
JF
86
87 def commit(self):
88 self.driver.commit()
89
90 def rollback(self):
91 self.driver.rollback()
92
426cccf1
JF
93 @contextmanager
94 def cursor(self):
38cd52b6 95 cursor = self.driver.cursor(cursor_factory=psycopg2.extras.DictCursor)
426cccf1 96 try:
426cccf1
JF
97 yield cursor
98 finally:
99 cursor.close()
100
101 @contextmanager
51ce7a27 102 def execute(self, statement, depth=0, context=None):
e12c964c
JF
103 # two frames, accounting for execute() and @contextmanager
104 frame = inspect.currentframe(depth + 2)
105
426cccf1 106 with self.cursor() as cursor:
e12c964c 107 f_globals = None
e12c964c 108 f_locals = frame.f_locals
51ce7a27
JF
109
110 if context == None:
111 context = dict(**f_locals)
e12c964c
JF
112
113 start = 0
114 while True:
115 percent = statement.find('%', start)
116 if percent == -1:
117 break
118
119 next = statement[percent + 1]
120 if next == '(':
121 start = statement.index(')', percent + 2) + 2
0a3825ab 122 assert statement[start - 1] == 's'
e12c964c
JF
123 elif next == '{':
124 start = statement.index('}', percent + 2)
0a3825ab 125 assert statement[start + 1] == 's'
e12c964c
JF
126 code = statement[percent + 2:start]
127
128 if f_globals == None:
129 f_globals = frame.f_globals
130
131 key = '__cyql__%i' % (percent,)
132 # XXX: compile() in the frame's context
133 context[key] = eval(code, f_globals, f_locals)
134
0a3825ab 135 statement = '%s%%(%s)%s' % (statement[0:percent], key, statement[start + 1:])
e12c964c 136 start = percent + len(key) + 4
51ce7a27 137 elif next in ('%', 's'):
e12c964c
JF
138 start = percent + 2
139 else:
140 assert False
141
142 cursor.execute(statement, context)
143
144 del context
145 del f_locals
146 del f_globals
147
426cccf1 148 yield cursor
736932f0 149
426cccf1
JF
150 @contextmanager
151 def transact(self, synchronous_commit=True):
f392523d 152 self.driver.autocommit = False
426cccf1 153 try:
38cd52b6
JF
154 with self.cursor() as cursor:
155 if not synchronous_commit:
156 cursor.execute('set local synchronous_commit = off')
157
158 yield
426cccf1
JF
159 self.driver.commit()
160 except:
161 self.driver.rollback()
162 raise
38cd52b6 163 finally:
f392523d 164 self.driver.autocommit = True
426cccf1 165
98706a12
JF
166 def one_(self, statement, context=None):
167 with self.execute(statement, 2, context) as cursor:
38cd52b6
JF
168 one = cursor.fetchone()
169 if one == None:
170 return None
736932f0 171
38cd52b6
JF
172 assert cursor.fetchone() == None
173 return one
174
175 def __call__(self, procedure, *parameters):
176 with self.execute(statement, 1) as cursor:
177 return cursor.callproc(procedure, *parameters)
178
51ce7a27
JF
179 def run(self, statement, context=None):
180 with self.execute(statement, 1, context) as cursor:
38cd52b6 181 return cursor.rowcount
38f04d60
JF
182
183 @contextmanager
184 def set(self, statement):
35e48155 185 with self.execute(statement, 2) as cursor:
38f04d60
JF
186 yield cursor
187
bffe10fe
JF
188 def all(self, statement, context=None):
189 with self.execute(statement, 1, context) as cursor:
38f04d60
JF
190 return cursor.fetchall()
191
98706a12
JF
192 def one(self, statement, context=None):
193 return self.one_(statement, context)
38f04d60 194
38cd52b6 195 def has(self, statement):
408ed285
JF
196 exists, = self.one_('select exists(%s)' % (statement,))
197 return exists
f1df255a 198
91d72c6c
JF
199def connected(dsn):
200 def wrapped(method):
201 def replaced(*args, **kw):
12c855cb
JF
202 with connect(dsn) as sql:
203 return method(*args, sql=sql, **kw)
91d72c6c
JF
204 return replaced
205 return wrapped
206
7d11917e 207@contextmanager
1e227340 208def transact(dsn, *args, **kw):
7d11917e 209 with connect(dsn) as connection:
1e227340 210 with connection.transact(*args, **kw):
38cd52b6 211 yield connection
7d11917e 212
426cccf1 213"""
736932f0
JF
214def slap_(sql, table, keys, values, path):
215 csr = sql.cursor()
216 try:
217 csr.execute('savepoint iou')
218 try:
219 both = dict(keys, **values)
220 fields = both.keys()
221
222 csr.execute('''
223 insert into %s (%s) values (%s)
224 ''' % (
225 table,
226 ', '.join(fields),
227 ', '.join(['%s' for key in fields])
228 ), both.values())
229 except psycopg2.IntegrityError, e:
230 csr.execute('rollback to savepoint iou')
231
232 csr.execute('''
233 update %s set %s where %s
234 ''' % (
235 table,
236 ', '.join([
237 key + ' = %s'
238 for key in values.keys()]),
239 ' and '.join([
240 key + ' = %s'
241 for key in keys.keys()])
242 ), values.values() + keys.values())
243
244 return path_(csr, path)
245 finally:
246 csr.close()
426cccf1 247"""