]> git.saurik.com Git - cyql.git/blob - __init__.py
Generalize context as argument to execute().
[cyql.git] / __init__.py
1 from __future__ import unicode_literals
2 from __future__ import print_function
3
4 import inspect
5 import os
6
7 from contextlib import contextmanager
8
9 import psycopg2
10 import psycopg2.extras
11 import psycopg2.pool
12
13 psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
14
15 class connection(object):
16 def __init__(self, driver):
17 self.driver = driver
18
19 @contextmanager
20 def cursor(self):
21 try:
22 cursor = self.driver.cursor(cursor_factory=psycopg2.extras.DictCursor)
23 yield cursor
24 finally:
25 cursor.close()
26
27 @contextmanager
28 def execute_(self, statement, locals):
29 with self.cursor() as cursor:
30 cursor.execute(statement.format(**locals), locals)
31 yield cursor
32
33 @contextmanager
34 def execute(self, statement, depth=0, context=None):
35 with self.cursor() as cursor:
36 # two frames, accounting for execute() and @contextmanager
37 locals = inspect.currentframe(depth + 2).f_locals
38 try:
39 if context == None:
40 context = locals
41 cursor.execute(statement.format(**locals), context)
42 finally:
43 del locals
44 yield cursor
45
46 @contextmanager
47 def transact(self, synchronous_commit=True):
48 with self.cursor() as cursor:
49 if not synchronous_commit:
50 cursor.execute('set local synchronous_commit = off')
51
52 try:
53 yield transaction(self)
54 self.driver.commit()
55 except:
56 self.driver.rollback()
57 raise
58
59 class transaction(object):
60 def __init__(self, connection):
61 self.connection = connection
62
63 def pull(self, statement):
64 with self.connection.execute(statement, 1) as cursor:
65 return cursor.fetchall()
66
67 def yank(self, statement, offset=0):
68 with self.connection.execute(statement, 1 + offset) as cursor:
69 rows = cursor.fetchall()
70 return rows[0] if len(rows) != 0 else None
71
72 def push(self, statement):
73 with self.connection.execute(statement, 1) as cursor:
74 pass
75
76 def push_(self, statement, locals):
77 with self.connection.execute_(statement, locals) as cursor:
78 pass
79
80 def exists(self, statement):
81 return self.yank('''
82 select exists (
83 {statement}
84 )
85 '''.format(**locals()), 1)[0]
86
87 @contextmanager
88 def connect(dsn):
89 attempt = 0
90 while True:
91 try:
92 driver = psycopg2.connect(**dsn)
93 break
94 except psycopg2.OperationalError, e:
95 if attempt == 2:
96 raise e
97 attempt = attempt + 1
98
99 try:
100 driver.set_client_encoding('UNICODE')
101 yield connection(driver)
102 finally:
103 driver.close()
104
105 def connected(dsn):
106 def wrapped(method):
107 def replaced(*args, **kw):
108 with connect(dsn) as connection:
109 return method(connection, *args, **kw)
110 return replaced
111 return wrapped
112
113 @contextmanager
114 def transact(dsn, **args):
115 with connect(dsn) as connection:
116 with connection.transact(**args) as cursor:
117 yield cursor
118
119 """
120 def slap_(sql, table, keys, values, path):
121 csr = sql.cursor()
122 try:
123 csr.execute('savepoint iou')
124 try:
125 both = dict(keys, **values)
126 fields = both.keys()
127
128 csr.execute('''
129 insert into %s (%s) values (%s)
130 ''' % (
131 table,
132 ', '.join(fields),
133 ', '.join(['%s' for key in fields])
134 ), both.values())
135 except psycopg2.IntegrityError, e:
136 csr.execute('rollback to savepoint iou')
137
138 csr.execute('''
139 update %s set %s where %s
140 ''' % (
141 table,
142 ', '.join([
143 key + ' = %s'
144 for key in values.keys()]),
145 ' and '.join([
146 key + ' = %s'
147 for key in keys.keys()])
148 ), values.values() + keys.values())
149
150 return path_(csr, path)
151 finally:
152 csr.close()
153 """