]> git.saurik.com Git - redis.git/commitdiff
sync python client to the new protocol
authorLudovico Magnocavallo <ludo@ankh.qix.it>
Tue, 24 Mar 2009 13:30:04 +0000 (14:30 +0100)
committerLudovico Magnocavallo <ludo@ankh.qix.it>
Tue, 24 Mar 2009 13:30:04 +0000 (14:30 +0100)
client-libraries/python/redis.py

index e844f812b05e348aa470fbdcad2219446e93308a..6187cf9d42de4ba0651ffcdb9157ff0e11fbeb5b 100644 (file)
@@ -83,7 +83,7 @@ class Redis(object):
         """
         self.connect()
         self._write('PING\r\n')
-        return self._get_simple_response()
+        return self.get_response()
     
     def set(self, name, value, preserve=False):
         """
@@ -114,7 +114,7 @@ class Redis(object):
             ))
         except UnicodeEncodeError, e:
             raise InvalidData("Error encoding unicode value for key '%s': %s." % (name, e))
-        return self._get_numeric_response() if preserve else self._get_simple_response()
+        return self.get_response()
     
     def get(self, name):
         """
@@ -138,7 +138,20 @@ class Redis(object):
         """
         self.connect()
         self._write('GET %s\r\n' % name)
-        return self._get_value()
+        return self.get_response()
+    
+    def mget(self, *args):
+        """
+        >>> r = Redis()
+        >>> r.set('a', 'pippo'), r.set('b', 15), r.set('c', '\\r\\naaa\\nbbb\\r\\ncccc\\nddd\\r\\n'), r.set('d', '\\r\\n')
+        ('OK', 'OK', 'OK', 'OK')
+        >>> r.mget('a', 'b', 'c', 'd')
+        ['pippo', '15', '\\r\\naaa\\nbbb\\r\\ncccc\\nddd\\r\\n', '\\r\\n']
+        >>> 
+        """
+        self.connect()
+        self._write('MGET %s\r\n' % ' '.join(args))
+        return self.get_response()
     
     def incr(self, name, amount=1):
         """
@@ -158,7 +171,7 @@ class Redis(object):
             self._write('INCR %s\r\n' % name)
         else:
             self._write('INCRBY %s %s\r\n' % (name, amount))
-        return self._get_numeric_response()
+        return self.get_response()
 
     def decr(self, name, amount=1):
         """
@@ -181,7 +194,7 @@ class Redis(object):
             self._write('DECR %s\r\n' % name)
         else:
             self._write('DECRBY %s %s\r\n' % (name, amount))
-        return self._get_numeric_response()
+        return self.get_response()
     
     def exists(self, name):
         """
@@ -196,7 +209,7 @@ class Redis(object):
         """
         self.connect()
         self._write('EXISTS %s\r\n' % name)
-        return self._get_numeric_response()
+        return self.get_response()
 
     def delete(self, name):
         """
@@ -215,7 +228,7 @@ class Redis(object):
         """
         self.connect()
         self._write('DEL %s\r\n' % name)
-        return self._get_numeric_response()
+        return self.get_response()
 
     def key_type(self, name):
         """
@@ -223,7 +236,7 @@ class Redis(object):
         """
         self.connect()
         self._write('TYPE %s\r\n' % name)
-        return self._get_simple_response()
+        return self.get_response()
     
     def keys(self, pattern):
         """
@@ -246,7 +259,7 @@ class Redis(object):
         """
         self.connect()
         self._write('KEYS %s\r\n' % pattern)
-        return self._get_value().split()
+        return self.get_response().split()
     
     def randomkey(self):
         """
@@ -260,9 +273,7 @@ class Redis(object):
         #raise NotImplementedError("Implemented but buggy, do not use.")
         self.connect()
         self._write('RANDOMKEY\r\n')
-        data = self._read().strip()
-        self._check_for_error(data)
-        return data
+        return self.get_response()
     
     def rename(self, src, dst, preserve=False):
         """
@@ -271,7 +282,7 @@ class Redis(object):
         ...     r.rename('a', 'a')
         ... except ResponseError, e:
         ...     print e
-        src and dest key are the same
+        source and destination objects are the same
         >>> r.rename('a', 'b')
         'OK'
         >>> try:
@@ -288,10 +299,10 @@ class Redis(object):
         self.connect()
         if preserve:
             self._write('RENAMENX %s %s\r\n' % (src, dst))
-            return self._get_numeric_response()
+            return self.get_response()
         else:
             self._write('RENAME %s %s\r\n' % (src, dst))
