X-Git-Url: https://git.saurik.com/cyql.git/blobdiff_plain/7d11917e79db64a1ac4a26747874b6ac61c17b4a..f392523d5d61f3447a05a319ec2e703dcc7a9596:/__init__.py?ds=sidebyside diff --git a/__init__.py b/__init__.py index 0c29a58..9fc8d31 100644 --- a/__init__.py +++ b/__init__.py @@ -1,5 +1,9 @@ -from __future__ import unicode_literals +from __future__ import absolute_import +from __future__ import division from __future__ import print_function +from __future__ import unicode_literals + +from future_builtins import ascii, filter, hex, map, oct, zip import inspect import os @@ -11,80 +15,168 @@ import psycopg2.extras import psycopg2.pool psycopg2.extensions.register_type(psycopg2.extensions.UNICODE) +psycopg2.extensions.register_type(psycopg2.extensions.UNICODEARRAY) + +class connect(object): + def __init__(self, dsn): + attempt = 0 + while True: + try: + self.driver = psycopg2.connect(**dsn) + break + except psycopg2.OperationalError, e: + if attempt == 2: + raise e + attempt = attempt + 1 + + try: + self.driver.set_client_encoding('UNICODE') + self.driver.autocommit = True + except: + self.driver.close() + + try: + psycopg2.extras.register_hstore(self.driver, globally=False, unicode=True) + except psycopg2.ProgrammingError, e: + pass -class connection(object): - def __init__(self, driver): - self.driver = driver + def close(self): + self.driver.close() + + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + self.close() + + def begin(self): + self.driver.autocommit = False + + def commit(self): + self.driver.commit() + + def rollback(self): + self.driver.rollback() @contextmanager def cursor(self): + cursor = self.driver.cursor(cursor_factory=psycopg2.extras.DictCursor) try: - cursor = self.driver.cursor(cursor_factory=psycopg2.extras.DictCursor) yield cursor finally: cursor.close() @contextmanager - def execute(self, statement, depth=0): + def execute(self, statement, depth=0, context=None): + # two frames, accounting for execute() and @contextmanager + frame = inspect.currentframe(depth + 2) + with self.cursor() as cursor: - # two frames, accounting for execute() and @contextmanager - locals = inspect.currentframe(depth + 2).f_locals - cursor.execute(statement.format(**locals), locals) + f_globals = None + f_locals = frame.f_locals + + if context == None: + context = dict(**f_locals) + + start = 0 + while True: + percent = statement.find('%', start) + if percent == -1: + break + + next = statement[percent + 1] + if next == '(': + start = statement.index(')', percent + 2) + 2 + assert statement[start - 1] == 's' + elif next == '{': + start = statement.index('}', percent + 2) + assert statement[start + 1] == 's' + code = statement[percent + 2:start] + + if f_globals == None: + f_globals = frame.f_globals + + key = '__cyql__%i' % (percent,) + # XXX: compile() in the frame's context + context[key] = eval(code, f_globals, f_locals) + + statement = '%s%%(%s)%s' % (statement[0:percent], key, statement[start + 1:]) + start = percent + len(key) + 4 + elif next in ('%', 's'): + start = percent + 2 + else: + assert False + + cursor.execute(statement, context) + + del context + del f_locals + del f_globals + yield cursor @contextmanager def transact(self, synchronous_commit=True): - with self.cursor() as cursor: - if not synchronous_commit: - cursor.execute('set local synchronous_commit = off') - + self.driver.autocommit = False try: - yield transaction(self) + with self.cursor() as cursor: + if not synchronous_commit: + cursor.execute('set local synchronous_commit = off') + + yield self.driver.commit() except: self.driver.rollback() raise + finally: + self.driver.autocommit = True -class transaction(object): - def __init__(self, connection): - self.connection = connection + def one_(self, statement, context=None): + with self.execute(statement, 2, context) as cursor: + one = cursor.fetchone() + if one == None: + return None - def pull(self, statement): - with self.connection.execute(statement, 1) as cursor: - return cursor.fetchall() + assert cursor.fetchone() == None + return one - def yank(self, statement): - with self.connection.execute(statement, 1) as cursor: - rows = cursor.fetchall() - return rows[0] if len(rows) != 0 else None + def __call__(self, procedure, *parameters): + with self.execute(statement, 1) as cursor: + return cursor.callproc(procedure, *parameters) - def push(self, statement): - with self.connection.execute(statement, 1) as cursor: - pass + def run(self, statement, context=None): + with self.execute(statement, 1, context) as cursor: + return cursor.rowcount -@contextmanager -def connect(dsn): - attempt = 0 - while True: - try: - driver = psycopg2.connect(**dsn) - break - except psycopg2.OperationalError, e: - if attempt == 2: - raise e - attempt = attempt + 1 + @contextmanager + def set(self, statement): + with self.execute(statement, 1) as cursor: + yield cursor - try: - driver.set_client_encoding('UNICODE') - yield connection(driver) - finally: - driver.close() + def all(self, statement, context=None): + with self.execute(statement, 1, context) as cursor: + return cursor.fetchall() + + def one(self, statement, context=None): + return self.one_(statement, context) + + def has(self, statement): + exists, = self.one_('select exists(%s)' % (statement,)) + return exists + +def connected(dsn): + def wrapped(method): + def replaced(*args, **kw): + with connect(dsn) as sql: + return method(*args, sql=sql, **kw) + return replaced + return wrapped @contextmanager -def transact(dsn, **args): +def transact(dsn, *args, **kw): with connect(dsn) as connection: - with connection.transact(**args) as cursor: - yield cursor + with connection.transact(*args, **kw): + yield connection """ def slap_(sql, table, keys, values, path):