]> git.saurik.com Git - cyql.git/blobdiff - __init__.py
Commit a path to access result rows via generator.
[cyql.git] / __init__.py
index 36548c263c6f984c71226f6d4a79f3c3ed3ffd17..d40d9b1eb57869e0bf9efb42cc241fae9a777f73 100644 (file)
@@ -1,5 +1,9 @@
-#from __future__ import unicode_literals
-#from __future__ import print_function
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+from future_builtins import ascii, filter, hex, map, oct, zip
 
 import inspect
 import os
@@ -7,77 +11,214 @@ import os
 from contextlib import contextmanager
 
 import psycopg2
+import psycopg2.extras
 import psycopg2.pool
 
-from psycopg2.extras import DictCursor
-
 psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
+psycopg2.extensions.register_type(psycopg2.extensions.UNICODEARRAY)
+
+class ConnectionError():
+    pass
+
+class connect(object):
+    def __init__(self, dsn):
+        options = dsn.copy()
+        if 'cache' in options:
+            del options['cache']
+
+        if 'cache' in dsn:
+            cached = True
+            cache = dsn['cache']
+        else:
+            cached = False
+            cache = {
+                'hstore': None,
+            }
+
+        attempt = 0
+        while True:
+            try:
+                self.driver = psycopg2.connect(**options)
+                break
+            except psycopg2.OperationalError, e:
+                if e.message.startswith('could not connect to server: '):
+                    raise ConnectionError()
+                if attempt == 2:
+                    raise
+                attempt = attempt + 1
+
+        self.driver.autocommit = True
+
+        # XXX: all of my databases default to this...
+        #try:
+        #    self.driver.set_client_encoding('UNICODE')
+        #except:
+        #    self.driver.close()
+        #    raise
+
+        hstore = cache['hstore']
+        if hstore == None:
+            hstore = psycopg2.extras.HstoreAdapter.get_oids(self.driver)
+            if hstore != None:
+                hstore = hstore[0]
+                cache['hstore'] = hstore
+
+        if hstore != None:
+            try:
+                psycopg2.extras.register_hstore(self.driver, globally=False, unicode=True, oid=hstore)
+            except psycopg2.ProgrammingError, e:
+                pass
 
-def Cursor(sql):
-    return sql.cursor(cursor_factory=DictCursor)
+        if not cached:
+            dsn['cache'] = cache
 
-class Transaction(object):
-    def __init__(self, connection):
-        self.connection = connection
+    def close(self):
+        self.driver.close()
 
-    def pull(self, statement):
-        locals = inspect.currentframe(1).f_locals
-        cursor = Cursor(self.connection)
+    def __enter__(self):
+        return self
 
-        try:
-            cursor.execute(statement.format(**locals), locals)
-            return cursor.fetchall()
-        finally:
-            cursor.close()
+    def __exit__(self, type, value, traceback):
+        self.close()
 
-    def yank(self, statement):
-        locals = inspect.currentframe(1).f_locals
-        cursor = Cursor(self.connection)
+    def begin(self):
+        self.driver.autocommit = False
 
-        try:
-            cursor.execute(statement.format(**locals), locals)
-            rows = cursor.fetchall()
-            return rows[0] if len(rows) != 0 else None
-        finally:
-            cursor.close()
+    def commit(self):
+        self.driver.commit()
 
-    def push(self, statement):
-        locals = inspect.currentframe(1).f_locals
-        cursor = Cursor(self.connection)
+    def rollback(self):
+        self.driver.rollback()
 
+    @contextmanager
+    def cursor(self):
+        cursor = self.driver.cursor(cursor_factory=psycopg2.extras.DictCursor)
         try:
-            cursor.execute(statement.format(**locals), locals)
+            yield cursor
         finally:
             cursor.close()
 
