]> git.saurik.com Git - cyql.git/blobdiff - __init__.py
Switch to psycopg2 autocommit syntax.
[cyql.git] / __init__.py
index 48bc92ca90d921e8d0e789b833f228b8d6635f8d..9fc8d31c388bea89919ee55c414d99311c962b24 100644 (file)
@@ -15,6 +15,7 @@ import psycopg2.extras
 import psycopg2.pool
 
 psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
+psycopg2.extensions.register_type(psycopg2.extensions.UNICODEARRAY)
 
 class connect(object):
     def __init__(self, dsn):
@@ -30,10 +31,15 @@ class connect(object):
 
         try:
             self.driver.set_client_encoding('UNICODE')
-            self.driver.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)
+            self.driver.autocommit = True
         except:
             self.driver.close()
 
+        try:
+            psycopg2.extras.register_hstore(self.driver, globally=False, unicode=True)
+        except psycopg2.ProgrammingError, e:
+            pass
+
     def close(self):
         self.driver.close()
 
@@ -43,6 +49,15 @@ class connect(object):
     def __exit__(self, type, value, traceback):
         self.close()
 
+    def begin(self):
+        self.driver.autocommit = False
+
+    def commit(self):
+        self.driver.commit()
+
+    def rollback(self):
+        self.driver.rollback()
+
     @contextmanager
     def cursor(self):
         cursor = self.driver.cursor(cursor_factory=psycopg2.extras.DictCursor)
@@ -52,15 +67,16 @@ class connect(object):
             cursor.close()
 
     @contextmanager
-    def execute(self, statement, depth=0):
+    def execute(self, statement, depth=0, context=None):
         # two frames, accounting for execute() and @contextmanager
         frame = inspect.currentframe(depth + 2)
 
         with self.cursor() as cursor:
             f_globals = None
-
             f_locals = frame.f_locals
-            context = dict(**f_locals)
+
+            if context == None:
+                context = dict(**f_locals)
 
             start = 0
             while True:
@@ -71,8 +87,10 @@ class connect(object):
                 next = statement[percent + 1]
                 if next == '(':
                     start = statement.index(')', percent + 2) + 2
+                    assert statement[start - 1] == 's'
                 elif next == '{':
                     start = statement.index('}', percent + 2)
+                    assert statement[start + 1] == 's'
                     code = statement[percent + 2:start]
 
                     if f_globals == None:
@@ -82,9 +100,9 @@ class connect(object):
                     # XXX: compile() in the frame's context
                     context[key] = eval(code, f_globals, f_locals)
 
-                    statement = '%s%%(%s)s%s' % (statement[0:percent], key, statement[start + 1:])
+                    statement = '%s%%(%s)%s' % (statement[0:percent], key, statement[start + 1:])
                     start = percent + len(key) + 4
-                elif next == '%':
+                elif next in ('%', 's'):
                     start = percent + 2
                 else:
                     assert False
@@ -99,7 +117,7 @@ class connect(object):
 
     @contextmanager
     def transact(self, synchronous_commit=True):
-        self.driver.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED)
+        self.driver.autocommit = False
         try:
             with self.cursor() as cursor:
                 if not synchronous_commit:
@@ -111,10 +129,10 @@ class connect(object):
             self.driver.rollback()
             raise
         finally:
-            self.driver.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)
+            self.driver.autocommit = True
 
-    def one_(self, statement):
-        with self.execute(statement, 2) as cursor:
+    def one_(self, statement, context=None):
+        with self.execute(statement, 2, context) as cursor:
             one = cursor.fetchone()
             if one == None:
                 return None
@@ -126,8 +144,8 @@ class connect(object):
         with self.execute(statement, 1) as cursor:
             return cursor.callproc(procedure, *parameters)
 
-    def run(self, statement):
-        with self.execute(statement, 1) as cursor:
+    def run(self, statement, context=None):
+        with self.execute(statement, 1, context) as cursor:
             return cursor.rowcount
 
     @contextmanager
@@ -135,12 +153,12 @@ class connect(object):
         with self.execute(statement, 1) as cursor:
             yield cursor
 
-    def all(self, statement):
-        with self.execute(statement, 1) as cursor:
+    def all(self, statement, context=None):
+        with self.execute(statement, 1, context) as cursor:
             return cursor.fetchall()
 
-    def one(self, statement):
-        return self.one_(statement)
+    def one(self, statement, context=None):
+        return self.one_(statement, context)
 
     def has(self, statement):
         exists, = self.one_('select exists(%s)' % (statement,))