]> git.saurik.com Git - cyql.git/blobdiff - __init__.py
Commit a path to access result rows via generator.
[cyql.git] / __init__.py
index 73aa9ad650bbee577bdb280e5992721fa4175900..d40d9b1eb57869e0bf9efb42cc241fae9a777f73 100644 (file)
@@ -17,23 +17,60 @@ import psycopg2.pool
 psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
 psycopg2.extensions.register_type(psycopg2.extensions.UNICODEARRAY)
 
+class ConnectionError():
+    pass
+
 class connect(object):
     def __init__(self, dsn):
+        options = dsn.copy()
+        if 'cache' in options:
+            del options['cache']
+
+        if 'cache' in dsn:
+            cached = True
+            cache = dsn['cache']
+        else:
+            cached = False
+            cache = {
+                'hstore': None,
+            }
+
         attempt = 0
         while True:
             try:
-                self.driver = psycopg2.connect(**dsn)
+                self.driver = psycopg2.connect(**options)
                 break
             except psycopg2.OperationalError, e:
+                if e.message.startswith('could not connect to server: '):
+                    raise ConnectionError()
                 if attempt == 2:
-                    raise e
+                    raise
                 attempt = attempt + 1
 
-        try:
-            self.driver.set_client_encoding('UNICODE')
-            self.driver.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)
-        except:
-            self.driver.close()
+        self.driver.autocommit = True
+
+        # XXX: all of my databases default to this...
+        #try:
+        #    self.driver.set_client_encoding('UNICODE')
+        #except:
+        #    self.driver.close()
+        #    raise
+
+        hstore = cache['hstore']
+        if hstore == None:
+            hstore = psycopg2.extras.HstoreAdapter.get_oids(self.driver)
+            if hstore != None:
+                hstore = hstore[0]
+                cache['hstore'] = hstore
+
+        if hstore != None:
+            try:
+                psycopg2.extras.register_hstore(self.driver, globally=False, unicode=True, oid=hstore)
+            except psycopg2.ProgrammingError, e:
+                pass
+
+        if not cached:
+            dsn['cache'] = cache
 
     def close(self):
         self.driver.close()
@@ -44,6 +81,15 @@ class connect(object):
     def __exit__(self, type, value, traceback):
         self.close()
 
+    def begin(self):
+        self.driver.autocommit = False
+
+    def commit(self):
+        self.driver.commit()
+
+    def rollback(self):
+        self.driver.rollback()
+
     @contextmanager
     def cursor(self):
         cursor = self.driver.cursor(cursor_factory=psycopg2.extras.DictCursor)
@@ -103,7 +149,7 @@ class connect(object):
 
     @contextmanager
     def transact(self, synchronous_commit=True):
-        self.driver.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED)
+        self.driver.autocommit = False
         try:
             with self.cursor() as cursor:
                 if not synchronous_commit:
@@ -115,7 +161,7 @@ class connect(object):
             self.driver.rollback()
             raise
         finally:
-            self.driver.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)
+            self.driver.autocommit = True
 
     def one_(self, statement, context=None):
         with self.execute(statement, 2, context) as cursor:
@@ -134,13 +180,21 @@ class connect(object):
         with self.execute(statement, 1, context) as cursor:
             return cursor.rowcount
 
+    def gen(self, statement):
+        with self.execute(statement, 1) as cursor:
+            while True:
+                fetch = cursor.fetchone()
+                if fetch == None:
+                    break
+                yield fetch
+
     @contextmanager
     def set(self, statement):
-        with self.execute(statement, 1) as cursor:
+        with self.execute(statement, 2) as cursor:
             yield cursor
 
-    def all(self, statement):
-        with self.execute(statement, 1) as cursor:
+    def all(self, statement, context=None):
+        with self.execute(statement, 1, context) as cursor:
             return cursor.fetchall()
 
     def one(self, statement, context=None):