Mercurial > p > mysql-python > mysqldb-2
changeset 75:3b03cb566032 MySQLdb
More serious restructuring and cleaning, especially in the handling
of result sets. All tests pass.
author | adustman |
---|---|
date | Mon, 22 Feb 2010 03:56:44 +0000 |
parents | 80164eb2f090 |
children | 17062a65fde9 |
files | MySQLdb/connections.py MySQLdb/converters.py MySQLdb/cursors.py src/connections.c src/mysqlmod.c src/mysqlmod.h src/results.c |
diffstat | 7 files changed, 315 insertions(+), 583 deletions(-) [+] |
line wrap: on
line diff
--- a/MySQLdb/connections.py Sat Feb 20 04:27:21 2010 +0000 +++ b/MySQLdb/connections.py Mon Feb 22 03:56:44 2010 +0000 @@ -126,8 +126,7 @@ """ from MySQLdb.constants import CLIENT, FIELD_TYPE - from MySQLdb.converters import default_decoders, default_encoders - from MySQLdb.converters import simple_type_encoders as conversions + from MySQLdb.converters import default_decoders, default_encoders, default_row_formatter from MySQLdb.cursors import Cursor import _mysql @@ -135,10 +134,10 @@ self.cursorclass = Cursor charset = kwargs2.pop('charset', '') - if 'decoder_stack' not in kwargs2: - kwargs2['decoder_stack'] = default_decoders; + self.encoders = kwargs2.pop('encoders', default_encoders) self.decoders = kwargs2.pop('decoders', default_decoders) + self.row_formatter = kwargs2.pop('row_formatter', default_row_formatter) client_flag = kwargs.get('client_flag', 0) client_version = tuple( @@ -187,14 +186,14 @@ def close(self): return self._db.close() - + def escape_string(self, s): return self._db.escape_string(s) def string_literal(self, s): - return self._db.string_literal(s) - - def cursor(self, encoders=None, decoders=None): + return self._db.string_literal(s) + + def cursor(self, encoders=None, decoders=None, row_formatter=None): """ Create a cursor on which queries may be performed. The optional cursorclass parameter is used to create the Cursor. By default, @@ -208,8 +207,11 @@ if not decoders: decoders = self.decoders[:] + + if not row_formatter: + row_formatter = self.row_formatter - self._active_cursor = self.cursorclass(self, encoders, decoders) + self._active_cursor = self.cursorclass(self, encoders, decoders, row_formatter) return self._active_cursor def __enter__(self): @@ -220,7 +222,7 @@ self.rollback() else: self.commit() - + def literal(self, obj): """ Given an object obj, returns an SQL literal as a string. @@ -234,17 +236,6 @@ raise self.NotSupportedError("could not encode as SQL", obj) - def _warning_count(self): - """Return the number of warnings generated from the last query.""" - if hasattr(self._db, "warning_count"): - return self._db.warning_count() - else: - info = self._db.info() - if info: - return int(info.split()[-1]) - else: - return 0 - def character_set_name(self): return self._db.character_set_name() @@ -263,9 +254,7 @@ if self._server_version < (4, 1): raise self.NotSupportedError("server is too old to set charset") self._db.query('SET NAMES %s' % charset) - self._db.store_result() - self.string_decoder.charset = charset - self.unicode_literal.charset = charset + self._db.get_result() def set_sql_mode(self, sql_mode): """Set the connection sql_mode. See MySQL documentation for legal @@ -276,8 +265,19 @@ if self._server_version < (4, 1): raise self.NotSupportedError("server is too old to set sql_mode") self._db.query("SET SESSION sql_mode='%s'" % sql_mode) - self._db.store_result() + self._db.get_result() + def _warning_count(self): + """Return the number of warnings generated from the last query.""" + if hasattr(self._db, "warning_count"): + return self._db.warning_count() + else: + info = self._db.info() + if info: + return int(info.split()[-1]) + else: + return 0 + def _show_warnings(self): """Return detailed information about warnings as a sequence of tuples of (Level, Code, Message). This is only supported in MySQL-4.1 and up. @@ -287,7 +287,6 @@ so you should not usually call it yourself.""" if self._server_version < (4, 1): return () self._db.query("SHOW WARNINGS") - result = self._db.store_result() - warnings = result.fetch_row(0) - return warnings + return tuple(self._db.get_result()) +
--- a/MySQLdb/converters.py Sat Feb 20 04:27:21 2010 +0000 +++ b/MySQLdb/converters.py Mon Feb 22 03:56:44 2010 +0000 @@ -110,7 +110,7 @@ } # Decoder protocol -# Each decoder is passed a cursor object and a field object. +# Each decoder is passed a field object. # The decoder returns a single value: # * A callable that given an SQL value, returns a Python object. # This can be as simple as int or str, etc. If the decoder @@ -178,8 +178,5 @@ return None return tuple(iter_row_decoder(decoders, row)) - - +default_row_formatter = tuple_row_decoder - -
--- a/MySQLdb/cursors.py Sat Feb 20 04:27:21 2010 +0000 +++ b/MySQLdb/cursors.py Mon Feb 22 03:56:44 2010 +0000 @@ -13,7 +13,8 @@ import re import sys import weakref -from MySQLdb.converters import get_codec, tuple_row_decoder +from MySQLdb.converters import get_codec +from warnings import warn INSERT_VALUES = re.compile(r"(?P<start>.+values\s*)" r"(?P<values>\(((?<!\\)'[^\)]*?\)[^\)]*(?<!\\)?'|[^\(\)]|(?:\([^\)]*\)))+\))" @@ -40,9 +41,8 @@ _defer_warnings = False _fetch_type = None - def __init__(self, connection, encoders, decoders): + def __init__(self, connection, encoders, decoders, row_formatter): self.connection = weakref.proxy(connection) - self.description = None self.description_flags = None self.rowcount = -1 self.arraysize = 1 @@ -51,6 +51,7 @@ self.messages = [] self.errorhandler = connection.errorhandler self._result = None + self._pending_results = [] self._warnings = 0 self._info = None self.rownumber = None @@ -58,29 +59,45 @@ self.encoders = encoders self.decoders = decoders self._row_decoders = () - self.row_decoder = tuple_row_decoder + self.row_formatter = row_formatter + self.use_result = False + @property + def description(self): + if self._result: + return self._result.description + return None + def _flush(self): """_flush() reads to the end of the current result set, buffering what it can, and then releases the result set.""" if self._result: - for row in self._result: - pass + self._result.flush() self._result = None + db = self._get_db() + while db.next_result(): + result = Result(self) + result.flush() + self._pending_results.append(result) def __del__(self): self.close() self.errorhandler = None self._result = None + del self._pending_results[:] - def _reset(self): - while True: - if self._result: - for row in self._result: - pass - self._result = None - if not self.nextset(): - break + def _clear(self): + if self._result: + self._result.clear() + self._result = None + for result in self._pending_results: + result.clear() + del self._pending_results[:] + db = self._get_db() + while db.next_result(): + result = db.get_result(True) + if result: + result.clear() del self.messages[:] def close(self): @@ -120,31 +137,19 @@ def nextset(self): """Advance to the next result set. - Returns None if there are no more result sets. + Returns False if there are no more result sets. """ - if self._executed: - self.fetchall() - del self.messages[:] - - connection = self._get_db() - num_rows = connection.next_result() - if num_rows == -1: - return None - result = connection.use_result() - self._result = result - if result: - self.field_flags = result.field_flags() - self._row_decoders = [ get_codec(field, self.decoders) for field in result.fields ] - self.description = result.describe() - else: - self._row_decoders = self.field_flags = () - self.description = None - self.rowcount = -1 #connection.affected_rows() - self.rownumber = 0 - self.lastrowid = connection.insert_id() - self._warnings = connection.warning_count() - self._info = connection.info() - return True + db = self._get_db() + self._result.clear() + self._result = None + if self._pending_results: + self._result = self._pending_results[0] + del self._pending_results[0] + return True + if db.next_result(): + self._result = Result(self) + return True + return False def setinputsizes(self, *args): """Does nothing, required by DB API.""" @@ -174,13 +179,13 @@ """ db = self._get_db() - self._reset() + self._clear() charset = db.character_set_name() if isinstance(query, unicode): query = query.encode(charset) try: if args is not None: - query = query % tuple(map(self.connection.literal, args)) + query = query % tuple(( get_codec(a, self.encoders)(db, a) for a in args )) self._query(query) except TypeError, msg: if msg.args[0] in ("not enough arguments for format string", @@ -220,7 +225,7 @@ """ db = self._get_db() - self._reset() + self._clear() if not args: return charset = self.connection.character_set_name() @@ -228,15 +233,19 @@ query = query.encode(charset) matched = INSERT_VALUES.match(query) if not matched: - self.rowcount = sum(( self.execute(query, arg) for arg in args )) - return self.rowcount + rowcount = 0 + for row in args: + self.execute(query, row) + rowcount += self.rowcount + self.rowcount = rowcount + return start = matched.group('start') values = matched.group('values') end = matched.group('end') try: - sql_params = ( values % tuple(map(self.connection.literal, row)) for row in args ) + sql_params = ( values % tuple(( get_codec(a, self.encoders)(db, a) for a in row )) for row in args ) multirow_query = '\n'.join([start, ',\n'.join(sql_params), end]) self._query(multirow_query) @@ -317,49 +326,32 @@ self._flush() self._executed = query connection.query(query) - result = connection.use_result() - self._result = result - if result: - self.field_flags = result.field_flags() - self._row_decoders = [ get_codec(field, self.decoders) for field in result.fields ] - self.description = result.describe() - else: - self._row_decoders = self.field_flags = () - self.description = None - self.rowcount = -1 #connection.affected_rows() - self.rownumber = 0 - self.lastrowid = connection.insert_id() - self._warnings = connection.warning_count() - self._info = connection.info() + self._result = Result(self) def fetchone(self): """Fetches a single row from the cursor. None indicates that no more rows are available.""" self._check_executed() - row = self.row_decoder(self._row_decoders, self._result.simple_fetch_row()) - return row + if not self._result: + return None + return self._result.fetchone() def fetchmany(self, size=None): """Fetch up to size rows from the cursor. Result set may be smaller than size. If size is not defined, cursor.arraysize is used.""" self._check_executed() + if not self._result: + return [] if size is None: size = self.arraysize - rows = [] - for i in range(size): - row = self.row_decoder(self._row_decoders, self._result.simple_fetch_row()) - if row is None: break - rows.append(row) - return rows + return self._result.fetchmany(size) def fetchall(self): """Fetches all available rows from the cursor.""" self._check_executed() - if self._result: - rows = [ self.row_decoder(self._row_decoders, row) for row in self._result ] - else: - rows = [] - return rows + if not self._result: + return [] + return self._result.fetchall() def scroll(self, value, mode='relative'): """Scroll the cursor in the result set to a new position according @@ -380,3 +372,108 @@ self.errorhandler(self, IndexError, "out of range") self.rownumber = row + +class Result(object): + + def __init__(self, cursor): + self.cursor = cursor + db = cursor._get_db() + result = db.get_result(cursor.use_result) + self.result = result + decoders = cursor.decoders + self.row_formatter = cursor.row_formatter + self.max_buffer = 1000 + self.rows = [] + self.row_start = 0 + self.rows_read = 0 + self.row_index = 0 + self.lastrowid = db.insert_id() + self.warning_count = db.warning_count() + self.info = db.info() + self.rowcount = -1 + self.description = None + self.field_flags = () + self.row_decoders = () + + if result: + self.description = result.describe() + self.field_flags = result.field_flags() + self.row_decoders = tuple(( get_codec(field, decoders) for field in result.fields )) + if not cursor.use_result: + self.rowcount = db.affected_rows() + self.flush() + + def flush(self): + if self.result: + self.rows.extend([ self.row_formatter(self.row_decoders, row) for row in self.result ]) + self.result.clear() + self.result = None + + def clear(self): + if self.result: + self.result.clear() + self.result = None + + def fetchone(self): + if self.result: + while self.row_index >= len(self.rows): + row = self.result.fetch_row() + if row is None: + return row + self.rows.append(self.row_formatter(self.row_decoders, row)) + if self.row_index >= len(self.rows): + return None + row = self.rows[self.row_index] + self.row_index += 1 + return row + + def __iter__(self): return self + + def next(self): + row = self.fetchone() + if row is None: + raise StopIteration + return row + + def fetchmany(self, size): + """Fetch up to size rows from the cursor. Result set may be smaller + than size. If size is not defined, cursor.arraysize is used.""" + row_end = self.row_index + size + if self.result: + while self.row_index >= len(self.rows): + row = self.result.fetch_row() + if row is None: + break + self.rows.append(self.row_formatter(self.row_decoders, row)) + if self.row_index >= len(self.rows): + return [] + if row_end >= len(self.rows): + row_end = len(self.rows) + rows = self.rows[self.row_index:row_end] + self.row_index = row_end + return rows + + def fetchall(self): + if self.result: + self.flush() + rows = self.rows[self.row_index:] + self.row_index = len(self.rows) + return rows + + def warning_check(self): + """Check for warnings, and report via the warnings module.""" + if self.warning_count: + cursor = self.cursor + warnings = cursor._get_db()._show_warnings() + if warnings: + # This is done in two loops in case + # Warnings are set to raise exceptions. + for warning in warnings: + cursor.warnings.append((self.Warning, warning)) + for warning in warnings: + warn(warning[-1], self.Warning, 3) + elif self._info: + cursor.messages.append((self.Warning, self._info)) + warn(self._info, self.Warning, 3) + +
--- a/src/connections.c Sat Feb 20 04:27:21 2010 +0000 +++ b/src/connections.c Mon Feb 22 03:56:44 2010 +0000 @@ -9,7 +9,6 @@ PyObject *kwargs) { MYSQL *conn = NULL; - PyObject *decoder_stack = NULL; PyObject *ssl = NULL; #if HAVE_OPENSSL char *key = NULL, *cert = NULL, *ca = NULL, @@ -20,7 +19,7 @@ unsigned int port = 0; unsigned int client_flag = 0; static char *kwlist[] = { "host", "user", "passwd", "db", "port", - "unix_socket", "decoder_stack", + "unix_socket", "connect_timeout", "compress", "named_pipe", "init_command", "read_default_file", "read_default_group", @@ -33,13 +32,12 @@ *read_default_file=NULL, *read_default_group=NULL; - self->decoder_stack = NULL; self->open = 0; check_server_init(-1); - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|ssssisOiiisssiOi:connect", + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|ssssisiiisssiOi:connect", kwlist, &host, &user, &passwd, &db, - &port, &unix_socket, &decoder_stack, + &port, &unix_socket, &connect_timeout, &compress, &named_pipe, &init_command, &read_default_file, @@ -107,12 +105,6 @@ return -1; } - if (!decoder_stack) - decoder_stack = PyList_New(0); - else - Py_INCREF(decoder_stack); - self->decoder_stack = decoder_stack; - /* PyType_GenericAlloc() automatically sets up GC allocation and tracking for GC objects, at least in 2.2.1, so it does not need to @@ -195,16 +187,12 @@ visitproc visit, void *arg) { - if (self->decoder_stack) - return visit(self->decoder_stack, arg); return 0; } static int _mysql_ConnectionObject_clear( _mysql_ConnectionObject *self) { - Py_XDECREF(self->decoder_stack); - self->decoder_stack = NULL; return 0; } @@ -399,7 +387,7 @@ now call store_result(), warning_count(), affected_rows()\n\ , and so forth. \n\ \n\ -Returns 0 if there are more results; -1 if there are no more results\n\ +Returns True if there are more results.\n\ \n\ Non-standard.\n\ "; @@ -418,7 +406,7 @@ #endif Py_END_ALLOW_THREADS if (err > 0) return _mysql_Exception(self); - return PyInt_FromLong(err); + return PyInt_FromLong(err == 0); } #if MYSQL_VERSION_ID >= 40100 @@ -917,22 +905,25 @@ return PyString_FromString(s); } -static char _mysql_ConnectionObject_store_result__doc__[] = -"Returns a result object acquired by mysql_store_result\n\ -(results stored in the client). If no results are available,\n\ -None is returned. Non-standard.\n\ +static char _mysql_ConnectionObject_get_result__doc__[] = +"Returns a result object. If use is True, mysql_use_result()\n\ +is used; otherwise mysql_store_result() is used (the default).\n\ "; static PyObject * -_mysql_ConnectionObject_store_result( +_mysql_ConnectionObject_get_result( _mysql_ConnectionObject *self, - PyObject *unused) + PyObject *args, + PyObject *kwargs) { PyObject *arglist=NULL, *kwarglist=NULL, *result=NULL; + static char *kwlist[] = {"use", NULL}; _mysql_ResultObject *r=NULL; - + int use = 0; + + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|i:get_result", kwlist, &use)) return NULL; check_connection(self); - arglist = Py_BuildValue("(OiO)", self, 0, self->decoder_stack); + arglist = Py_BuildValue("(Oi)", self, use); if (!arglist) goto error; kwarglist = PyDict_New(); if (!kwarglist) goto error; @@ -977,41 +968,6 @@ return PyInt_FromLong((long)pid); } -static char _mysql_ConnectionObject_use_result__doc__[] = -"Returns a result object acquired by mysql_use_result\n\ -(results stored in the server). If no results are available,\n\ -None is returned. Non-standard.\n\ -"; - -static PyObject * -_mysql_ConnectionObject_use_result( - _mysql_ConnectionObject *self, - PyObject *unused) -{ - PyObject *arglist=NULL, *kwarglist=NULL, *result=NULL; - _mysql_ResultObject *r=NULL; - - check_connection(self); - arglist = Py_BuildValue("(OiO)", self, 1, self->decoder_stack); - if (!arglist) return NULL; - kwarglist = PyDict_New(); - if (!kwarglist) goto error; - r = MyAlloc(_mysql_ResultObject, _mysql_ResultObject_Type); - if (!r) goto error; - result = (PyObject *) r; - if (_mysql_ResultObject_Initialize(r, arglist, kwarglist)) - goto error; - if (!(r->result)) { - Py_DECREF(result); - Py_INCREF(Py_None); - result = Py_None; - } - error: - Py_DECREF(arglist); - Py_XDECREF(kwarglist); - return result; -} - static void _mysql_ConnectionObject_dealloc( _mysql_ConnectionObject *self) @@ -1171,6 +1127,12 @@ _mysql_ConnectionObject_get_proto_info__doc__ }, { + "get_result", + (PyCFunction)_mysql_ConnectionObject_get_result, + METH_VARARGS | METH_KEYWORDS, + _mysql_ConnectionObject_get_result__doc__ + }, + { "get_server_info", (PyCFunction)_mysql_ConnectionObject_get_server_info, METH_NOARGS, @@ -1225,12 +1187,6 @@ _mysql_ConnectionObject_stat__doc__ }, { - "store_result", - (PyCFunction)_mysql_ConnectionObject_store_result, - METH_NOARGS, - _mysql_ConnectionObject_store_result__doc__ - }, - { "string_literal", (PyCFunction)_mysql_string_literal, METH_VARARGS, @@ -1241,12 +1197,6 @@ METH_NOARGS, _mysql_ConnectionObject_thread_id__doc__ }, - { - "use_result", - (PyCFunction)_mysql_ConnectionObject_use_result, - METH_NOARGS, - _mysql_ConnectionObject_use_result__doc__ - }, {NULL, NULL} /* sentinel */ }; @@ -1259,13 +1209,6 @@ "True if connection is open" }, { - "decoder_stack", - T_OBJECT, - offsetof(_mysql_ConnectionObject, decoder_stack), - 0, - "Type decoder stack" - }, - { "server_capabilities", T_UINT, offsetof(_mysql_ConnectionObject, connection.server_capabilities),
--- a/src/mysqlmod.c Sat Feb 20 04:27:21 2010 +0000 +++ b/src/mysqlmod.c Mon Feb 22 03:56:44 2010 +0000 @@ -255,70 +255,6 @@ PyObject *self, PyObject *args); -static char _mysql_escape_sequence__doc__[] = -"escape_sequence(seq, dict) -- escape any special characters in sequence\n\ -seq using mapping dict to provide quoting functions for each type.\n\ -Returns a tuple of escaped items."; -static PyObject * -_mysql_escape_sequence( - PyObject *self, - PyObject *args) -{ - PyObject *o=NULL, *d=NULL, *r=NULL, *item, *quoted; - int i, n; - if (!PyArg_ParseTuple(args, "OO:escape_sequence", &o, &d)) - goto error; - if (!PyMapping_Check(d)) { - PyErr_SetString(PyExc_TypeError, - "argument 2 must be a mapping"); - return NULL; - } - if ((n = PyObject_Length(o)) == -1) goto error; - if (!(r = PyTuple_New(n))) goto error; - for (i=0; i<n; i++) { - item = PySequence_GetItem(o, i); - if (!item) goto error; - quoted = _escape_item(item, d); - Py_DECREF(item); - if (!quoted) goto error; - PyTuple_SET_ITEM(r, i, quoted); - } - return r; - error: - Py_XDECREF(r); - return NULL; -} - -static char _mysql_escape_dict__doc__[] = -"escape_sequence(d, dict) -- escape any special characters in\n\ -dictionary d using mapping dict to provide quoting functions for each type.\n\ -Returns a dictionary of escaped items."; -static PyObject * -_mysql_escape_dict( - PyObject *self, - PyObject *args) -{ - PyObject *o=NULL, *d=NULL, *r=NULL, *item, *quoted, *pkey; - Py_ssize_t ppos = 0; - if (!PyArg_ParseTuple(args, "O!O:escape_dict", &PyDict_Type, &o, &d)) - goto error; - if (!PyMapping_Check(d)) { - PyErr_SetString(PyExc_TypeError, - "argument 2 must be a mapping"); - return NULL; - } - if (!(r = PyDict_New())) goto error; - while (PyDict_Next(o, &ppos, &pkey, &item)) { - quoted = _escape_item(item, d); - if (!quoted) goto error; - if (PyDict_SetItem(r, pkey, quoted)==-1) goto error; - Py_DECREF(quoted); - } - return r; - error: - Py_XDECREF(r); - return NULL; -} static char _mysql_get_client_info__doc__[] = "get_client_info() -- Returns a string that represents\n\
--- a/src/mysqlmod.h Sat Feb 20 04:27:21 2010 +0000 +++ b/src/mysqlmod.h Mon Feb 22 03:56:44 2010 +0000 @@ -26,7 +26,6 @@ PyObject_HEAD MYSQL connection; int open; - PyObject *decoder_stack; } _mysql_ConnectionObject; #define check_connection(c) if (!(c->open)) return _mysql_Exception(c) @@ -41,7 +40,6 @@ MYSQL_RES *result; int nfields; int use; - PyObject *decoders; PyObject *fields; } _mysql_ResultObject;
--- a/src/results.c Sat Feb 20 04:27:21 2010 +0000 +++ b/src/results.c Mon Feb 22 03:56:44 2010 +0000 @@ -37,7 +37,7 @@ } static char _mysql_ResultObject__doc__[] = -"result(connection, use=0, decoder_stack=[]) -- Result set from a query.\n\ +"result(connection, use=0) -- Result set from a query.\n\ \n\ Creating instances of this class directly is an excellent way to\n\ shoot yourself in the foot. If using _mysql.connection directly,\n\ @@ -51,18 +51,16 @@ PyObject *args, PyObject *kwargs) { - static char *kwlist[] = {"connection", "use", "decoder_stack", NULL}; + static char *kwlist[] = {"connection", "use", NULL}; MYSQL_RES *result; _mysql_ConnectionObject *conn = NULL; int use = 0; - PyObject *decoder_stack = NULL; - int n, ns, i, j; + int n; - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|iO", kwlist, - &conn, &use, &decoder_stack)) + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|i", kwlist, + &conn, &use)) return -1; - if (!decoder_stack) decoder_stack = PyList_New(0); - if (!decoder_stack) return -1; + self->conn = (PyObject *) conn; Py_INCREF(conn); self->use = use; @@ -74,32 +72,13 @@ self->result = result; Py_END_ALLOW_THREADS ; if (!result) { - self->decoders = PyTuple_New(0); return 0; } n = mysql_num_fields(result); - ns = PySequence_Length(decoder_stack); self->nfields = n; - if (!(self->decoders = PyTuple_New(n))) return -1; self->fields = _mysql_ResultObject_get_fields(self, NULL); - for (i=0; i<n; i++) { - PyObject *field = PyTuple_GET_ITEM(self->fields, i); - for (j=0; j<ns; j++) { - PyObject *df = PySequence_GetItem(decoder_stack, j); - if (!df) goto error; - PyObject *f = PyObject_CallFunctionObjArgs(df, field, NULL); - Py_DECREF(df); - if (!f) goto error; - if (f != Py_None) { - PyTuple_SET_ITEM(self->decoders, i, f); - break; - } - Py_DECREF(f); - } - } + return 0; - error: - return -1; } static int @@ -109,25 +88,14 @@ void *arg) { int r; - if (self->decoders) { - if (!(r = visit(self->decoders, arg))) return r; + if (self->fields) { + if (!(r = visit(self->fields, arg))) return r; } if (self->conn) return visit(self->conn, arg); return 0; } -static int -_mysql_ResultObject_clear( - _mysql_ResultObject *self) -{ - Py_XDECREF(self->decoders); - self->decoders = NULL; - Py_XDECREF(self->conn); - self->conn = NULL; - return 0; -} - static char _mysql_ResultObject_describe__doc__[] = "Returns the sequence of 7-tuples required by the DB-API for\n\ the Cursor.description attribute.\n\ @@ -165,6 +133,61 @@ return NULL; } +static char _mysql_ResultObject_fetch_row__doc__[] = +"fetchrow()\n\ + Fetches one row as a tuple of strings.\n\ + NULL is returned as None.\n\ + A single None indicates the end of the result set.\n\ +"; + +static PyObject * +_mysql_ResultObject_fetch_row( + _mysql_ResultObject *self, + PyObject *unused) + { + unsigned int n, i; + unsigned long *length; + PyObject *r=NULL; + MYSQL_ROW row; + + check_result_connection(self); + + if (!self->use) + row = mysql_fetch_row(self->result); + else { + Py_BEGIN_ALLOW_THREADS; + row = mysql_fetch_row(self->result); + Py_END_ALLOW_THREADS; + } + if (!row && mysql_errno(&(((_mysql_ConnectionObject *)(self->conn))->connection))) { + _mysql_Exception((_mysql_ConnectionObject *)self->conn); + goto error; + } + if (!row) { + Py_INCREF(Py_None); + return Py_None; + } + + n = mysql_num_fields(self->result); + if (!(r = PyTuple_New(n))) return NULL; + length = mysql_fetch_lengths(self->result); + for (i=0; i<n; i++) { + PyObject *v; + if (row[i]) { + v = PyString_FromStringAndSize(row[i], length[i]); + if (!v) goto error; + } else /* NULL */ { + v = Py_None; + Py_INCREF(v); + } + PyTuple_SET_ITEM(r, i, v); + } + return r; + error: + Py_XDECREF(r); + return NULL; +} + static char _mysql_ResultObject_field_flags__doc__[] = "Returns a tuple of field flags, one for each column in the result.\n\ " ; @@ -193,294 +216,40 @@ return NULL; } -static PyObject * -_mysql_field_to_python( - PyObject *decoder, - char *rowitem, - unsigned long length) -{ - PyObject *v; - if (rowitem) { - if (decoder != Py_None) - v = PyObject_CallFunction(decoder, - "s#", - rowitem, - (int)length); - else - v = PyString_FromStringAndSize(rowitem, - (int)length); - if (!v) - return NULL; - } else { - Py_INCREF(Py_None); - v = Py_None; - } - return v; -} - -static PyObject * -_mysql_row_to_tuple( - _mysql_ResultObject *self, - MYSQL_ROW row) -{ - unsigned int n, i; - unsigned long *length; - PyObject *r, *c; - - n = mysql_num_fields(self->result); - if (!(r = PyTuple_New(n))) return NULL; - length = mysql_fetch_lengths(self->result); - for (i=0; i<n; i++) { - PyObject *v; - c = PyTuple_GET_ITEM(self->decoders, i); - v = _mysql_field_to_python(c, row[i], length[i]); - if (!v) goto error; - PyTuple_SET_ITEM(r, i, v); - } - return r; - error: - Py_XDECREF(r); - return NULL; -} - -static PyObject * -_mysql_row_to_dict( - _mysql_ResultObject *self, - MYSQL_ROW row) -{ - unsigned int n, i; - unsigned long *length; - PyObject *r, *c; - MYSQL_FIELD *fields; - - n = mysql_num_fields(self->result); - if (!(r = PyDict_New())) return NULL; - length = mysql_fetch_lengths(self->result); - fields = mysql_fetch_fields(self->result); - for (i=0; i<n; i++) { - PyObject *v; - c = PyTuple_GET_ITEM(self->decoders, i); - v = _mysql_field_to_python(c, row[i], length[i]); - if (!v) goto error; - if (!PyMapping_HasKeyString(r, fields[i].name)) { - PyMapping_SetItemString(r, fields[i].name, v); - } else { - int len; - char buf[256]; - strncpy(buf, fields[i].table, 256); - len = strlen(buf); - strncat(buf, ".", 256-len); - len = strlen(buf); - strncat(buf, fields[i].name, 256-len); - PyMapping_SetItemString(r, buf, v); - } - Py_DECREF(v); - } - return r; - error: - Py_XDECREF(r); - return NULL; -} - -static PyObject * -_mysql_row_to_dict_old( - _mysql_ResultObject *self, - MYSQL_ROW row) -{ - unsigned int n, i; - unsigned long *length; - PyObject *r, *c; - MYSQL_FIELD *fields; - - n = mysql_num_fields(self->result); - if (!(r = PyDict_New())) return NULL; - length = mysql_fetch_lengths(self->result); - fields = mysql_fetch_fields(self->result); - for (i=0; i<n; i++) { - PyObject *v; - c = PyTuple_GET_ITEM(self->decoders, i); - v = _mysql_field_to_python(c, row[i], length[i]); - if (!v) goto error; - { - int len=0; - char buf[256]=""; - if (strlen(fields[i].table)) { - strncpy(buf, fields[i].table, 256); - len = strlen(buf); - strncat(buf, ".", 256-len); - len = strlen(buf); - } - strncat(buf, fields[i].name, 256-len); - PyMapping_SetItemString(r, buf, v); - } - Py_DECREF(v); - } - return r; - error: - Py_XDECREF(r); - return NULL; -} - typedef PyObject *_PYFUNC(_mysql_ResultObject *, MYSQL_ROW); -int -_mysql__fetch_row( - _mysql_ResultObject *self, - PyObject **r, - int skiprows, - int maxrows, - _PYFUNC *convert_row) -{ - unsigned int i; - MYSQL_ROW row; - - for (i = skiprows; i<(skiprows+maxrows); i++) { - PyObject *v; - if (!self->use) - row = mysql_fetch_row(self->result); - else { - Py_BEGIN_ALLOW_THREADS; - row = mysql_fetch_row(self->result); - Py_END_ALLOW_THREADS; - } - if (!row && mysql_errno(&(((_mysql_ConnectionObject *)(self->conn))->connection))) { - _mysql_Exception((_mysql_ConnectionObject *)self->conn); - goto error; - } - if (!row) { - if (_PyTuple_Resize(r, i) == -1) goto error; - break; - } - v = convert_row(self, row); - if (!v) goto error; - PyTuple_SET_ITEM(*r, i, v); - } - return i-skiprows; - error: - return -1; -} - -static char _mysql_ResultObject_fetch_row__doc__[] = -"fetch_row([maxrows, how]) -- Fetches up to maxrows as a tuple.\n\ -The rows are formatted according to how:\n\ -\n\ - 0 -- tuples (default)\n\ - 1 -- dictionaries, key=column or table.column if duplicated\n\ - 2 -- dictionaries, key=table.column\n\ +static char _mysql_ResultObject_clear__doc__[] = +"clear()\n\ + Reads to the end of the result set, discarding all the rows.\n\ "; static PyObject * -_mysql_ResultObject_fetch_row( - _mysql_ResultObject *self, - PyObject *args, - PyObject *kwargs) -{ - typedef PyObject *_PYFUNC(_mysql_ResultObject *, MYSQL_ROW); - static char *kwlist[] = { "maxrows", "how", NULL }; - static _PYFUNC *row_converters[] = - { - _mysql_row_to_tuple, - _mysql_row_to_dict, - _mysql_row_to_dict_old - }; - _PYFUNC *convert_row; - unsigned int maxrows=1, how=0, skiprows=0, rowsadded; - PyObject *r=NULL; - - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|ii:fetch_row", kwlist, - &maxrows, &how)) - return NULL; - check_result_connection(self); - if (how < 0 || how >= sizeof(row_converters)) { - PyErr_SetString(PyExc_ValueError, "how out of range"); - return NULL; - } - convert_row = row_converters[how]; - if (maxrows) { - if (!(r = PyTuple_New(maxrows))) goto error; - rowsadded = _mysql__fetch_row(self, &r, skiprows, maxrows, - convert_row); - if (rowsadded == -1) goto error; - } else { - if (self->use) { - maxrows = 1000; - if (!(r = PyTuple_New(maxrows))) goto error; - while (1) { - rowsadded = _mysql__fetch_row(self, &r, skiprows, - maxrows, convert_row); - if (rowsadded == -1) goto error; - skiprows += rowsadded; - if (rowsadded < maxrows) break; - if (_PyTuple_Resize(&r, skiprows + maxrows) == -1) - goto error; - } - } else { - /* XXX if overflow, maxrows<0? */ - maxrows = (int) mysql_num_rows(self->result); - if (!(r = PyTuple_New(maxrows))) goto error; - rowsadded = _mysql__fetch_row(self, &r, 0, - maxrows, convert_row); - if (rowsadded == -1) goto error; - } - } - return r; - error: - Py_XDECREF(r); - return NULL; -} - -static char _mysql_ResultObject_simple_fetch_row__doc__[] = -"simple_fetchrow()\n\ - Fetches one row as a tuple of strings.\n\ - NULL is returned as None.\n\ -"; - -static PyObject * -_mysql_ResultObject_simple_fetch_row( +_mysql_ResultObject_clear( _mysql_ResultObject *self, PyObject *unused) { - unsigned int n, i; - unsigned long *length; - PyObject *r=NULL; - MYSQL_ROW row; - - check_result_connection(self); - - if (!self->use) - row = mysql_fetch_row(self->result); - else { - Py_BEGIN_ALLOW_THREADS; - row = mysql_fetch_row(self->result); - Py_END_ALLOW_THREADS; - } - if (!row && mysql_errno(&(((_mysql_ConnectionObject *)(self->conn))->connection))) { - _mysql_Exception((_mysql_ConnectionObject *)self->conn); - goto error; + if (self->result) { + if (self->use) { + Py_BEGIN_ALLOW_THREADS; + while (mysql_fetch_row(self->result)); + Py_END_ALLOW_THREADS; + + if (mysql_errno(&(((_mysql_ConnectionObject *)(self->conn))->connection))) { + _mysql_Exception((_mysql_ConnectionObject *)self->conn); + return NULL; + } + } } - if (!row) { - Py_INCREF(Py_None); - return Py_None; + Py_XDECREF(self->fields); + self->fields = NULL; + Py_XDECREF(self->conn); + self->conn = NULL; + if (self->result) { + mysql_free_result(self->result); + self->result = NULL; } - - n = mysql_num_fields(self->result); - if (!(r = PyTuple_New(n))) return NULL; - length = mysql_fetch_lengths(self->result); - for (i=0; i<n; i++) { - PyObject *v; - if (row[i]) { - v = PyString_FromStringAndSize(row[i], length[i]); - if (!v) goto error; - } else /* NULL */ { - v = Py_None; - Py_INCREF(v); - } - PyTuple_SET_ITEM(r, i, v); - } - return r; - error: - Py_XDECREF(r); - return NULL; + Py_INCREF(Py_None); + return Py_None; } static PyObject * @@ -500,7 +269,7 @@ { PyObject *row; check_result_connection(self); - row = _mysql_ResultObject_simple_fetch_row(self, NULL); + row = _mysql_ResultObject_fetch_row(self, NULL); if (row == Py_None) { Py_DECREF(row); PyErr_SetString(PyExc_StopIteration, ""); @@ -598,8 +367,8 @@ _mysql_ResultObject *self) { PyObject_GC_UnTrack((PyObject *)self); + _mysql_ResultObject_clear(self, NULL); mysql_free_result(self->result); - _mysql_ResultObject_clear(self); MyFree(self); } @@ -633,6 +402,12 @@ _mysql_ResultObject_row_tell__doc__ }, { + "clear", + (PyCFunction)_mysql_ResultObject_clear, + METH_NOARGS, + _mysql_ResultObject_clear__doc__ + }, + { "describe", (PyCFunction)_mysql_ResultObject_describe, METH_NOARGS, @@ -644,12 +419,6 @@ METH_VARARGS | METH_KEYWORDS, _mysql_ResultObject_fetch_row__doc__ }, - { - "simple_fetch_row", - (PyCFunction)_mysql_ResultObject_simple_fetch_row, - METH_VARARGS | METH_KEYWORDS, - _mysql_ResultObject_simple_fetch_row__doc__ - }, { "field_flags", @@ -681,13 +450,6 @@ "Connection associated with result" }, { - "decoders", - T_OBJECT, - offsetof(_mysql_ResultObject, decoders), - RO, - "Field decoders for result set" - }, - { "fields", T_OBJECT, offsetof(_mysql_ResultObject, fields),