]>
Commit | Line | Data |
---|---|---|
426cccf1 JF |
1 | from __future__ import unicode_literals |
2 | from __future__ import print_function | |
736932f0 JF |
3 | |
4 | import inspect | |
5 | import os | |
6 | ||
7 | from contextlib import contextmanager | |
8 | ||
9 | import psycopg2 | |
426cccf1 | 10 | import psycopg2.extras |
736932f0 JF |
11 | import psycopg2.pool |
12 | ||
736932f0 JF |
13 | psycopg2.extensions.register_type(psycopg2.extensions.UNICODE) |
14 | ||
426cccf1 JF |
15 | class connection(object): |
16 | def __init__(self, driver): | |
17 | self.driver = driver | |
18 | ||
19 | @contextmanager | |
20 | def cursor(self): | |
21 | try: | |
22 | cursor = self.driver.cursor(cursor_factory=psycopg2.extras.DictCursor) | |
23 | yield cursor | |
24 | finally: | |
25 | cursor.close() | |
26 | ||
3de15e7b JF |
27 | @contextmanager |
28 | def execute_(self, statement, locals): | |
29 | with self.cursor() as cursor: | |
30 | cursor.execute(statement.format(**locals), locals) | |
31 | yield cursor | |
32 | ||
426cccf1 JF |
33 | @contextmanager |
34 | def execute(self, statement, depth=0): | |
35 | with self.cursor() as cursor: | |
ba4690e8 JF |
36 | # two frames, accounting for execute() and @contextmanager |
37 | locals = inspect.currentframe(depth + 2).f_locals | |
229bbf8c JF |
38 | try: |
39 | cursor.execute(statement.format(**locals), locals) | |
40 | finally: | |
41 | del locals | |
426cccf1 | 42 | yield cursor |
736932f0 | 43 | |
426cccf1 JF |
44 | @contextmanager |
45 | def transact(self, synchronous_commit=True): | |
46 | with self.cursor() as cursor: | |
47 | if not synchronous_commit: | |
48 | cursor.execute('set local synchronous_commit = off') | |
49 | ||
50 | try: | |
51 | yield transaction(self) | |
52 | self.driver.commit() | |
53 | except: | |
54 | self.driver.rollback() | |
55 | raise | |
56 | ||
57 | class transaction(object): | |
736932f0 JF |
58 | def __init__(self, connection): |
59 | self.connection = connection | |
60 | ||
61 | def pull(self, statement): | |
426cccf1 | 62 | with self.connection.execute(statement, 1) as cursor: |
736932f0 | 63 | return cursor.fetchall() |
736932f0 | 64 | |
f1df255a JF |
65 | def yank(self, statement, offset=0): |
66 | with self.connection.execute(statement, 1 + offset) as cursor: | |
736932f0 JF |
67 | rows = cursor.fetchall() |
68 | return rows[0] if len(rows) != 0 else None | |
736932f0 JF |
69 | |
70 | def push(self, statement): | |
426cccf1 JF |
71 | with self.connection.execute(statement, 1) as cursor: |
72 | pass | |
736932f0 | 73 | |
3de15e7b JF |
74 | def push_(self, statement, locals): |
75 | with self.connection.execute_(statement, locals) as cursor: | |
76 | pass | |
77 | ||
f1df255a | 78 | def exists(self, statement): |
8deb55d5 | 79 | return self.yank(''' |
f1df255a JF |
80 | select exists ( |
81 | {statement} | |
82 | ) | |
77034649 | 83 | '''.format(**locals()), 1)[0] |
f1df255a | 84 | |
736932f0 | 85 | @contextmanager |
426cccf1 | 86 | def connect(dsn): |
736932f0 JF |
87 | attempt = 0 |
88 | while True: | |
89 | try: | |
426cccf1 | 90 | driver = psycopg2.connect(**dsn) |
736932f0 JF |
91 | break |
92 | except psycopg2.OperationalError, e: | |
93 | if attempt == 2: | |
94 | raise e | |
95 | attempt = attempt + 1 | |
96 | ||
97 | try: | |
426cccf1 JF |
98 | driver.set_client_encoding('UNICODE') |
99 | yield connection(driver) | |
736932f0 | 100 | finally: |
426cccf1 | 101 | driver.close() |
736932f0 | 102 | |
91d72c6c JF |
103 | def connected(dsn): |
104 | def wrapped(method): | |
105 | def replaced(*args, **kw): | |
106 | with connect(dsn) as connection: | |
107 | return method(connection, *args, **kw) | |
108 | return replaced | |
109 | return wrapped | |
110 | ||
7d11917e JF |
111 | @contextmanager |
112 | def transact(dsn, **args): | |
113 | with connect(dsn) as connection: | |
114 | with connection.transact(**args) as cursor: | |
115 | yield cursor | |
116 | ||
426cccf1 | 117 | """ |
736932f0 JF |
118 | def slap_(sql, table, keys, values, path): |
119 | csr = sql.cursor() | |
120 | try: | |
121 | csr.execute('savepoint iou') | |
122 | try: | |
123 | both = dict(keys, **values) | |
124 | fields = both.keys() | |
125 | ||
126 | csr.execute(''' | |
127 | insert into %s (%s) values (%s) | |
128 | ''' % ( | |
129 | table, | |
130 | ', '.join(fields), | |
131 | ', '.join(['%s' for key in fields]) | |
132 | ), both.values()) | |
133 | except psycopg2.IntegrityError, e: | |
134 | csr.execute('rollback to savepoint iou') | |
135 | ||
136 | csr.execute(''' | |
137 | update %s set %s where %s | |
138 | ''' % ( | |
139 | table, | |
140 | ', '.join([ | |
141 | key + ' = %s' | |
142 | for key in values.keys()]), | |
143 | ' and '.join([ | |
144 | key + ' = %s' | |
145 | for key in keys.keys()]) | |
146 | ), values.values() + keys.values()) | |
147 | ||
148 | return path_(csr, path) | |
149 | finally: | |
150 | csr.close() | |
426cccf1 | 151 | """ |