-@contextmanager
-def ConnectSQL(dsn):
-    attempt = 0
-    while True:
+    @contextmanager
+    def execute(self, statement, depth=0, context=None):
+        # two frames, accounting for execute() and @contextmanager
+        frame = inspect.currentframe(depth + 2)
+
+        with self.cursor() as cursor:
+            f_globals = None
+            f_locals = frame.f_locals
+
+            if context == None:
+                context = dict(**f_locals)
+
+            start = 0
+            while True:
+                percent = statement.find('%', start)
+                if percent == -1:
+                    break
+
+                next = statement[percent + 1]
+                if next == '(':
+                    start = statement.index(')', percent + 2) + 2
+                    assert statement[start - 1] == 's'
+                elif next == '{':
+                    start = statement.index('}', percent + 2)
+                    assert statement[start + 1] == 's'
+                    code = statement[percent + 2:start]
+
+                    if f_globals == None:
+                        f_globals = frame.f_globals
+
+                    key = '__cyql__%i' % (percent,)
+                    # XXX: compile() in the frame's context
+                    context[key] = eval(code, f_globals, f_locals)
+
+                    statement = '%s%%(%s)%s' % (statement[0:percent], key, statement[start + 1:])
+                    start = percent + len(key) + 4
+                elif next in ('%', 's'):
+                    start = percent + 2
+                else:
+                    assert False
+
+            cursor.execute(statement, context)
+
+            del context
+            del f_locals
+            del f_globals
+
+            yield cursor
+
+    @contextmanager
+    def transact(self, synchronous_commit=True):
+        self.driver.autocommit = False
         try:
-            sql = psycopg2.connect(**dsn)
-            break
-        except psycopg2.OperationalError, e:
-            if attempt == 2:
-                raise e
-            attempt = attempt + 1
+            with self.cursor() as cursor:
+                if not synchronous_commit:
+                    cursor.execute('set local synchronous_commit = off')
+
+            yield
+            self.driver.commit()
+        except:
+            self.driver.rollback()
+            raise
+        finally:
+            self.driver.autocommit = True
+
+    def one_(self, statement, context=None):
+        with self.execute(statement, 2, context) as cursor:
+            one = cursor.fetchone()
+            if one == None:
+                return None
+
+            assert cursor.fetchone() == None
+            return one
+
+    def __call__(self, procedure, *parameters):
+        with self.execute(statement, 1) as cursor:
+            return cursor.callproc(procedure, *parameters)
+
+    def run(self, statement, context=None):
+        with self.execute(statement, 1, context) as cursor:
+            return cursor.rowcount
+
+    def gen(self, statement):
+        with self.execute(statement, 1) as cursor:
+            while True:
+                fetch = cursor.fetchone()
+                if fetch == None:
+                    break
+                yield fetch
+
+    @contextmanager
+    def set(self, statement):
+        with self.execute(statement, 2) as cursor:
+            yield cursor
+
+    def all(self, statement, context=None):
+        with self.execute(statement, 1, context) as cursor:
+            return cursor.fetchall()
 
-    try:
-        sql.set_client_encoding('UNICODE')
+    def one(self, statement, context=None):
+        return self.one_(statement, context)
 
-        @contextmanager
-        def transact():
-            try:
-                yield Transaction(sql)
-                sql.commit()
-            except:
-                sql.rollback()
-                raise
+    def has(self, statement):
+        exists, = self.one_('select exists(%s)' % (statement,))
+        return exists
 
-        yield transact
-    finally:
-        sql.close()
+def connected(dsn):
+    def wrapped(method):
+        def replaced(*args, **kw):
+            with connect(dsn) as sql:
+                return method(*args, sql=sql, **kw)
+        return replaced
+    return wrapped
+
+@contextmanager
+def transact(dsn, *args, **kw):
+    with connect(dsn) as connection:
+        with connection.transact(*args, **kw):
+            yield connection
 
+"""
 def slap_(sql, table, keys, values, path):
     csr = sql.cursor()
     try:
@@ -111,3 +252,4 @@ def slap_(sql, table, keys, values, path):
         return path_(csr, path)
     finally:
         csr.close()
+"""