]> git.saurik.com Git - cyql.git/blobdiff - __init__.py
Move back to Python 2.6 and fix up all of the naming.
[cyql.git] / __init__.py
index 36548c263c6f984c71226f6d4a79f3c3ed3ffd17..6eafaee13b6d2f4e60736a1540231bca0577d748 100644 (file)
@@ -1,5 +1,5 @@
-#from __future__ import unicode_literals
-#from __future__ import print_function
+from __future__ import unicode_literals
+from __future__ import print_function
 
 import inspect
 import os
@@ -7,55 +7,66 @@ import os
 from contextlib import contextmanager
 
 import psycopg2
+import psycopg2.extras
 import psycopg2.pool
 
-from psycopg2.extras import DictCursor
-
 psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
 
-def Cursor(sql):
-    return sql.cursor(cursor_factory=DictCursor)
+class connection(object):
+    def __init__(self, driver):
+        self.driver = driver
+
+    @contextmanager
+    def cursor(self):
+        try:
+            cursor = self.driver.cursor(cursor_factory=psycopg2.extras.DictCursor)
+            yield cursor
+        finally:
+            cursor.close()
+
+    @contextmanager
+    def execute(self, statement, depth=0):
+        with self.cursor() as cursor:
+            locals = inspect.currentframe(depth + 1).f_locals
+            cursor.execute(statement.format(**locals), locals)
+            yield cursor
 
-class Transaction(object):
+    @contextmanager
+    def transact(self, synchronous_commit=True):
+        with self.cursor() as cursor:
+            if not synchronous_commit:
+                cursor.execute('set local synchronous_commit = off')
+
+        try:
+            yield transaction(self)
+            self.driver.commit()
+        except:
+            self.driver.rollback()
+            raise
+
+class transaction(object):
     def __init__(self, connection):
         self.connection = connection
 
     def pull(self, statement):
-        locals = inspect.currentframe(1).f_locals
-        cursor = Cursor(self.connection)
-
-        try:
-            cursor.execute(statement.format(**locals), locals)
+        with self.connection.execute(statement, 1) as cursor:
             return cursor.fetchall()
-        finally:
-            cursor.close()
 
     def yank(self, statement):
-        locals = inspect.currentframe(1).f_locals
-        cursor = Cursor(self.connection)
-
-        try:
-            cursor.execute(statement.format(**locals), locals)
+        with self.connection.execute(statement, 1) as cursor:
             rows = cursor.fetchall()
             return rows[0] if len(rows) != 0 else None
-        finally:
-            cursor.close()
 
     def push(self, statement):
-        locals = inspect.currentframe(1).f_locals
-        cursor = Cursor(self.connection)
-
-        try:
-            cursor.execute(statement.format(**locals), locals)
-        finally:
-            cursor.close()
+        with self.connection.execute(statement, 1) as cursor:
+            pass
 
 @contextmanager
-def ConnectSQL(dsn):
+def connect(dsn):
     attempt = 0
     while True:
         try:
-            sql = psycopg2.connect(**dsn)
+            driver = psycopg2.connect(**dsn)
             break
         except psycopg2.OperationalError, e:
             if attempt == 2:
@@ -63,21 +74,12 @@ def ConnectSQL(dsn):
             attempt = attempt + 1
 
     try:
-        sql.set_client_encoding('UNICODE')
-
-        @contextmanager
-        def transact():
-            try:
-                yield Transaction(sql)
-                sql.commit()
-            except:
-                sql.rollback()
-                raise
-
-        yield transact
+        driver.set_client_encoding('UNICODE')
+        yield connection(driver)
     finally:
-        sql.close()
+        driver.close()
 
+"""
 def slap_(sql, table, keys, values, path):
     csr = sql.cursor()
     try:
@@ -111,3 +113,4 @@ def slap_(sql, table, keys, values, path):
         return path_(csr, path)
     finally:
         csr.close()
+"""