-from __future__ import unicode_literals
+from __future__ import absolute_import
+from __future__ import division
from __future__ import print_function
+from __future__ import unicode_literals
+
+from future_builtins import ascii, filter, hex, map, oct, zip
import inspect
import os
import psycopg2.pool
psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
+psycopg2.extensions.register_type(psycopg2.extensions.UNICODEARRAY)
+
+class ConnectionError():
+ pass
+
+class connect(object):
+ def __init__(self, dsn):
+ options = dsn.copy()
+ if 'cache' in options:
+ del options['cache']
+
+ if 'cache' in dsn:
+ cached = True
+ cache = dsn['cache']
+ else:
+ cached = False
+ cache = {
+ 'hstore': None,
+ }
+
+ attempt = 0
+ while True:
+ try:
+ self.driver = psycopg2.connect(**options)
+ break
+ except psycopg2.OperationalError, e:
+ if e.message.startswith('could not connect to server: '):
+ raise ConnectionError()
+ if attempt == 2:
+ raise
+ attempt = attempt + 1
+
+ self.driver.autocommit = True
+
+ # XXX: all of my databases default to this...
+ #try:
+ # self.driver.set_client_encoding('UNICODE')
+ #except:
+ # self.driver.close()
+ # raise
+
+ hstore = cache['hstore']
+ if hstore == None:
+ hstore = psycopg2.extras.HstoreAdapter.get_oids(self.driver)
+ if hstore != None:
+ hstore = hstore[0]
+ cache['hstore'] = hstore
+
+ if hstore != None:
+ try:
+ psycopg2.extras.register_hstore(self.driver, globally=False, unicode=True, oid=hstore)
+ except psycopg2.ProgrammingError, e:
+ pass
+
+ if not cached:
+ dsn['cache'] = cache
+
+ def close(self):
+ self.driver.close()
+
+ def __enter__(self):
+ return self
-class connection(object):
- def __init__(self, driver):
- self.driver = driver
+ def __exit__(self, type, value, traceback):
+ self.close()
+
+ def begin(self):
+ self.driver.autocommit = False
+
+ def commit(self):
+ self.driver.commit()
+
+ def rollback(self):
+ self.driver.rollback()
@contextmanager
def cursor(self):
+ cursor = self.driver.cursor(cursor_factory=psycopg2.extras.DictCursor)
try:
- cursor = self.driver.cursor(cursor_factory=psycopg2.extras.DictCursor)
yield cursor
finally:
cursor.close()
@contextmanager
- def execute_(self, statement, locals):
- with self.cursor() as cursor:
- cursor.execute(statement.format(**locals), locals)
- yield cursor
+ def execute(self, statement, depth=0, context=None):
+ # two frames, accounting for execute() and @contextmanager
+ frame = inspect.currentframe(depth + 2)
- @contextmanager
- def execute(self, statement, depth=0):
with self.cursor() as cursor:
- # two frames, accounting for execute() and @contextmanager
- locals = inspect.currentframe(depth + 2).f_locals
- try:
- cursor.execute(statement.format(**locals), locals)
- finally:
- del locals
+ f_globals = None
+ f_locals = frame.f_locals
+
+ if context == None:
+ context = dict(**f_locals)
+
+ start = 0
+ while True:
+ percent = statement.find('%', start)
+ if percent == -1:
+ break
+
+ next = statement[percent + 1]
+ if next == '(':
+ start = statement.index(')', percent + 2) + 2
+ assert statement[start - 1] == 's'
+ elif next == '{':
+ start = statement.index('}', percent + 2)
+ assert statement[start + 1] == 's'
+ code = statement[percent + 2:start]
+
+ if f_globals == None:
+ f_globals = frame.f_globals
+
+ key = '__cyql__%i' % (percent,)
+ # XXX: compile() in the frame's context
+ context[key] = eval(code, f_globals, f_locals)
+
+ statement = '%s%%(%s)%s' % (statement[0:percent], key, statement[start + 1:])
+ start = percent + len(key) + 4
+ elif next in ('%', 's'):
+ start = percent + 2
+ else:
+ assert False
+
+ cursor.execute(statement, context)
+
+ del context
+ del f_locals
+ del f_globals
+
yield cursor
@contextmanager
def transact(self, synchronous_commit=True):
- with self.cursor() as cursor:
- if not synchronous_commit:
- cursor.execute('set local synchronous_commit = off')
-
+ self.driver.autocommit = False
try:
- yield transaction(self)
+ with self.cursor() as cursor:
+ if not synchronous_commit:
+ cursor.execute('set local synchronous_commit = off')
+
+ yield
self.driver.commit()
except:
self.driver.rollback()
raise
+ finally:
+ self.driver.autocommit = True
-class transaction(object):
- def __init__(self, connection):
- self.connection = connection
+ def one_(self, statement, context=None):
+ with self.execute(statement, 2, context) as cursor:
+ one = cursor.fetchone()
+ if one == None:
+ return None
- def pull(self, statement):
- with self.connection.execute(statement, 1) as cursor:
- return cursor.fetchall()
+ assert cursor.fetchone() == None
+ return one
- def yank(self, statement, offset=0):
- with self.connection.execute(statement, 1 + offset) as cursor:
- rows = cursor.fetchall()
- return rows[0] if len(rows) != 0 else None
+ def __call__(self, procedure, *parameters):
+ with self.execute(statement, 1) as cursor:
+ return cursor.callproc(procedure, *parameters)
- def push(self, statement):
- with self.connection.execute(statement, 1) as cursor:
- pass
+ def run(self, statement, context=None):
+ with self.execute(statement, 1, context) as cursor:
+ return cursor.rowcount
- def push_(self, statement, locals):
- with self.connection.execute_(statement, locals) as cursor:
- pass
+ def gen(self, statement):
+ with self.execute(statement, 1) as cursor:
+ while True:
+ fetch = cursor.fetchone()
+ if fetch == None:
+ break
+ yield fetch
- def exists(self, statement):
- return self.yank('''
- select exists (
- {statement}
- )
- '''.format(**locals()), 1)[0]
+ @contextmanager
+ def set(self, statement):
+ with self.execute(statement, 2) as cursor:
+ yield cursor
-@contextmanager
-def connect(dsn):
- attempt = 0
- while True:
- try:
- driver = psycopg2.connect(**dsn)
- break
- except psycopg2.OperationalError, e:
- if attempt == 2:
- raise e
- attempt = attempt + 1
+ def all(self, statement, context=None):
+ with self.execute(statement, 1, context) as cursor:
+ return cursor.fetchall()
- try:
- driver.set_client_encoding('UNICODE')
- yield connection(driver)
- finally:
- driver.close()
+ def one(self, statement, context=None):
+ return self.one_(statement, context)
+
+ def has(self, statement):
+ exists, = self.one_('select exists(%s)' % (statement,))
+ return exists
def connected(dsn):
def wrapped(method):
def replaced(*args, **kw):
- with connect(dsn) as connection:
- return method(connection, *args, **kw)
+ with connect(dsn) as sql:
+ return method(*args, sql=sql, **kw)
return replaced
return wrapped
@contextmanager
-def transact(dsn, **args):
+def transact(dsn, *args, **kw):
with connect(dsn) as connection:
- with connection.transact(**args) as cursor:
- yield cursor
+ with connection.transact(*args, **kw):
+ yield connection
"""
def slap_(sql, table, keys, values, path):