changeset 74:80164eb2f090 MySQLdb

This passes all test, yet is still broken and ugly in many ways. However, a lot of ugliness has been removed.
author adustman
date Sat, 20 Feb 2010 04:27:21 +0000
parents 24fa6a40c706
children 3b03cb566032
files MySQLdb/connections.py MySQLdb/converters.py MySQLdb/cursors.py tests/test_MySQLdb_capabilities.py
diffstat 4 files changed, 112 insertions(+), 77 deletions(-) [+]
line wrap: on
line diff
--- a/MySQLdb/connections.py	Fri Feb 19 02:21:11 2010 +0000
+++ b/MySQLdb/connections.py	Sat Feb 20 04:27:21 2010 +0000
@@ -138,6 +138,7 @@
         if 'decoder_stack' not in kwargs2:
             kwargs2['decoder_stack'] = default_decoders;
         self.encoders = kwargs2.pop('encoders', default_encoders)
+        self.decoders = kwargs2.pop('decoders', default_decoders)
         
         client_flag = kwargs.get('client_flag', 0)
         client_version = tuple(
@@ -167,6 +168,7 @@
             # PEP-249 requires autocommit to be initially off
             self.autocommit(False)
         self.messages = []
+        self._active_cursor = None
     
     def autocommit(self, do_autocommit):
         self._autocommit = do_autocommit
@@ -192,16 +194,23 @@
     def string_literal(self, s):
         return self._db.string_literal(s)
     
-    def cursor(self, encoders=None):
+    def cursor(self, encoders=None, decoders=None):
         """
         Create a cursor on which queries may be performed. The optional
         cursorclass parameter is used to create the Cursor. By default,
         self.cursorclass=cursors.Cursor is used.
         """
+        if self._active_cursor:
+            self._active_cursor._flush()
+            
         if not encoders:
             encoders = self.encoders[:]
         
-        return self.cursorclass(self, encoders)
+        if not decoders:
+            decoders = self.decoders[:]
+            
+        self._active_cursor = self.cursorclass(self, encoders, decoders)
+        return self._active_cursor
 
     def __enter__(self):
         return self.cursor()
--- a/MySQLdb/converters.py	Fri Feb 19 02:21:11 2010 +0000
+++ b/MySQLdb/converters.py	Sat Feb 20 04:27:21 2010 +0000
@@ -15,6 +15,7 @@
 import array
 import datetime
 from decimal import Decimal
+from itertools import izip
 
 __revision__ = "$Revision$"[11:-2]
 __author__ = "$Author$"[9:-2]
@@ -143,6 +144,8 @@
     
     charset = field.result.connection.character_set_name()
     def char_to_unicode(s):
+        if s is None:
+            return s
         return s.decode(charset)
     
     return char_to_unicode
@@ -158,8 +161,25 @@
     default_encoder,
     ]
 
+def get_codec(field, codecs):
+    for c in codecs:
+        func = c(field)
+        if func:
+            return func
+    # the default codec is guaranteed to work
+
+def iter_row_decoder(decoders, row):
+    if row is None:
+        return None
+    return ( d(col) for d, col in izip(decoders, row) )
+
+def tuple_row_decoder(decoders, row):
+    if row is None:
+        return None
+    return tuple(iter_row_decoder(decoders, row))
 
 
 
 
 
+
--- a/MySQLdb/cursors.py	Fri Feb 19 02:21:11 2010 +0000
+++ b/MySQLdb/cursors.py	Sat Feb 20 04:27:21 2010 +0000
@@ -13,6 +13,7 @@
 import re
 import sys
 import weakref
