]> git.saurik.com Git - cyql.git/blobdiff - __init__.py
Switch to psycopg2 autocommit syntax.
[cyql.git] / __init__.py
index 90beb40190c3133c8ff7ef6cefb7301a275a60ec..9fc8d31c388bea89919ee55c414d99311c962b24 100644 (file)
@@ -1,5 +1,9 @@
-from __future__ import unicode_literals
+from __future__ import absolute_import
+from __future__ import division
 from __future__ import print_function
+from __future__ import unicode_literals
+
+from future_builtins import ascii, filter, hex, map, oct, zip
 
 import inspect
 import os
@@ -11,18 +15,48 @@ import psycopg2.extras
 import psycopg2.pool
 
 psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
+psycopg2.extensions.register_type(psycopg2.extensions.UNICODEARRAY)
+
+class connect(object):
+    def __init__(self, dsn):
+        attempt = 0
+        while True:
+            try:
+                self.driver = psycopg2.connect(**dsn)
+                break
+            except psycopg2.OperationalError, e:
+                if attempt == 2:
+                    raise e
+                attempt = attempt + 1
+
+        try:
+            self.driver.set_client_encoding('UNICODE')
+            self.driver.autocommit = True
+        except:
+            self.driver.close()
+
+        try:
+            psycopg2.extras.register_hstore(self.driver, globally=False, unicode=True)
+        except psycopg2.ProgrammingError, e:
+            pass
+
+    def close(self):
+        self.driver.close()
+
+    def __enter__(self):
+        return self
 
-def one(values):
-    if values == None or len(values) == 0:
-        return None
-    else:
-        assert len(values) == 1
-        return values[0]
+    def __exit__(self, type, value, traceback):
+        self.close()
 
-class connection(object):
-    def __init__(self, driver):
-        self.driver = driver
-        self.driver.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)
+    def begin(self):
+        self.driver.autocommit = False
+
+    def commit(self):
+        self.driver.commit()
+
+    def rollback(self):
+        self.driver.rollback()
 
     @contextmanager
     def cursor(self):
@@ -34,20 +68,56 @@ class connection(object):
 
     @contextmanager
     def execute(self, statement, depth=0, context=None):
+        # two frames, accounting for execute() and @contextmanager
+        frame = inspect.currentframe(depth + 2)
+
         with self.cursor() as cursor:
-            # two frames, accounting for execute() and @contextmanager
-            locals = inspect.currentframe(depth + 2).f_locals
-            try:
-                if context == None:
-                    context = locals
-                cursor.execute(statement.format(**locals), context)
-            finally:
-                del locals
+            f_globals = None
+            f_locals = frame.f_locals
+
+            if context == None:
+                context = dict(**f_locals)
+
+            start = 0
+            while True:
+                percent = statement.find('%', start)
+                if percent == -1:
+                    break
+
+                next = statement[percent + 1]
+                if next == '(':
+                    start = statement.index(')', percent + 2) + 2
+                    assert statement[start - 1] == 's'
+                elif next == '{':
+                    start = statement.index('}', percent + 2)
+                    assert statement[start + 1] == 's'
+                    code = statement[percent + 2:start]
+
+                    if f_globals == None:
+                        f_globals = frame.f_globals
+
+                    key = '__cyql__%i' % (percent,)
+                    # XXX: compile() in the frame's context
+                    context[key] = eval(code, f_globals, f_locals)
+
+                    statement = '%s%%(%s)%s' % (statement[0:percent], key, statement[start + 1:])
+                    start = percent + len(key) + 4
+                elif next in ('%', 's'):
+                    start = percent + 2
+                else:
+                    assert False
+
+            cursor.execute(statement, context)
+
+            del context
+            del f_locals
+            del f_globals
+
             yield cursor
 
     @contextmanager
     def transact(self, synchronous_commit=True):
-        self.driver.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED)
+        self.driver.autocommit = False
         try:
             with self.cursor() as cursor:
                 if not synchronous_commit:
@@ -59,10 +129,10 @@ class connection(object):
             self.driver.rollback()
             raise
         finally:
-            self.driver.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)
+            self.driver.autocommit = True
 
-    def one_(self, statement):
-        with self.execute(statement, 2) as cursor:
+    def one_(self, statement, context=None):
+        with self.execute(statement, 2, context) as cursor:
             one = cursor.fetchone()
             if one == None:
                 return None
@@ -74,8 +144,8 @@ class connection(object):
         with self.execute(statement, 1) as cursor:
             return cursor.callproc(procedure, *parameters)
 
-    def run(self, statement, locals=None):
-        with self.execute(statement, 1, locals) as cursor:
+    def run(self, statement, context=None):
+        with self.execute(statement, 1, context) as cursor:
             return cursor.rowcount
 
     @contextmanager
@@ -83,33 +153,16 @@ class connection(object):
         with self.execute(statement, 1) as cursor:
             yield cursor
 
-    def all(self, statement):
-        with self.execute(statement, 1) as cursor:
+    def all(self, statement, context=None):
+        with self.execute(statement, 1, context) as cursor:
             return cursor.fetchall()
 
-    def one(self, statement):
-        return self.one_(statement)
+    def one(self, statement, context=None):
+        return self.one_(statement, context)
 
     def has(self, statement):
-        return one(self.one_('select exists(%s)' % (statement,)))
-
-@contextmanager
-def connect(dsn):
-    attempt = 0
-    while True:
-        try:
-            driver = psycopg2.connect(**dsn)
-            break
-        except psycopg2.OperationalError, e:
-            if attempt == 2:
-                raise e
-            attempt = attempt + 1
-
-    try:
-        driver.set_client_encoding('UNICODE')
-        yield connection(driver)
-    finally:
-        driver.close()
+        exists, = self.one_('select exists(%s)' % (statement,))
+        return exists
 
 def connected(dsn):
     def wrapped(method):