Mercurial > p > mysql-python > mysqldb-2
changeset 74:80164eb2f090 MySQLdb
This passes all test, yet is still broken and ugly in many ways.
However, a lot of ugliness has been removed.
author | adustman |
---|---|
date | Sat, 20 Feb 2010 04:27:21 +0000 |
parents | 24fa6a40c706 |
children | 3b03cb566032 |
files | MySQLdb/connections.py MySQLdb/converters.py MySQLdb/cursors.py tests/test_MySQLdb_capabilities.py |
diffstat | 4 files changed, 112 insertions(+), 77 deletions(-) [+] |
line wrap: on
line diff
--- a/MySQLdb/connections.py Fri Feb 19 02:21:11 2010 +0000 +++ b/MySQLdb/connections.py Sat Feb 20 04:27:21 2010 +0000 @@ -138,6 +138,7 @@ if 'decoder_stack' not in kwargs2: kwargs2['decoder_stack'] = default_decoders; self.encoders = kwargs2.pop('encoders', default_encoders) + self.decoders = kwargs2.pop('decoders', default_decoders) client_flag = kwargs.get('client_flag', 0) client_version = tuple( @@ -167,6 +168,7 @@ # PEP-249 requires autocommit to be initially off self.autocommit(False) self.messages = [] + self._active_cursor = None def autocommit(self, do_autocommit): self._autocommit = do_autocommit @@ -192,16 +194,23 @@ def string_literal(self, s): return self._db.string_literal(s) - def cursor(self, encoders=None): + def cursor(self, encoders=None, decoders=None): """ Create a cursor on which queries may be performed. The optional cursorclass parameter is used to create the Cursor. By default, self.cursorclass=cursors.Cursor is used. """ + if self._active_cursor: + self._active_cursor._flush() + if not encoders: encoders = self.encoders[:] - return self.cursorclass(self, encoders) + if not decoders: + decoders = self.decoders[:] + + self._active_cursor = self.cursorclass(self, encoders, decoders) + return self._active_cursor def __enter__(self): return self.cursor()
--- a/MySQLdb/converters.py Fri Feb 19 02:21:11 2010 +0000 +++ b/MySQLdb/converters.py Sat Feb 20 04:27:21 2010 +0000 @@ -15,6 +15,7 @@ import array import datetime from decimal import Decimal +from itertools import izip __revision__ = "$Revision$"[11:-2] __author__ = "$Author$"[9:-2] @@ -143,6 +144,8 @@ charset = field.result.connection.character_set_name() def char_to_unicode(s): + if s is None: + return s return s.decode(charset) return char_to_unicode @@ -158,8 +161,25 @@ default_encoder, ] +def get_codec(field, codecs): + for c in codecs: + func = c(field) + if func: + return func + # the default codec is guaranteed to work + +def iter_row_decoder(decoders, row): + if row is None: + return None + return ( d(col) for d, col in izip(decoders, row) ) + +def tuple_row_decoder(decoders, row): + if row is None: + return None + return tuple(iter_row_decoder(decoders, row)) +
--- a/MySQLdb/cursors.py Fri Feb 19 02:21:11 2010 +0000 +++ b/MySQLdb/cursors.py Sat Feb 20 04:27:21 2010 +0000 @@ -13,6 +13,7 @@ import re import sys import weakref +from MySQLdb.converters import get_codec, tuple_row_decoder INSERT_VALUES = re.compile(r"(?P<start>.+values\s*)" r"(?P<values>\(((?<!\\)'[^\)]*?\)[^\)]*(?<!\\)?'|[^\(\)]|(?:\([^\)]*\)))+\))" @@ -39,8 +40,7 @@ _defer_warnings = False _fetch_type = None - def __init__(self, connection, encoders): - from MySQLdb.converters import default_decoders + def __init__(self, connection, encoders, decoders): self.connection = weakref.proxy(connection) self.description = None self.description_flags = None @@ -54,17 +54,41 @@ self._warnings = 0 self._info = None self.rownumber = None - self._encoders = encoders + self.maxrows = 0 + self.encoders = encoders + self.decoders = decoders + self._row_decoders = () + self.row_decoder = tuple_row_decoder + def _flush(self): + """_flush() reads to the end of the current result set, buffering what + it can, and then releases the result set.""" + if self._result: + for row in self._result: + pass + self._result = None + def __del__(self): self.close() self.errorhandler = None self._result = None + def _reset(self): + while True: + if self._result: + for row in self._result: + pass + self._result = None + if not self.nextset(): + break + del self.messages[:] + def close(self): """Close the cursor. No further queries will be possible.""" if not self.connection: return + + self._flush() try: while self.nextset(): pass @@ -106,22 +130,21 @@ num_rows = connection.next_result() if num_rows == -1: return None - self._do_get_result() - self._post_get_result() - self._warning_check() - return True - - def _do_get_result(self): - """Get the result from the last query.""" - connection = self._get_db() - self._result = self._get_result() - self.rowcount = connection.affected_rows() + result = connection.use_result() + self._result = result + if result: + self.field_flags = result.field_flags() + self._row_decoders = [ get_codec(field, self.decoders) for field in result.fields ] + self.description = result.describe() + else: + self._row_decoders = self.field_flags = () + self.description = None + self.rowcount = -1 #connection.affected_rows() self.rownumber = 0 - self.description = self._result and self._result.describe() or None - self.description_flags = self._result and self._result.field_flags() or None self.lastrowid = connection.insert_id() self._warnings = connection.warning_count() - self._info = connection.info() + self._info = connection.info() + return True def setinputsizes(self, *args): """Does nothing, required by DB API.""" @@ -150,15 +173,15 @@ Returns long integer rows affected, if any """ - del self.messages[:] db = self._get_db() + self._reset() charset = db.character_set_name() if isinstance(query, unicode): query = query.encode(charset) try: if args is not None: query = query % tuple(map(self.connection.literal, args)) - result = self._query(query) + self._query(query) except TypeError, msg: if msg.args[0] in ("not enough arguments for format string", "not all arguments converted"): @@ -173,10 +196,9 @@ self.messages.append((exc, value)) self.errorhandler(self, exc, value) - self._executed = query if not self._defer_warnings: self._warning_check() - return result + return None def executemany(self, query, args): """Execute a multi-row query. @@ -197,8 +219,8 @@ execute(). """ - del self.messages[:] db = self._get_db() + self._reset() if not args: return charset = self.connection.character_set_name() @@ -216,8 +238,7 @@ try: sql_params = ( values % tuple(map(self.connection.literal, row)) for row in args ) multirow_query = '\n'.join([start, ',\n'.join(sql_params), end]) - self._executed = multirow_query - self.rowcount = int(self._query(multirow_query)) + self._query(multirow_query) except TypeError, msg: if msg.args[0] in ("not enough arguments for format string", @@ -234,7 +255,7 @@ if not self._defer_warnings: self._warning_check() - return self.rowcount + return None def callproc(self, procname, args=()): """Execute stored procedure procname with args @@ -283,71 +304,62 @@ if isinstance(query, unicode): query = query.encode(charset) self._query(query) - self._executed = query if not self._defer_warnings: self._warning_check() return args - - def _do_query(self, query): - """Low-levey query wrapper. Overridden by MixIns.""" - connection = self._get_db() - self._executed = query - connection.query(query) - self._do_get_result() - return self.rowcount - - def _fetch_row(self, size=1): - """Low-level fetch_row wrapper.""" - if not self._result: - return () - return self._result.fetch_row(size, self._fetch_type) def __iter__(self): return iter(self.fetchone, None) - def _get_result(self): - """Low-level; uses mysql_store_result()""" - return self._get_db().store_result() - def _query(self, query): - """Low-level; executes query, gets result, and returns rowcount.""" - rowcount = self._do_query(query) - self._post_get_result() - return rowcount - - def _post_get_result(self): - """Low-level""" - self._rows = self._fetch_row(0) - self._result = None - + """Low-level; executes query, gets result, sets up decoders.""" + connection = self._get_db() + self._flush() + self._executed = query + connection.query(query) + result = connection.use_result() + self._result = result + if result: + self.field_flags = result.field_flags() + self._row_decoders = [ get_codec(field, self.decoders) for field in result.fields ] + self.description = result.describe() + else: + self._row_decoders = self.field_flags = () + self.description = None + self.rowcount = -1 #connection.affected_rows() + self.rownumber = 0 + self.lastrowid = connection.insert_id() + self._warnings = connection.warning_count() + self._info = connection.info() + def fetchone(self): """Fetches a single row from the cursor. None indicates that no more rows are available.""" self._check_executed() - if self.rownumber >= len(self._rows): - return None - result = self._rows[self.rownumber] - self.rownumber += 1 - return result + row = self.row_decoder(self._row_decoders, self._result.simple_fetch_row()) + return row def fetchmany(self, size=None): """Fetch up to size rows from the cursor. Result set may be smaller than size. If size is not defined, cursor.arraysize is used.""" self._check_executed() - end = self.rownumber + (size or self.arraysize) - result = self._rows[self.rownumber:end] - self.rownumber = min(end, len(self._rows)) - return result + if size is None: + size = self.arraysize + rows = [] + for i in range(size): + row = self.row_decoder(self._row_decoders, self._result.simple_fetch_row()) + if row is None: break + rows.append(row) + return rows def fetchall(self): - """Fetchs all available rows from the cursor.""" + """Fetches all available rows from the cursor.""" self._check_executed() - if self.rownumber: - result = self._rows[self.rownumber:] + if self._result: + rows = [ self.row_decoder(self._row_decoders, row) for row in self._result ] else: - result = self._rows - self.rownumber = len(self._rows) - return result + rows = [] + return rows def scroll(self, value, mode='relative'): """Scroll the cursor in the result set to a new position according @@ -368,9 +380,3 @@ self.errorhandler(self, IndexError, "out of range") self.rownumber = row - def __iter__(self): - self._check_executed() - result = self.rownumber and self._rows[self.rownumber:] or self._rows - return iter(result) - - _fetch_type = 0
--- a/tests/test_MySQLdb_capabilities.py Fri Feb 19 02:21:11 2010 +0000 +++ b/tests/test_MySQLdb_capabilities.py Sat Feb 20 04:27:21 2010 +0000 @@ -13,7 +13,7 @@ connect_kwargs = dict(db='test', read_default_file='~/.my.cnf', charset='utf8', sql_mode="ANSI,STRICT_TRANS_TABLES,TRADITIONAL") create_table_extra = "ENGINE=INNODB CHARACTER SET UTF8" - leak_test = True + leak_test = False def quote_identifier(self, ident): return "`%s`" % ident