+from MySQLdb.converters import get_codec, tuple_row_decoder
 
 INSERT_VALUES = re.compile(r"(?P<start>.+values\s*)"
                            r"(?P<values>\(((?<!\\)'[^\)]*?\)[^\)]*(?<!\\)?'|[^\(\)]|(?:\([^\)]*\)))+\))"
@@ -39,8 +40,7 @@
     _defer_warnings = False
     _fetch_type = None
 
-    def __init__(self, connection, encoders):
-        from MySQLdb.converters import default_decoders
+    def __init__(self, connection, encoders, decoders):
         self.connection = weakref.proxy(connection)
         self.description = None
         self.description_flags = None
@@ -54,17 +54,41 @@
         self._warnings = 0
         self._info = None
         self.rownumber = None
-        self._encoders = encoders
+        self.maxrows = 0
+        self.encoders = encoders
+        self.decoders = decoders
+        self._row_decoders = ()
+        self.row_decoder = tuple_row_decoder
 
+    def _flush(self):
+        """_flush() reads to the end of the current result set, buffering what
+        it can, and then releases the result set."""
+        if self._result:
+            for row in self._result:
+                pass
+            self._result = None
+    
     def __del__(self):
         self.close()
         self.errorhandler = None
         self._result = None
 
+    def _reset(self):
+        while True:
+            if self._result:
+                for row in self._result:
+                    pass
+                self._result = None
+            if not self.nextset():
+                break
+        del self.messages[:]
+            
     def close(self):
         """Close the cursor. No further queries will be possible."""
         if not self.connection:
             return
+        
+        self._flush()
         try:
             while self.nextset():
                 pass
@@ -106,22 +130,21 @@
         num_rows = connection.next_result()
         if num_rows == -1:
             return None
-        self._do_get_result()
-        self._post_get_result()
-        self._warning_check()
-        return True
-
-    def _do_get_result(self):
-        """Get the result from the last query."""
-        connection = self._get_db()
-        self._result = self._get_result()
-        self.rowcount = connection.affected_rows()
+        result = connection.use_result()
+        self._result = result
+        if result:
+            self.field_flags = result.field_flags()
+            self._row_decoders = [ get_codec(field, self.decoders) for field in result.fields ]
+            self.description = result.describe()
+        else:
+            self._row_decoders = self.field_flags = ()
+            self.description = None
+        self.rowcount = -1 #connection.affected_rows()
         self.rownumber = 0
-        self.description = self._result and self._result.describe() or None
-        self.description_flags = self._result and self._result.field_flags() or None
         self.lastrowid = connection.insert_id()
         self._warnings = connection.warning_count()
-        self._info = connection.info()
+        self._info = connection.info()        
+        return True
     
     def setinputsizes(self, *args):
         """Does nothing, required by DB API."""
@@ -150,15 +173,15 @@
         Returns long integer rows affected, if any
 
         """
-        del self.messages[:]
         db = self._get_db()
+        self._reset()
         charset = db.character_set_name()
         if isinstance(query, unicode):
             query = query.encode(charset)
         try:
             if args is not None:
                 query = query % tuple(map(self.connection.literal, args))
-            result = self._query(query)
+            self._query(query)
         except TypeError, msg:
             if msg.args[0] in ("not enough arguments for format string",
                                "not all arguments converted"):
@@ -173,10 +196,9 @@
             self.messages.append((exc, value))
             self.errorhandler(self, exc, value)
             
-        self._executed = query
         if not self._defer_warnings:
             self._warning_check()
-        return result
+        return None
 
     def executemany(self, query, args):
         """Execute a multi-row query.
@@ -197,8 +219,8 @@
         execute().
 
         """
-        del self.messages[:]
         db = self._get_db()
+        self._reset()
         if not args:
             return
         charset = self.connection.character_set_name()
@@ -216,8 +238,7 @@
         try:
             sql_params = ( values % tuple(map(self.connection.literal, row)) for row in args )
             multirow_query = '\n'.join([start, ',\n'.join(sql_params), end])
-            self._executed = multirow_query
-            self.rowcount = int(self._query(multirow_query))
+            self._query(multirow_query)
 
         except TypeError, msg:
             if msg.args[0] in ("not enough arguments for format string",
@@ -234,7 +255,7 @@
         
         if not self._defer_warnings:
             self._warning_check()
-        return self.rowcount
+        return None
     
     def callproc(self, procname, args=()):
         """Execute stored procedure procname with args
@@ -283,71 +304,62 @@
         if isinstance(query, unicode):
             query = query.encode(charset)
         self._query(query)
-        self._executed = query
         if not self._defer_warnings:
             self._warning_check()
         return args
-    
-    def _do_query(self, query):
-        """Low-levey query wrapper. Overridden by MixIns."""
-        connection = self._get_db()
-        self._executed = query
-        connection.query(query)
-        self._do_get_result()
-        return self.rowcount
-
-    def _fetch_row(self, size=1):
-        """Low-level fetch_row wrapper."""
-        if not self._result:
-            return ()
-        return self._result.fetch_row(size, self._fetch_type)
 
     def __iter__(self):
         return iter(self.fetchone, None)
 
-    def _get_result(self):
-        """Low-level; uses mysql_store_result()"""
-        return self._get_db().store_result()
-
     def _query(self, query):
-        """Low-level; executes query, gets result, and returns rowcount."""
-        rowcount = self._do_query(query)
-        self._post_get_result()
-        return rowcount
-
-    def _post_get_result(self):
-        """Low-level"""
-        self._rows = self._fetch_row(0)
-        self._result = None
-
+        """Low-level; executes query, gets result, sets up decoders."""
+        connection = self._get_db()
+        self._flush()
+        self._executed = query
+        connection.query(query)
+        result = connection.use_result()
+        self._result = result
+        if result:
+            self.field_flags = result.field_flags()
+            self._row_decoders = [ get_codec(field, self.decoders) for field in result.fields ]
+            self.description = result.describe()
+        else:
+            self._row_decoders = self.field_flags = ()
+            self.description = None
+        self.rowcount = -1 #connection.affected_rows()
+        self.rownumber = 0
+        self.lastrowid = connection.insert_id()
+        self._warnings = connection.warning_count()
+        self._info = connection.info()
+    
     def fetchone(self):
         """Fetches a single row from the cursor. None indicates that
         no more rows are available."""
         self._check_executed()
-        if self.rownumber >= len(self._rows):
-            return None
-        result = self._rows[self.rownumber]
-        self.rownumber += 1
-        return result
+        row = self.row_decoder(self._row_decoders, self._result.simple_fetch_row())
+        return row
 
     def fetchmany(self, size=None):
         """Fetch up to size rows from the cursor. Result set may be smaller
         than size. If size is not defined, cursor.arraysize is used."""
         self._check_executed()
-        end = self.rownumber + (size or self.arraysize)
-        result = self._rows[self.rownumber:end]
-        self.rownumber = min(end, len(self._rows))
-        return result
+        if size is None:
+            size = self.arraysize
+        rows = []
+        for i in range(size):
+            row = self.row_decoder(self._row_decoders, self._result.simple_fetch_row())
+            if row is None: break
+            rows.append(row)
+        return rows
 
     def fetchall(self):
-        """Fetchs all available rows from the cursor."""
+        """Fetches all available rows from the cursor."""
         self._check_executed()
-        if self.rownumber:
-            result = self._rows[self.rownumber:]
+        if self._result:
+            rows = [ self.row_decoder(self._row_decoders, row) for row in self._result ]
         else:
-            result = self._rows
-        self.rownumber = len(self._rows)
-        return result
+            rows = []
+        return rows
     
     def scroll(self, value, mode='relative'):
         """Scroll the cursor in the result set to a new position according
@@ -368,9 +380,3 @@
             self.errorhandler(self, IndexError, "out of range")
         self.rownumber = row
 
-    def __iter__(self):
-        self._check_executed()
-        result = self.rownumber and self._rows[self.rownumber:] or self._rows
-        return iter(result)
-    
-    _fetch_type = 0
--- a/tests/test_MySQLdb_capabilities.py	Fri Feb 19 02:21:11 2010 +0000
+++ b/tests/test_MySQLdb_capabilities.py	Sat Feb 20 04:27:21 2010 +0000
@@ -13,7 +13,7 @@
     connect_kwargs = dict(db='test', read_default_file='~/.my.cnf',
                           charset='utf8', sql_mode="ANSI,STRICT_TRANS_TABLES,TRADITIONAL")
     create_table_extra = "ENGINE=INNODB CHARACTER SET UTF8"
-    leak_test = True
+    leak_test = False
     
     def quote_identifier(self, ident):
         return "`%s`" % ident