-            return self._get_simple_response().strip()
+            return self.get_response().strip()
     
     def push(self, name, value, tail=False):
         """
@@ -318,7 +329,7 @@ class Redis(object):
             ))
         except UnicodeEncodeError, e:
             raise InvalidData("Error encoding unicode value for element in list '%s': %s." % (name, e))
-        return self._get_simple_response()
+        return self.get_response()
     
     def llen(self, name):
         """
@@ -337,7 +348,7 @@ class Redis(object):
         """
         self.connect()
         self._write('LLEN %s\r\n' % name)
-        return self._get_numeric_response()
+        return self.get_response()
 
     def lrange(self, name, start, end):
         """
@@ -364,7 +375,7 @@ class Redis(object):
         """
         self.connect()
         self._write('LRANGE %s %s %s\r\n' % (name, start, end))
-        return self._get_multi_response()
+        return self.get_response()
         
     def ltrim(self, name, start, end):
         """
@@ -394,7 +405,7 @@ class Redis(object):
         """
         self.connect()
         self._write('LTRIM %s %s %s\r\n' % (name, start, end))
-        return self._get_simple_response()
+        return self.get_response()
     
     def lindex(self, name, index):
         """
@@ -416,7 +427,7 @@ class Redis(object):
         """
         self.connect()
         self._write('LINDEX %s %s\r\n' % (name, index))
-        return self._get_value()
+        return self.get_response()
         
     def pop(self, name, tail=False):
         """
@@ -446,7 +457,7 @@ class Redis(object):
         """
         self.connect()
         self._write('%s %s\r\n' % ('RPOP' if tail else 'LPOP', name))
-        return self._get_value()
+        return self.get_response()
     
     def lset(self, name, index, value):
         """
@@ -479,7 +490,7 @@ class Redis(object):
             ))
         except UnicodeEncodeError, e:
             raise InvalidData("Error encoding unicode value for element %s in list '%s': %s." % (index, name, e))
-        return self._get_simple_response()
+        return self.get_response()
     
     def lrem(self, name, value, num=0):
         """
@@ -516,7 +527,7 @@ class Redis(object):
             ))
         except UnicodeEncodeError, e:
             raise InvalidData("Error encoding unicode value for element %s in list '%s': %s." % (index, name, e))
-        return self._get_numeric_response()
+        return self.get_response()
     
     def sort(self, name, by=None, get=None, start=None, num=None, desc=False, alpha=False):
         """
@@ -577,7 +588,7 @@ class Redis(object):
             stmt.append("ALPHA")
         self.connect()
         self._write(' '.join(stmt + ["\r\n"]))
-        return self._get_multi_response()
+        return self.get_response()
     
     def sadd(self, name, value):
         """
@@ -598,7 +609,7 @@ class Redis(object):
             ))
         except UnicodeEncodeError, e:
             raise InvalidData("Error encoding unicode value for element in set '%s': %s." % (name, e))
-        return self._get_numeric_response()
+        return self.get_response()
         
     def srem(self, name, value):
         """
@@ -624,7 +635,7 @@ class Redis(object):
             ))
         except UnicodeEncodeError, e:
             raise InvalidData("Error encoding unicode value for element in set '%s': %s." % (name, e))
-        return self._get_numeric_response()
+        return self.get_response()
     
     def sismember(self, name, value):
         """
@@ -650,7 +661,7 @@ class Redis(object):
             ))
         except UnicodeEncodeError, e:
             raise InvalidData("Error encoding unicode value for element in set '%s': %s." % (name, e))
-        return self._get_numeric_response()
+        return self.get_response()
     
     def sinter(self, *args):
         """
@@ -682,7 +693,7 @@ class Redis(object):
         """
         self.connect()
         self._write('SINTER %s\r\n' % ' '.join(args))
-        return set(self._get_multi_response())
+        return set(self.get_response())
     
     def sinterstore(self, dest, *args):
         """
@@ -706,7 +717,7 @@ class Redis(object):
         """
         self.connect()
         self._write('SINTERSTORE %s %s\r\n' % (dest, ' '.join(args)))
-        return self._get_simple_response()
+        return self.get_response()
 
     def smembers(self, name):
         """
@@ -728,7 +739,7 @@ class Redis(object):
         """
         self.connect()
         self._write('SMEMBERS %s\r\n' % name)
-        return set(self._get_multi_response())
+        return set(self.get_response())
 
     def select(self, db):
         """
@@ -746,7 +757,7 @@ class Redis(object):
         """
         self.connect()
         self._write('SELECT %s\r\n' % db)
-        return self._get_simple_response()
+        return self.get_response()
     
     def move(self, name, db):
         """
@@ -777,7 +788,7 @@ class Redis(object):
         """
         self.connect()
         self._write('MOVE %s %s\r\n' % (name, db))
-        return self._get_numeric_response()
+        return self.get_response()
     
     def save(self, background=False):
         """
