]> git.saurik.com Git - cyql.git/blob - __init__.py
bc205292e79719b98b91b5460204f423cb30602a
[cyql.git] / __init__.py
1 from __future__ import absolute_import
2 from __future__ import division
3 from __future__ import print_function
4 from __future__ import unicode_literals
5
6 from future_builtins import ascii, filter, hex, map, oct, zip
7
8 import inspect
9 import os
10
11 from contextlib import contextmanager
12
13 import psycopg2
14 import psycopg2.extras
15 import psycopg2.pool
16
17 psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
18 psycopg2.extensions.register_type(psycopg2.extensions.UNICODEARRAY)
19
20 class connect(object):
21 def __init__(self, dsn):
22 attempt = 0
23 while True:
24 try:
25 self.driver = psycopg2.connect(**dsn)
26 break
27 except psycopg2.OperationalError, e:
28 if attempt == 2:
29 raise
30 attempt = attempt + 1
31
32 self.driver.autocommit = True
33
34 # XXX: all of my databases default to this...
35 #try:
36 # self.driver.set_client_encoding('UNICODE')
37 #except:
38 # self.driver.close()
39 # raise
40
41 try:
42 psycopg2.extras.register_hstore(self.driver, globally=False, unicode=True)
43 except psycopg2.ProgrammingError, e:
44 pass
45
46 def close(self):
47 self.driver.close()
48
49 def __enter__(self):
50 return self
51
52 def __exit__(self, type, value, traceback):
53 self.close()
54
55 def begin(self):
56 self.driver.autocommit = False
57
58 def commit(self):
59 self.driver.commit()
60
61 def rollback(self):
62 self.driver.rollback()
63
64 @contextmanager
65 def cursor(self):
66 cursor = self.driver.cursor(cursor_factory=psycopg2.extras.DictCursor)
67 try:
68 yield cursor
69 finally:
70 cursor.close()
71
72 @contextmanager
73 def execute(self, statement, depth=0, context=None):
74 # two frames, accounting for execute() and @contextmanager
75 frame = inspect.currentframe(depth + 2)
76
77 with self.cursor() as cursor:
78 f_globals = None
79 f_locals = frame.f_locals
80
81 if context == None:
82 context = dict(**f_locals)
83
84 start = 0
85 while True:
86 percent = statement.find('%', start)
87 if percent == -1:
88 break
89
90 next = statement[percent + 1]
91 if next == '(':
92 start = statement.index(')', percent + 2) + 2
93 assert statement[start - 1] == 's'
94 elif next == '{':
95 start = statement.index('}', percent + 2)
96 assert statement[start + 1] == 's'
97 code = statement[percent + 2:start]
98
99 if f_globals == None:
100 f_globals = frame.f_globals
101
102 key = '__cyql__%i' % (percent,)
103 # XXX: compile() in the frame's context
104 context[key] = eval(code, f_globals, f_locals)
105
106 statement = '%s%%(%s)%s' % (statement[0:percent], key, statement[start + 1:])
107 start = percent + len(key) + 4
108 elif next in ('%', 's'):
109 start = percent + 2
110 else:
111 assert False
112
113 cursor.execute(statement, context)
114
115 del context
116 del f_locals
117 del f_globals
118
119 yield cursor
120
121 @contextmanager
122 def transact(self, synchronous_commit=True):
123 self.driver.autocommit = False
124 try:
125 with self.cursor() as cursor:
126 if not synchronous_commit:
127 cursor.execute('set local synchronous_commit = off')
128
129 yield
130 self.driver.commit()
131 except:
132 self.driver.rollback()
133 raise
134 finally:
135 self.driver.autocommit = True
136
137 def one_(self, statement, context=None):
138 with self.execute(statement, 2, context) as cursor:
139 one = cursor.fetchone()
140 if one == None:
141 return None
142
143 assert cursor.fetchone() == None
144 return one
145
146 def __call__(self, procedure, *parameters):
147 with self.execute(statement, 1) as cursor:
148 return cursor.callproc(procedure, *parameters)
149
150 def run(self, statement, context=None):
151 with self.execute(statement, 1, context) as cursor:
152 return cursor.rowcount
153
154 @contextmanager
155 def set(self, statement):
156 with self.execute(statement, 1) as cursor:
157 yield cursor
158
159 def all(self, statement, context=None):
160 with self.execute(statement, 1, context) as cursor:
161 return cursor.fetchall()
162
163 def one(self, statement, context=None):
164 return self.one_(statement, context)
165
166 def has(self, statement):
167 exists, = self.one_('select exists(%s)' % (statement,))
168 return exists
169
170 def connected(dsn):
171 def wrapped(method):
172 def replaced(*args, **kw):
173 with connect(dsn) as sql:
174 return method(*args, sql=sql, **kw)
175 return replaced
176 return wrapped
177
178 @contextmanager
179 def transact(dsn, *args, **kw):
180 with connect(dsn) as connection:
181 with connection.transact(*args, **kw):
182 yield connection
183
184 """
185 def slap_(sql, table, keys, values, path):
186 csr = sql.cursor()
187 try:
188 csr.execute('savepoint iou')
189 try:
190 both = dict(keys, **values)
191 fields = both.keys()
192
193 csr.execute('''
194 insert into %s (%s) values (%s)
195 ''' % (
196 table,
197 ', '.join(fields),
198 ', '.join(['%s' for key in fields])
199 ), both.values())
200 except psycopg2.IntegrityError, e:
201 csr.execute('rollback to savepoint iou')
202
203 csr.execute('''
204 update %s set %s where %s
205 ''' % (
206 table,
207 ', '.join([
208 key + ' = %s'
209 for key in values.keys()]),
210 ' and '.join([
211 key + ' = %s'
212 for key in keys.keys()])
213 ), values.values() + keys.values())
214
215 return path_(csr, path)
216 finally:
217 csr.close()
218 """