-from __future__ import unicode_literals
+from __future__ import absolute_import
+from __future__ import division
from __future__ import print_function
+from __future__ import unicode_literals
+
+from future_builtins import ascii, filter, hex, map, oct, zip
import inspect
import os
psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
-class connection(object):
- def __init__(self, driver):
- self.driver = driver
+class connect(object):
+ def __init__(self, dsn):
+ attempt = 0
+ while True:
+ try:
+ self.driver = psycopg2.connect(**dsn)
+ break
+ except psycopg2.OperationalError, e:
+ if attempt == 2:
+ raise e
+ attempt = attempt + 1
+
+ try:
+ self.driver.set_client_encoding('UNICODE')
+ self.driver.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)
+ except:
+ self.driver.close()
+
+ def close(self):
+ self.driver.close()
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, type, value, traceback):
+ self.close()
@contextmanager
def cursor(self):
+ cursor = self.driver.cursor(cursor_factory=psycopg2.extras.DictCursor)
try:
- cursor = self.driver.cursor(cursor_factory=psycopg2.extras.DictCursor)
yield cursor
finally:
cursor.close()
@contextmanager
- def execute(self, statement, depth=0):
+ def execute(self, statement, depth=0, context=None):
with self.cursor() as cursor:
- locals = inspect.currentframe(depth + 1).f_locals
- cursor.execute(statement.format(**locals), locals)
+ # two frames, accounting for execute() and @contextmanager
+ locals = inspect.currentframe(depth + 2).f_locals
+ try:
+ if context == None:
+ context = locals
+ cursor.execute(statement.format(**locals), context)
+ finally:
+ del locals
yield cursor
@contextmanager
def transact(self, synchronous_commit=True):
- with self.cursor() as cursor:
- if not synchronous_commit:
- cursor.execute('set local synchronous_commit = off')
-
+ self.driver.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED)
try:
- yield transaction(self)
+ with self.cursor() as cursor:
+ if not synchronous_commit:
+ cursor.execute('set local synchronous_commit = off')
+
+ yield
self.driver.commit()
except:
self.driver.rollback()
raise
+ finally:
+ self.driver.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)
+
+ def one_(self, statement):
+ with self.execute(statement, 2) as cursor:
+ one = cursor.fetchone()
+ if one == None:
+ return None
+
+ assert cursor.fetchone() == None
+ return one
-class transaction(object):
- def __init__(self, connection):
- self.connection = connection
+ def __call__(self, procedure, *parameters):
+ with self.execute(statement, 1) as cursor:
+ return cursor.callproc(procedure, *parameters)
+
+ def run(self, statement, locals=None):
+ with self.execute(statement, 1, locals) as cursor:
+ return cursor.rowcount
+
+ @contextmanager
+ def set(self, statement):
+ with self.execute(statement, 1) as cursor:
+ yield cursor
- def pull(self, statement):
- with self.connection.execute(statement, 1) as cursor:
+ def all(self, statement):
+ with self.execute(statement, 1) as cursor:
return cursor.fetchall()
- def yank(self, statement):
- with self.connection.execute(statement, 1) as cursor:
- rows = cursor.fetchall()
- return rows[0] if len(rows) != 0 else None
+ def one(self, statement):
+ return self.one_(statement)
- def push(self, statement):
- with self.connection.execute(statement, 1) as cursor:
- pass
+ def has(self, statement):
+ exists, = self.one_('select exists(%s)' % (statement,))
+ return exists
-@contextmanager
-def connect(dsn):
- attempt = 0
- while True:
- try:
- driver = psycopg2.connect(**dsn)
- break
- except psycopg2.OperationalError, e:
- if attempt == 2:
- raise e
- attempt = attempt + 1
+def connected(dsn):
+ def wrapped(method):
+ def replaced(*args, **kw):
+ with connect(dsn) as sql:
+ return method(*args, sql=sql, **kw)
+ return replaced
+ return wrapped
- try:
- driver.set_client_encoding('UNICODE')
- yield connection(driver)
- finally:
- driver.close()
+@contextmanager
+def transact(dsn, *args, **kw):
+ with connect(dsn) as connection:
+ with connection.transact(*args, **kw):
+ yield connection
"""
def slap_(sql, table, keys, values, path):