@@ -797,7 +808,7 @@ class Redis(object):
             self._write('BGSAVE\r\n')
         else:
             self._write('SAVE\r\n')
-        return self._get_simple_response()
+        return self.get_response()
         
     def lastsave(self):
         """
@@ -812,7 +823,7 @@ class Redis(object):
         """
         self.connect()
         self._write('LASTSAVE\r\n')
-        return self._get_numeric_response()
+        return self.get_response()
     
     def flush(self, all_dbs=False):
         """
@@ -825,76 +836,71 @@ class Redis(object):
         """
         self.connect()
         self._write('%s\r\n' % ('FLUSHALL' if all_dbs else 'FLUSHDB'))
-        return self._get_simple_response()
+        return self.get_response()
+    
+    def info(self):
+        """
+        >>> r = Redis()
+        >>> info = r.info()
+        >>> info and isinstance(info, dict)
+        True
+        >>> isinstance(info.get('connected_clients'), int)
+        True
+        >>> 
+        """
+        self.connect()
+        self._write('INFO\r\n')
+        info = dict()
+        for l in self.get_response().split('\r\n'):
+            if not l:
+                continue
+            k, v = l.split(':', 1)
+            info[k] = int(v) if v.isdigit() else v
+        return info
     
-    def _get_value(self, negative_as_nil=False):
+    def get_response(self):
         data = self._read().strip()
-        if data == 'nil' or (negative_as_nil and data == '-1'):
-            return
-        elif data[0] == '-':
-            self._check_for_error(data)
+        c = data[0]
+        if c == '-':
+            raise ResponseError(data[5:] if data[:5] == '-ERR ' else data[1:])
+        if c == '+':
+            return data[1:]
+        if c == '*':
+            try:
+                num = int(data[1:])
+            except (TypeError, ValueError):
+                raise InvalidResponse("Cannot convert multi-response header '%s' to integer" % data)
+            result = list()
+            for i in range(num):
+                result.append(self._get_value())
+            return result
+        return self._get_value(data)
+    
+    def _get_value(self, data=None):
+        data = data or self._read().strip()
+        if data == '$-1':
+            return None
         try:
-            l = int(data)
-        except (TypeError, ValueError):
-            raise ResponseError("Cannot parse response '%s' as data length." % data)
+            c, i = data[0], (int(data[1:]) if data.find('.') == -1 else float(data[1:]))
+        except ValueError:
+            raise InvalidResponse("Cannot convert data '%s' to integer" % data)
+        if c == ':':
+            return i
+        if c != '$':
+            raise InvalidResponse("Unkown response prefix for '%s'" % data)
         buf = []
-        while l > 0:
+        while i > 0:
             data = self._read()
-            l -= len(data)
-            if len(data) > l:
+            i -= len(data)
+            if len(data) > i:
                 # we got the ending crlf
                 data = data.rstrip()
             buf.append(data)
-        if l == 0:
+        if i == 0:
             # the data has a trailing crlf embedded, let's restore it
             buf.append(self._read())
         return ''.join(buf)
     
-    def _get_simple_response(self):
-        data = self._read().strip()
-        if data[0] == '+':
-            return data[1:]
-        self._check_for_error(data)
-        raise InvalidResponse("Cannot parse first line '%s' for a simple response." % data, data)
-
-    def _get_numeric_response(self, allow_negative=True):
-        data = self._read().strip()
-        try:
-            value = int(data)
-        except (TypeError, ValueError), e:
-            pass
-        else:
-            if not allow_negative and value < 0:
-                self._check_for_error(data)
-            return value
-        self._check_for_error(data)
-        raise InvalidResponse("Cannot parse first line '%s' for a numeric response: %s." % (data, e), data)
-        
-    def _get_multi_response(self):
-        results = list()
-        try:
-            num = self._get_numeric_response(allow_negative=False)
-        except InvalidResponse, e:
-            if e.args[1] == 'nil':
-                return results
-            raise
-        while num:
-            results.append(self._get_value(negative_as_nil=True))
-            num -= 1
-        return results
-        
-    def _check_for_error(self, data):
-        if not data or data[0] != '-':
-            return
-        if data.startswith('-ERR'):
-            raise ResponseError(data[4:].strip())
-        try:
-            error_len = int(data[1:])
-        except (TypeError, ValueError):
-            raise ResponseError("Unknown error format '%s'." % data)
-        error_message = self._read().strip()[5:]
-        raise ResponseError(error_message)
-        
     def disconnect(self):
         if isinstance(self._sock, socket.socket):
             try: