-from __future__ import unicode_literals
+from __future__ import absolute_import
+from __future__ import division
from __future__ import print_function
+from __future__ import unicode_literals
+
+from future_builtins import ascii, filter, hex, map, oct, zip
import inspect
import os
import psycopg2.pool
psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
+psycopg2.extensions.register_type(psycopg2.extensions.UNICODEARRAY)
+
+class connect(object):
+ def __init__(self, dsn):
+ attempt = 0
+ while True:
+ try:
+ self.driver = psycopg2.connect(**dsn)
+ break
+ except psycopg2.OperationalError, e:
+ if attempt == 2:
+ raise e
+ attempt = attempt + 1
+
+ try:
+ self.driver.set_client_encoding('UNICODE')
+ self.driver.autocommit = True
+ except:
+ self.driver.close()
+
+ try:
+ psycopg2.extras.register_hstore(self.driver, globally=False, unicode=True)
+ except psycopg2.ProgrammingError, e:
+ pass
-class connection(object):
- def __init__(self, driver):
- self.driver = driver
+ def close(self):
+ self.driver.close()
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, type, value, traceback):
+ self.close()
+
+ def begin(self):
+ self.driver.autocommit = False
+
+ def commit(self):
+ self.driver.commit()
+
+ def rollback(self):
+ self.driver.rollback()
@contextmanager
def cursor(self):
+ cursor = self.driver.cursor(cursor_factory=psycopg2.extras.DictCursor)
try:
- cursor = self.driver.cursor(cursor_factory=psycopg2.extras.DictCursor)
yield cursor
finally:
cursor.close()
@contextmanager
- def execute(self, statement, depth=0):
+ def execute(self, statement, depth=0, context=None):
+ # two frames, accounting for execute() and @contextmanager
+ frame = inspect.currentframe(depth + 2)
+
with self.cursor() as cursor:
- # two frames, accounting for execute() and @contextmanager
- locals = inspect.currentframe(depth + 2).f_locals
- cursor.execute(statement.format(**locals), locals)
+ f_globals = None
+ f_locals = frame.f_locals
+
+ if context == None:
+ context = dict(**f_locals)
+
+ start = 0
+ while True:
+ percent = statement.find('%', start)
+ if percent == -1:
+ break
+
+ next = statement[percent + 1]
+ if next == '(':
+ start = statement.index(')', percent + 2) + 2
+ assert statement[start - 1] == 's'
+ elif next == '{':
+ start = statement.index('}', percent + 2)
+ assert statement[start + 1] == 's'
+ code = statement[percent + 2:start]
+
+ if f_globals == None:
+ f_globals = frame.f_globals
+
+ key = '__cyql__%i' % (percent,)
+ # XXX: compile() in the frame's context
+ context[key] = eval(code, f_globals, f_locals)
+
+ statement = '%s%%(%s)%s' % (statement[0:percent], key, statement[start + 1:])
+ start = percent + len(key) + 4
+ elif next in ('%', 's'):
+ start = percent + 2
+ else:
+ assert False
+
+ cursor.execute(statement, context)
+
+ del context
+ del f_locals
+ del f_globals
+
yield cursor
@contextmanager
def transact(self, synchronous_commit=True):
- with self.cursor() as cursor:
- if not synchronous_commit:
- cursor.execute('set local synchronous_commit = off')
-
+ self.driver.autocommit = False
try:
- yield transaction(self)
+ with self.cursor() as cursor:
+ if not synchronous_commit:
+ cursor.execute('set local synchronous_commit = off')
+
+ yield
self.driver.commit()
except:
self.driver.rollback()
raise
+ finally:
+ self.driver.autocommit = True
-class transaction(object):
- def __init__(self, connection):
- self.connection = connection
+ def one_(self, statement, context=None):
+ with self.execute(statement, 2, context) as cursor:
+ one = cursor.fetchone()
+ if one == None:
+ return None
- def pull(self, statement):
- with self.connection.execute(statement, 1) as cursor:
- return cursor.fetchall()
+ assert cursor.fetchone() == None
+ return one
- def yank(self, statement):
- with self.connection.execute(statement, 1) as cursor:
- rows = cursor.fetchall()
- return rows[0] if len(rows) != 0 else None
+ def __call__(self, procedure, *parameters):
+ with self.execute(statement, 1) as cursor:
+ return cursor.callproc(procedure, *parameters)
- def push(self, statement):
- with self.connection.execute(statement, 1) as cursor:
- pass
+ def run(self, statement, context=None):
+ with self.execute(statement, 1, context) as cursor:
+ return cursor.rowcount
-@contextmanager
-def connect(dsn):
- attempt = 0
- while True:
- try:
- driver = psycopg2.connect(**dsn)
- break
- except psycopg2.OperationalError, e:
- if attempt == 2:
- raise e
- attempt = attempt + 1
+ @contextmanager
+ def set(self, statement):
+ with self.execute(statement, 1) as cursor:
+ yield cursor
- try:
- driver.set_client_encoding('UNICODE')
- yield connection(driver)
- finally:
- driver.close()
+ def all(self, statement, context=None):
+ with self.execute(statement, 1, context) as cursor:
+ return cursor.fetchall()
+
+ def one(self, statement, context=None):
+ return self.one_(statement, context)
+
+ def has(self, statement):
+ exists, = self.one_('select exists(%s)' % (statement,))
+ return exists
+
+def connected(dsn):
+ def wrapped(method):
+ def replaced(*args, **kw):
+ with connect(dsn) as sql:
+ return method(*args, sql=sql, **kw)
+ return replaced
+ return wrapped
@contextmanager
-def transact(dsn, **args):
+def transact(dsn, *args, **kw):
with connect(dsn) as connection:
- with connection.transact(**args) as cursor:
- yield cursor
+ with connection.transact(*args, **kw):
+ yield connection
"""
def slap_(sql, table, keys, values, path):