]> git.saurik.com Git - cyql.git/blob - __init__.py
dd150ae19ea6a64aed2b26103205267c773e21d0
[cyql.git] / __init__.py
1 from __future__ import unicode_literals
2 from __future__ import print_function
3
4 import inspect
5 import os
6
7 from contextlib import contextmanager
8
9 import psycopg2
10 import psycopg2.extras
11 import psycopg2.pool
12
13 psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
14
15 class connection(object):
16 def __init__(self, driver):
17 self.driver = driver
18
19 @contextmanager
20 def cursor(self):
21 try:
22 cursor = self.driver.cursor(cursor_factory=psycopg2.extras.DictCursor)
23 yield cursor
24 finally:
25 cursor.close()
26
27 @contextmanager
28 def execute_(self, statement, locals):
29 with self.cursor() as cursor:
30 cursor.execute(statement.format(**locals), locals)
31 yield cursor
32
33 @contextmanager
34 def execute(self, statement, depth=0, context=None):
35 with self.cursor() as cursor:
36 # two frames, accounting for execute() and @contextmanager
37 locals = inspect.currentframe(depth + 2).f_locals
38 try:
39 if context == None:
40 context = locals
41 cursor.execute(statement.format(**locals), context)
42 finally:
43 del locals
44 yield cursor
45
46 @contextmanager
47 def transact(self, synchronous_commit=True):
48 with self.cursor() as cursor:
49 if not synchronous_commit:
50 cursor.execute('set local synchronous_commit = off')
51
52 try:
53 yield transaction(self)
54 self.driver.commit()
55 except:
56 self.driver.rollback()
57 raise
58
59 class transaction(object):
60 def __init__(self, connection):
61 self.connection = connection
62
63 def __call__(self, statement, locals=None):
64 with self.connection.execute(statement, 1, locals) as cursor:
65 pass
66
67 @contextmanager
68 def set(self, statement):
69 with self.connection.execute(statement, 1) as cursor:
70 yield cursor
71
72 def all(self, statement):
73 with self.connection.execute(statement, 1) as cursor:
74 return cursor.fetchall()
75
76 def one(self, statement):
77 with self.connection.execute(statement, 1) as cursor:
78 one = cursor.fetchone()
79 if one == None:
80 return None
81 assert cursor.fetchone() == None
82 return one
83
84 def pull(self, statement):
85 with self.connection.execute(statement, 1) as cursor:
86 return cursor.fetchall()
87
88 def yank(self, statement, offset=0):
89 with self.connection.execute(statement, 1 + offset) as cursor:
90 rows = cursor.fetchall()
91 return rows[0] if len(rows) != 0 else None
92
93 def push(self, statement):
94 with self.connection.execute(statement, 1) as cursor:
95 pass
96
97 def push_(self, statement, locals):
98 with self.connection.execute_(statement, locals) as cursor:
99 pass
100
101 def exists(self, statement):
102 return self.yank('''
103 select exists (
104 {statement}
105 )
106 '''.format(**locals()), 1)[0]
107
108 @contextmanager
109 def connect(dsn):
110 attempt = 0
111 while True:
112 try:
113 driver = psycopg2.connect(**dsn)
114 break
115 except psycopg2.OperationalError, e:
116 if attempt == 2:
117 raise e
118 attempt = attempt + 1
119
120 try:
121 driver.set_client_encoding('UNICODE')
122 yield connection(driver)
123 finally:
124 driver.close()
125
126 def connected(dsn):
127 def wrapped(method):
128 def replaced(*args, **kw):
129 with connect(dsn) as connection:
130 return method(connection, *args, **kw)
131 return replaced
132 return wrapped
133
134 @contextmanager
135 def transact(dsn, **args):
136 with connect(dsn) as connection:
137 with connection.transact(**args) as cursor:
138 yield cursor
139
140 """
141 def slap_(sql, table, keys, values, path):
142 csr = sql.cursor()
143 try:
144 csr.execute('savepoint iou')
145 try:
146 both = dict(keys, **values)
147 fields = both.keys()
148
149 csr.execute('''
150 insert into %s (%s) values (%s)
151 ''' % (
152 table,
153 ', '.join(fields),
154 ', '.join(['%s' for key in fields])
155 ), both.values())
156 except psycopg2.IntegrityError, e:
157 csr.execute('rollback to savepoint iou')
158
159 csr.execute('''
160 update %s set %s where %s
161 ''' % (
162 table,
163 ', '.join([
164 key + ' = %s'
165 for key in values.keys()]),
166 ' and '.join([
167 key + ' = %s'
168 for key in keys.keys()])
169 ), values.values() + keys.values())
170
171 return path_(csr, path)
172 finally:
173 csr.close()
174 """