X-Git-Url: https://git.saurik.com/cyql.git/blobdiff_plain/77034649a12bfc857886558935a404499e2105fe..refs/heads/master:/__init__.py?ds=sidebyside diff --git a/__init__.py b/__init__.py index 10383d3..d40d9b1 100644 --- a/__init__.py +++ b/__init__.py @@ -1,5 +1,9 @@ -from __future__ import unicode_literals +from __future__ import absolute_import +from __future__ import division from __future__ import print_function +from __future__ import unicode_literals + +from future_builtins import ascii, filter, hex, map, oct, zip import inspect import os @@ -11,98 +15,208 @@ import psycopg2.extras import psycopg2.pool psycopg2.extensions.register_type(psycopg2.extensions.UNICODE) +psycopg2.extensions.register_type(psycopg2.extensions.UNICODEARRAY) + +class ConnectionError(): + pass + +class connect(object): + def __init__(self, dsn): + options = dsn.copy() + if 'cache' in options: + del options['cache'] + + if 'cache' in dsn: + cached = True + cache = dsn['cache'] + else: + cached = False + cache = { + 'hstore': None, + } + + attempt = 0 + while True: + try: + self.driver = psycopg2.connect(**options) + break + except psycopg2.OperationalError, e: + if e.message.startswith('could not connect to server: '): + raise ConnectionError() + if attempt == 2: + raise + attempt = attempt + 1 + + self.driver.autocommit = True + + # XXX: all of my databases default to this... + #try: + # self.driver.set_client_encoding('UNICODE') + #except: + # self.driver.close() + # raise + + hstore = cache['hstore'] + if hstore == None: + hstore = psycopg2.extras.HstoreAdapter.get_oids(self.driver) + if hstore != None: + hstore = hstore[0] + cache['hstore'] = hstore + + if hstore != None: + try: + psycopg2.extras.register_hstore(self.driver, globally=False, unicode=True, oid=hstore) + except psycopg2.ProgrammingError, e: + pass + + if not cached: + dsn['cache'] = cache + + def close(self): + self.driver.close() + + def __enter__(self): + return self -class connection(object): - def __init__(self, driver): - self.driver = driver + def __exit__(self, type, value, traceback): + self.close() + + def begin(self): + self.driver.autocommit = False + + def commit(self): + self.driver.commit() + + def rollback(self): + self.driver.rollback() @contextmanager def cursor(self): + cursor = self.driver.cursor(cursor_factory=psycopg2.extras.DictCursor) try: - cursor = self.driver.cursor(cursor_factory=psycopg2.extras.DictCursor) yield cursor finally: cursor.close() @contextmanager - def execute(self, statement, depth=0): + def execute(self, statement, depth=0, context=None): + # two frames, accounting for execute() and @contextmanager + frame = inspect.currentframe(depth + 2) + with self.cursor() as cursor: - # two frames, accounting for execute() and @contextmanager - locals = inspect.currentframe(depth + 2).f_locals - try: - cursor.execute(statement.format(**locals), locals) - finally: - del locals + f_globals = None + f_locals = frame.f_locals + + if context == None: + context = dict(**f_locals) + + start = 0 + while True: + percent = statement.find('%', start) + if percent == -1: + break + + next = statement[percent + 1] + if next == '(': + start = statement.index(')', percent + 2) + 2 + assert statement[start - 1] == 's' + elif next == '{': + start = statement.index('}', percent + 2) + assert statement[start + 1] == 's' + code = statement[percent + 2:start] + + if f_globals == None: + f_globals = frame.f_globals + + key = '__cyql__%i' % (percent,) + # XXX: compile() in the frame's context + context[key] = eval(code, f_globals, f_locals) + + statement = '%s%%(%s)%s' % (statement[0:percent], key, statement[start + 1:]) + start = percent + len(key) + 4 + elif next in ('%', 's'): + start = percent + 2 + else: + assert False + + cursor.execute(statement, context) + + del context + del f_locals + del f_globals + yield cursor @contextmanager def transact(self, synchronous_commit=True): - with self.cursor() as cursor: - if not synchronous_commit: - cursor.execute('set local synchronous_commit = off') - + self.driver.autocommit = False try: - yield transaction(self) + with self.cursor() as cursor: + if not synchronous_commit: + cursor.execute('set local synchronous_commit = off') + + yield self.driver.commit() except: self.driver.rollback() raise + finally: + self.driver.autocommit = True -class transaction(object): - def __init__(self, connection): - self.connection = connection + def one_(self, statement, context=None): + with self.execute(statement, 2, context) as cursor: + one = cursor.fetchone() + if one == None: + return None - def pull(self, statement): - with self.connection.execute(statement, 1) as cursor: - return cursor.fetchall() + assert cursor.fetchone() == None + return one - def yank(self, statement, offset=0): - with self.connection.execute(statement, 1 + offset) as cursor: - rows = cursor.fetchall() - return rows[0] if len(rows) != 0 else None + def __call__(self, procedure, *parameters): + with self.execute(statement, 1) as cursor: + return cursor.callproc(procedure, *parameters) - def push(self, statement): - with self.connection.execute(statement, 1) as cursor: - pass + def run(self, statement, context=None): + with self.execute(statement, 1, context) as cursor: + return cursor.rowcount - def exists(self, statement): - return self.yank(''' - select exists ( - {statement} - ) - '''.format(**locals()), 1)[0] + def gen(self, statement): + with self.execute(statement, 1) as cursor: + while True: + fetch = cursor.fetchone() + if fetch == None: + break + yield fetch -@contextmanager -def connect(dsn): - attempt = 0 - while True: - try: - driver = psycopg2.connect(**dsn) - break - except psycopg2.OperationalError, e: - if attempt == 2: - raise e - attempt = attempt + 1 + @contextmanager + def set(self, statement): + with self.execute(statement, 2) as cursor: + yield cursor - try: - driver.set_client_encoding('UNICODE') - yield connection(driver) - finally: - driver.close() + def all(self, statement, context=None): + with self.execute(statement, 1, context) as cursor: + return cursor.fetchall() + + def one(self, statement, context=None): + return self.one_(statement, context) + + def has(self, statement): + exists, = self.one_('select exists(%s)' % (statement,)) + return exists def connected(dsn): def wrapped(method): def replaced(*args, **kw): - with connect(dsn) as connection: - return method(connection, *args, **kw) + with connect(dsn) as sql: + return method(*args, sql=sql, **kw) return replaced return wrapped @contextmanager -def transact(dsn, **args): +def transact(dsn, *args, **kw): with connect(dsn) as connection: - with connection.transact(**args) as cursor: - yield cursor + with connection.transact(*args, **kw): + yield connection """ def slap_(sql, table, keys, values, path):