Mercurial > p > mysql-python > mysqldb-2
changeset 67:98d968f5af11 MySQLdb
Reimplement MySQL->Python type conversion in C; much simpler and easier to deal with now. Hey, all my tests pass, so I guess that means I need to write some more tests.
author | adustman |
---|---|
date | Mon, 30 Mar 2009 20:21:24 +0000 |
parents | 5a7c30cd9de2 |
children | 1e1e24fddc74 |
files | MySQLdb/connections.py MySQLdb/converters.py MySQLdb/cursors.py MySQLdb/times.py src/connections.c src/mysqlmod.c src/mysqlmod.h src/results.c tests/test_MySQLdb_capabilities.py |
diffstat | 9 files changed, 187 insertions(+), 337 deletions(-) [+] |
line wrap: on
line diff
--- a/MySQLdb/connections.py Sun Mar 29 16:26:30 2009 +0000 +++ b/MySQLdb/connections.py Mon Mar 30 20:21:24 2009 +0000 @@ -135,10 +135,9 @@ self.cursorclass = Cursor charset = kwargs2.pop('charset', '') - if 'decoders' not in kwargs2: - kwargs2['decoders'] = default_decoders; - self.decoders = kwargs2.pop('decoders', default_decoders) # XXX kwargs2['decoders'] - self.encoders = conversions # XXX kwargs2.pop('encoders', default_encoders) + if 'decoder_stack' not in kwargs2: + kwargs2['decoder_stack'] = default_decoders; + self.encoders = kwargs2.pop('encoders', default_encoders) client_flag = kwargs.get('client_flag', 0) client_version = tuple( @@ -184,18 +183,22 @@ def close(self): return self._db.close() - def cursor(self, decoders=None, encoders=None): + def escape_string(self, s): + return self._db.escape_string(s) + + def string_literal(self, s): + return self._db.string_literal(s) + + def cursor(self, encoders=None): """ Create a cursor on which queries may be performed. The optional cursorclass parameter is used to create the Cursor. By default, self.cursorclass=cursors.Cursor is used. """ - if not decoders: - decoders = self.decoders[:] if not encoders: - encoders = self.encoders.copy() #[:] + encoders = self.encoders[:] - return self.cursorclass(self, decoders, encoders) + return self.cursorclass(self, encoders) def __enter__(self): return self.cursor() @@ -208,13 +211,16 @@ def literal(self, obj): """ - If obj is a single object, returns an SQL literal as a string. If - obj is a non-string sequence, the items of the sequence are converted - and returned as a sequence. + Given an object obj, returns an SQL literal as a string. - Non-standard. For internal use; do not use this in your applications. + Non-standard. """ - return self._db.escape(obj, self.encoders) + for encoder in self.encoders: + f = encoder(obj) + if f: + return f(self, obj) + + raise self.NotSupportedError("could not encode as SQL", obj) def _warning_count(self): """Return the number of warnings generated from the last query."""
--- a/MySQLdb/converters.py Sun Mar 29 16:26:30 2009 +0000 +++ b/MySQLdb/converters.py Mon Mar 30 20:21:24 2009 +0000 @@ -34,7 +34,7 @@ """ -from _mysql import string_literal, escape_sequence, escape_dict, NULL +from _mysql import NULL from MySQLdb.constants import FIELD_TYPE, FLAG from MySQLdb.times import datetime_to_sql, timedelta_to_sql, \ timedelta_or_None, datetime_or_None, date_or_None, \ @@ -47,7 +47,7 @@ __revision__ = "$Revision$"[11:-2] __author__ = "$Author$"[9:-2] -def bool_to_sql(boolean, conv): +def bool_to_sql(connection, boolean): """Convert a Python bool to an SQL literal.""" return str(int(boolean)) @@ -55,37 +55,33 @@ """Convert MySQL SET column to Python set.""" return set([ i for i in value.split(',') if i ]) -def Set_to_sql(value, conv): +def Set_to_sql(connection, value): """Convert a Python set to an SQL literal.""" - return string_literal(','.join(value), conv) - -def object_to_sql(obj, conv): - """Convert something into a string via str().""" - return str(obj) + return connection.string_literal(','.join(value)) -def unicode_to_sql(value, conv): - """Convert a unicode object to a string using the default encoding. - This is only used as a placeholder for the real function, which - is connection-dependent.""" - assert isinstance(value, unicode) - return value.encode() +def object_to_sql(connection, obj): + """Convert something into a string via str(). + The result will not be quoted.""" + return connection.escape_string(str(obj)) -def float_to_sql(value, conv): +def unicode_to_sql(connection, value): + """Convert a unicode object to a string using the connection encoding.""" + return connection.string_literal(value.encode(connection.character_set_name())) + +def float_to_sql(connection, value): return '%.15g' % value -def None_to_sql(value, conv): +def None_to_sql(connection, value): """Convert None to NULL.""" return NULL # duh -def object_to_quoted_sql(obj, conv): - """Convert something into a SQL string literal. If using - MySQL-3.23 or newer, string_literal() is a method of the - _mysql.MYSQL object, and this function will be overridden with - that method when the connection is created.""" +def object_to_quoted_sql(connection, obj): + """Convert something into a SQL string literal.""" + if hasattr(obj, "__unicode__"): + return unicode_to_sql(connection, obj) + return connection.string_literal(str(obj)) - return string_literal(obj, conv) - -def instance_to_sql(obj, conv): +def instance_to_sql(connection, obj): """Convert an Instance to a string representation. If the __str__() method produces acceptable output, then you don't need to add the class to conversions; it will be handled by the default @@ -101,22 +97,14 @@ conv[obj.__class__] = conv[classes[0]] return conv[classes[0]](obj, conv) -def char_array(obj): - return array.array('c', obj) - -def array_to_sql(obj, conv): - return object_to_quoted_sql(obj.tostring(), conv) +def array_to_sql(connection, obj): + return connection.string_literal(obj.tostring()) simple_type_encoders = { int: object_to_sql, long: object_to_sql, float: float_to_sql, type(None): None_to_sql, - tuple: escape_sequence, - list: escape_sequence, - dict: escape_dict, - InstanceType: instance_to_sql, - array.array: array_to_sql, unicode: unicode_to_sql, object: instance_to_sql, bool: bool_to_sql, @@ -156,19 +144,18 @@ # returns None, this decoder will be ignored and the next decoder # on the stack will be checked. -def filter_NULL(f): - def _filter_NULL(o): - if o is None: return o - return f(o) - _filter_NULL.__name__ = f.__name__ - return _filter_NULL - def default_decoder(field): return str +def default_encoder(value): + return object_to_quoted_sql + def simple_decoder(field): return simple_field_decoders.get(field.type, None) +def simple_encoder(value): + return simple_type_encoders.get(type(value), None) + character_types = [ FIELD_TYPE.BLOB, FIELD_TYPE.STRING, @@ -195,6 +182,8 @@ ] default_encoders = [ + simple_encoder, + default_encoder, ]
--- a/MySQLdb/cursors.py Sun Mar 29 16:26:30 2009 +0000 +++ b/MySQLdb/cursors.py Mon Mar 30 20:21:24 2009 +0000 @@ -45,7 +45,7 @@ _defer_warnings = False _fetch_type = None - def __init__(self, connection, decoders, encoders): + def __init__(self, connection, encoders): from MySQLdb.converters import default_decoders self.connection = weakref.proxy(connection) self.description = None @@ -60,7 +60,7 @@ self._warnings = 0 self._info = None self.rownumber = None - self._decoders = decoders + self._encoders = encoders def __del__(self): self.close() @@ -117,25 +117,10 @@ self._warning_check() return True - def _lookup_decoder(self, field): - from MySQLdb.converters import filter_NULL - for plugin in self._decoders: - f = plugin(field) - if f: - return filter_NULL(f) - return None # this should never happen - def _do_get_result(self): """Get the result from the last query.""" connection = self._get_db() self._result = self._get_result() - if self._result: - self.sql_to_python = [ - self._lookup_decoder(f) - for f in self._result.fields() - ] - else: - self.sql_to_python = [] self.rowcount = connection.affected_rows() self.rownumber = 0 self.description = self._result and self._result.describe() or None @@ -176,9 +161,9 @@ charset = db.character_set_name() if isinstance(query, unicode): query = query.encode(charset) - if args is not None: - query = query % self.connection.literal(args) try: + if args is not None: + query = query % tuple(map(self.connection.literal, args)) result = self._query(query) except TypeError, msg: if msg.args[0] in ("not enough arguments for format string", @@ -235,7 +220,11 @@ values = matched.group('values') try: - sql_params = [ values % self.connection.literal(arg) for arg in args ] + sql_params = ( values % tuple(map(self.connection.literal, row)) for row in args ) + multirow_query = '\n'.join([start, ',\n'.join(sql_params), end]) + self._executed = multirow_query + self.rowcount = int(self._query(multirow_query)) + except TypeError, msg: if msg.args[0] in ("not enough arguments for format string", "not all arguments converted"): @@ -248,9 +237,7 @@ exc, value, traceback = sys.exc_info() del traceback self.errorhandler(self, exc, value) - self.rowcount = int(self._query( - '\n'.join([start, ',\n'.join(sql_params), end, - ]))) + if not self._defer_warnings: self._warning_check() return self.rowcount @@ -319,14 +306,7 @@ """Low-level fetch_row wrapper.""" if not self._result: return () - # unfortunately it is necessary to wrap these generators up as tuples - # as the rows are expected to be subscriptable. - return tuple( - ( - tuple( ( f(x) for f, x in zip(self.sql_to_python, row) ) ) - for row in self._result.fetch_row(size, self._fetch_type) - ) - ) + return self._result.fetch_row(size, self._fetch_type) def __iter__(self): return iter(self.fetchone, None)
--- a/MySQLdb/times.py Sun Mar 29 16:26:30 2009 +0000 +++ b/MySQLdb/times.py Mon Mar 30 20:21:24 2009 +0000 @@ -12,7 +12,6 @@ from time import localtime from datetime import date, datetime, time, timedelta -from _mysql import string_literal # These are required for DB-API (PEP-249) Date = date @@ -192,17 +191,17 @@ """ try: - return date(*[ int(x) for x in obj.split('-', 2) ]) + return date(*map(int, obj.split('-', 2))) except ValueError: return None -def datetime_to_sql(obj, conv): +def datetime_to_sql(connection, obj): """Format a DateTime object as an ISO timestamp.""" - return string_literal(datetime_to_str(obj), conv) + return connection.string_literal(datetime_to_str(obj)) -def timedelta_to_sql(obj, conv): +def timedelta_to_sql(connection, obj): """Format a timedelta as an SQL literal.""" - return string_literal(timedelta_to_str(obj), conv) + return connection.string_literal(timedelta_to_str(obj)) def mysql_timestamp_converter(timestamp): """Convert a MySQL TIMESTAMP to a Timestamp object.
--- a/src/connections.c Sun Mar 29 16:26:30 2009 +0000 +++ b/src/connections.c Mon Mar 30 20:21:24 2009 +0000 @@ -9,7 +9,7 @@ PyObject *kwargs) { MYSQL *conn = NULL; - PyObject *conv = NULL; + PyObject *decoder_stack = NULL; PyObject *ssl = NULL; #if HAVE_OPENSSL char *key = NULL, *cert = NULL, *ca = NULL, @@ -20,7 +20,7 @@ unsigned int port = 0; unsigned int client_flag = 0; static char *kwlist[] = { "host", "user", "passwd", "db", "port", - "unix_socket", "conv", + "unix_socket", "decoder_stack", "connect_timeout", "compress", "named_pipe", "init_command", "read_default_file", "read_default_group", @@ -33,13 +33,13 @@ *read_default_file=NULL, *read_default_group=NULL; - self->converter = NULL; + self->decoder_stack = NULL; self->open = 0; check_server_init(-1); if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|ssssisOiiisssiOi:connect", kwlist, &host, &user, &passwd, &db, - &port, &unix_socket, &conv, + &port, &unix_socket, &decoder_stack, &connect_timeout, &compress, &named_pipe, &init_command, &read_default_file, @@ -49,15 +49,12 @@ )) return -1; - /* Keep the converter mapping or a blank mapping dict */ - if (!conv) - conv = PyDict_New(); + if (!decoder_stack) + decoder_stack = PyList_New(0); else - Py_INCREF(conv); - if (!conv) - return -1; - self->converter = conv; - + Py_INCREF(decoder_stack); + self->decoder_stack = decoder_stack; + #define _stringsuck(d,t,s) {t=PyMapping_GetItemString(s,#d);\ if(t){d=PyString_AsString(t);Py_DECREF(t);}\ PyErr_Clear();} @@ -148,10 +145,6 @@ unix_socket\n\ string, location of unix_socket (UNIX-ish only)\n\ \n\ -conv\n\ - mapping, maps MySQL FIELD_TYPE.* to Python functions which\n\ - convert a string to the appropriate Python type\n\ -\n\ connect_timeout\n\ number of seconds to wait before the connection\n\ attempt fails.\n\ @@ -201,16 +194,16 @@ visitproc visit, void *arg) { - if (self->converter) - return visit(self->converter, arg); + if (self->decoder_stack) + return visit(self->decoder_stack, arg); return 0; } static int _mysql_ConnectionObject_clear( _mysql_ConnectionObject *self) { - Py_XDECREF(self->converter); - self->converter = NULL; + Py_XDECREF(self->decoder_stack); + self->decoder_stack = NULL; return 0; } @@ -218,44 +211,11 @@ _escape_item( PyObject *item, PyObject *d); - -char _mysql_escape__doc__[] = -"escape(obj, dict) -- escape any special characters in object obj\n\ -using mapping dict to provide quoting functions for each type.\n\ -Returns a SQL literal string."; -PyObject * -_mysql_escape( - PyObject *self, - PyObject *args) -{ - PyObject *o=NULL, *d=NULL; - if (!PyArg_ParseTuple(args, "O|O:escape", &o, &d)) - return NULL; - if (d) { - if (!PyMapping_Check(d)) { - PyErr_SetString(PyExc_TypeError, - "argument 2 must be a mapping"); - return NULL; - } - return _escape_item(o, d); - } else { - if (!self) { - PyErr_SetString(PyExc_TypeError, - "argument 2 must be a mapping"); - return NULL; - } - return _escape_item(o, - ((_mysql_ConnectionObject *) self)->converter); - } -} char _mysql_escape_string__doc__[] = "escape_string(s) -- quote any SQL-interpreted characters in string s.\n\ -\n\ -Use connection.escape_string(s), if you use it at all.\n\ -_mysql.escape_string(s) cannot handle character sets. You are\n\ -probably better off using connection.escape(o) instead, since\n\ -it will escape entire sequences as well as strings."; +If you want quotes around your value, use string_literal(s) instead.\n\ +"; PyObject * _mysql_escape_string( @@ -269,57 +229,34 @@ str = PyString_FromStringAndSize((char *) NULL, size*2+1); if (!str) return PyErr_NoMemory(); out = PyString_AS_STRING(str); -#if MYSQL_VERSION_ID < 32321 - len = mysql_escape_string(out, in, size); -#else - check_server_init(NULL); - if (self && self->open) - len = mysql_real_escape_string(&(self->connection), out, in, size); - else - len = mysql_escape_string(out, in, size); -#endif + len = mysql_real_escape_string(&(self->connection), out, in, size); if (_PyString_Resize(&str, len) < 0) return NULL; return (str); } char _mysql_string_literal__doc__[] = -"string_literal(obj) -- converts object obj into a SQL string literal.\n\ +"string_literal(s) -- converts string s into a SQL string literal.\n\ This means, any special SQL characters are escaped, and it is enclosed\n\ within single quotes. In other words, it performs:\n\ \n\ -\"'%s'\" % escape_string(str(obj))\n\ -\n\ -Use connection.string_literal(obj), if you use it at all.\n\ -_mysql.string_literal(obj) cannot handle character sets."; +\"'%s'\" % escape_string(s)\n\ +"; PyObject * _mysql_string_literal( _mysql_ConnectionObject *self, PyObject *args) { - PyObject *str, *s, *o, *d; + PyObject *str; char *in, *out; int len, size; - if (!PyArg_ParseTuple(args, "O|O:string_literal", &o, &d)) return NULL; - s = PyObject_Str(o); - if (!s) return NULL; - in = PyString_AsString(s); - size = PyString_GET_SIZE(s); + if (!PyArg_ParseTuple(args, "s#:string_literal", &in, &size)) return NULL; str = PyString_FromStringAndSize((char *) NULL, size*2+3); if (!str) return PyErr_NoMemory(); out = PyString_AS_STRING(str); -#if MYSQL_VERSION_ID < 32321 - len = mysql_escape_string(out+1, in, size); -#else - check_server_init(NULL); - if (self && self->open) - len = mysql_real_escape_string(&(self->connection), out+1, in, size); - else - len = mysql_escape_string(out+1, in, size); -#endif + len = mysql_real_escape_string(&(self->connection), out+1, in, size); *out = *(out+len+1) = '\''; if (_PyString_Resize(&str, len+2) < 0) return NULL; - Py_DECREF(s); return (str); } @@ -994,7 +931,7 @@ _mysql_ResultObject *r=NULL; check_connection(self); - arglist = Py_BuildValue("(OiO)", self, 0, self->converter); + arglist = Py_BuildValue("(OiO)", self, 0, self->decoder_stack); if (!arglist) goto error; kwarglist = PyDict_New(); if (!kwarglist) goto error; @@ -1054,7 +991,7 @@ _mysql_ResultObject *r=NULL; check_connection(self); - arglist = Py_BuildValue("(OiO)", self, 1, self->converter); + arglist = Py_BuildValue("(OiO)", self, 1, self->decoder_stack); if (!arglist) return NULL; kwarglist = PyDict_New(); if (!kwarglist) goto error; @@ -1197,12 +1134,6 @@ _mysql_ConnectionObject_dump_debug_info__doc__ }, { - "escape", - (PyCFunction)_mysql_escape, - METH_VARARGS, - _mysql_escape__doc__ - }, - { "escape_string", (PyCFunction)_mysql_escape_string, METH_VARARGS, @@ -1327,11 +1258,11 @@ "True if connection is open" }, { - "converter", + "decoder_stack", T_OBJECT, - offsetof(_mysql_ConnectionObject, converter), + offsetof(_mysql_ConnectionObject, decoder_stack), 0, - "Type conversion mapping" + "Type decoder stack" }, { "server_capabilities",
--- a/src/mysqlmod.c Sun Mar 29 16:26:30 2009 +0000 +++ b/src/mysqlmod.c Mon Mar 30 20:21:24 2009 +0000 @@ -350,36 +350,6 @@ _mysql_debug__doc__ }, { - "escape", - (PyCFunction)_mysql_escape, - METH_VARARGS, - _mysql_escape__doc__ - }, - { - "escape_sequence", - (PyCFunction)_mysql_escape_sequence, - METH_VARARGS, - _mysql_escape_sequence__doc__ - }, - { - "escape_dict", - (PyCFunction)_mysql_escape_dict, - METH_VARARGS, - _mysql_escape_dict__doc__ - }, - { - "escape_string", - (PyCFunction)_mysql_escape_string, - METH_VARARGS, - _mysql_escape_string__doc__ - }, - { - "string_literal", - (PyCFunction)_mysql_string_literal, - METH_VARARGS, - _mysql_string_literal__doc__ - }, - { "get_client_info", (PyCFunction)_mysql_get_client_info, METH_NOARGS,
--- a/src/mysqlmod.h Sun Mar 29 16:26:30 2009 +0000 +++ b/src/mysqlmod.h Mon Mar 30 20:21:24 2009 +0000 @@ -26,7 +26,7 @@ PyObject_HEAD MYSQL connection; int open; - PyObject *converter; + PyObject *decoder_stack; } _mysql_ConnectionObject; #define check_connection(c) if (!(c->open)) return _mysql_Exception(c) @@ -41,7 +41,8 @@ MYSQL_RES *result; int nfields; int use; - PyObject *converter; + PyObject *decoders; + PyObject *fields; } _mysql_ResultObject; extern PyTypeObject _mysql_ResultObject_Type;
--- a/src/results.c Sun Mar 29 16:26:30 2009 +0000 +++ b/src/results.c Mon Mar 30 20:21:24 2009 +0000 @@ -2,8 +2,42 @@ #include "mysqlmod.h" +static PyObject * +_mysql_ResultObject_get_fields( + _mysql_ResultObject *self, + PyObject *unused) +{ + PyObject *arglist=NULL, *kwarglist=NULL; + PyObject *fields=NULL; + _mysql_FieldObject *field=NULL; + unsigned int i, n; + + check_result_connection(self); + kwarglist = PyDict_New(); + if (!kwarglist) goto error; + n = mysql_num_fields(self->result); + if (!(fields = PyTuple_New(n))) return NULL; + for (i=0; i<n; i++) { + arglist = Py_BuildValue("(Oi)", self, i); + if (!arglist) goto error; + field = MyAlloc(_mysql_FieldObject, _mysql_FieldObject_Type); + if (!field) goto error; + if (_mysql_FieldObject_Initialize(field, arglist, kwarglist)) + goto error; + Py_DECREF(arglist); + PyTuple_SET_ITEM(fields, i, (PyObject *) field); + } + Py_DECREF(kwarglist); + return fields; + error: + Py_XDECREF(arglist); + Py_XDECREF(kwarglist); + Py_XDECREF(fields); + return NULL; +} + static char _mysql_ResultObject__doc__[] = -"result(connection, use=0, converter={}) -- Result set from a query.\n\ +"result(connection, use=0, decoder_stack=[]) -- Result set from a query.\n\ \n\ Creating instances of this class directly is an excellent way to\n\ shoot yourself in the foot. If using _mysql.connection directly,\n\ @@ -17,19 +51,18 @@ PyObject *args, PyObject *kwargs) { - static char *kwlist[] = {"connection", "use", "converter", NULL}; + static char *kwlist[] = {"connection", "use", "decoder_stack", NULL}; MYSQL_RES *result; - _mysql_ConnectionObject *conn=NULL; - int use=0; - PyObject *conv=NULL; - int n, i; - MYSQL_FIELD *fields; + _mysql_ConnectionObject *conn = NULL; + int use = 0; + PyObject *decoder_stack = NULL; + int n, ns, i, j; if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|iO", kwlist, - &conn, &use, &conv)) + &conn, &use, &decoder_stack)) return -1; - if (!conv) conv = PyDict_New(); - if (!conv) return -1; + if (!decoder_stack) decoder_stack = PyList_New(0); + if (!decoder_stack) return -1; self->conn = (PyObject *) conn; Py_INCREF(conn); self->use = use; @@ -41,61 +74,32 @@ self->result = result; Py_END_ALLOW_THREADS ; if (!result) { - self->converter = PyTuple_New(0); + self->decoders = PyTuple_New(0); return 0; } n = mysql_num_fields(result); + ns = PySequence_Length(decoder_stack); self->nfields = n; - if (!(self->converter = PyTuple_New(n))) return -1; - fields = mysql_fetch_fields(result); + if (!(self->decoders = PyTuple_New(n))) return -1; + self->fields = _mysql_ResultObject_get_fields(self, NULL); for (i=0; i<n; i++) { - PyObject *tmp, *fun; - tmp = PyInt_FromLong((long) fields[i].type); - if (!tmp) return -1; - fun = PyObject_GetItem(conv, tmp); - Py_DECREF(tmp); - if (!fun) { - PyErr_Clear(); - fun = Py_None; - Py_INCREF(Py_None); + PyObject *field = PyTuple_GET_ITEM(self->fields, i); + for (j=0; j<ns; j++) { + PyObject *df = PySequence_GetItem(decoder_stack, j); + if (!df) goto error; + PyObject *f = PyObject_CallFunctionObjArgs(df, field, NULL); + Py_DECREF(df); + if (!f) goto error; + if (f != Py_None) { + PyTuple_SET_ITEM(self->decoders, i, f); + break; + } + Py_DECREF(f); } - if (PySequence_Check(fun)) { - int j, n2=PySequence_Size(fun); - PyObject *fun2=NULL; - for (j=0; j<n2; j++) { - PyObject *t = PySequence_GetItem(fun, j); - if (!t) continue; - if (!PyTuple_Check(t)) goto cleanup; - if (PyTuple_GET_SIZE(t) == 2) { - long mask; - PyObject *pmask=NULL; - pmask = PyTuple_GET_ITEM(t, 0); - fun2 = PyTuple_GET_ITEM(t, 1); - if (PyInt_Check(pmask)) { - mask = PyInt_AS_LONG(pmask); - if (mask & fields[i].flags) { - Py_DECREF(t); - break; - } - else { - goto cleanup; - } - } else { - Py_DECREF(t); - break; - } - } - cleanup: - Py_DECREF(t); - } - if (!fun2) fun2 = Py_None; - Py_INCREF(fun2); - Py_DECREF(fun); - fun = fun2; - } - PyTuple_SET_ITEM(self->converter, i, fun); } return 0; + error: + return -1; } static int @@ -105,8 +109,8 @@ void *arg) { int r; - if (self->converter) { - if (!(r = visit(self->converter, arg))) return r; + if (self->decoders) { + if (!(r = visit(self->decoders, arg))) return r; } if (self->conn) return visit(self->conn, arg); @@ -117,8 +121,8 @@ _mysql_ResultObject_clear( _mysql_ResultObject *self) { - Py_XDECREF(self->converter); - self->converter = NULL; + Py_XDECREF(self->decoders); + self->decoders = NULL; Py_XDECREF(self->conn); self->conn = NULL; return 0; @@ -161,45 +165,6 @@ return NULL; } -static char _mysql_ResultObject_fields__doc__[] = -"Returns the sequence of 7-tuples required by the DB-API for\n\ -the Cursor.description attribute.\n\ -"; - -static PyObject * -_mysql_ResultObject_fields( - _mysql_ResultObject *self, - PyObject *unused) -{ - PyObject *arglist=NULL, *kwarglist=NULL; - PyObject *fields=NULL; - _mysql_FieldObject *field=NULL; - unsigned int i, n; - - check_result_connection(self); - kwarglist = PyDict_New(); - if (!kwarglist) goto error; - n = mysql_num_fields(self->result); - if (!(fields = PyTuple_New(n))) return NULL; - for (i=0; i<n; i++) { - arglist = Py_BuildValue("(Oi)", self, i); - if (!arglist) goto error; - field = MyAlloc(_mysql_FieldObject, _mysql_FieldObject_Type); - if (!field) goto error; - if (_mysql_FieldObject_Initialize(field, arglist, kwarglist)) - goto error; - Py_DECREF(arglist); - PyTuple_SET_ITEM(fields, i, (PyObject *) field); - } - Py_DECREF(kwarglist); - return fields; - error: - Py_XDECREF(arglist); - Py_XDECREF(kwarglist); - Py_XDECREF(fields); - return NULL; -} - static char _mysql_ResultObject_field_flags__doc__[] = "Returns a tuple of field flags, one for each column in the result.\n\ " ; @@ -230,14 +195,14 @@ static PyObject * _mysql_field_to_python( - PyObject *converter, + PyObject *decoder, char *rowitem, unsigned long length) { PyObject *v; if (rowitem) { - if (converter != Py_None) - v = PyObject_CallFunction(converter, + if (decoder != Py_None) + v = PyObject_CallFunction(decoder, "s#", rowitem, (int)length); @@ -267,7 +232,7 @@ length = mysql_fetch_lengths(self->result); for (i=0; i<n; i++) { PyObject *v; - c = PyTuple_GET_ITEM(self->converter, i); + c = PyTuple_GET_ITEM(self->decoders, i); v = _mysql_field_to_python(c, row[i], length[i]); if (!v) goto error; PyTuple_SET_ITEM(r, i, v); @@ -294,7 +259,7 @@ fields = mysql_fetch_fields(self->result); for (i=0; i<n; i++) { PyObject *v; - c = PyTuple_GET_ITEM(self->converter, i); + c = PyTuple_GET_ITEM(self->decoders, i); v = _mysql_field_to_python(c, row[i], length[i]); if (!v) goto error; if (!PyMapping_HasKeyString(r, fields[i].name)) { @@ -333,7 +298,7 @@ fields = mysql_fetch_fields(self->result); for (i=0; i<n; i++) { PyObject *v; - c = PyTuple_GET_ITEM(self->converter, i); + c = PyTuple_GET_ITEM(self->decoders, i); v = _mysql_field_to_python(c, row[i], length[i]); if (!v) goto error; { @@ -595,12 +560,6 @@ _mysql_ResultObject_describe__doc__ }, { - "fields", - (PyCFunction)_mysql_ResultObject_fields, - METH_NOARGS, - _mysql_ResultObject_fields__doc__ - }, - { "fetch_row", (PyCFunction)_mysql_ResultObject_fetch_row, METH_VARARGS | METH_KEYWORDS, @@ -636,13 +595,19 @@ "Connection associated with result" }, { - "converter", + "decoders", T_OBJECT, - offsetof(_mysql_ResultObject, converter), + offsetof(_mysql_ResultObject, decoders), RO, - "Type conversion mapping" + "Field decoders for result set" }, - {NULL} /* Sentinel */ + { + "fields", + T_OBJECT, + offsetof(_mysql_ResultObject, fields), + RO, + "Field metadata for result set" + }, {NULL} /* Sentinel */ }; static PyObject *
--- a/tests/test_MySQLdb_capabilities.py Sun Mar 29 16:26:30 2009 +0000 +++ b/tests/test_MySQLdb_capabilities.py Mon Mar 30 20:21:24 2009 +0000 @@ -13,7 +13,7 @@ connect_kwargs = dict(db='test', read_default_file='~/.my.cnf', charset='utf8', sql_mode="ANSI,STRICT_TRANS_TABLES,TRADITIONAL") create_table_extra = "ENGINE=INNODB CHARACTER SET UTF8" - leak_test = True + leak_test = False def quote_identifier(self, ident): return "`%s`" % ident @@ -96,6 +96,15 @@ def test_ping(self): self.connection.ping() + def test_literal_int(self): + self.failUnless("2" == self.connection.literal(2)) + + def test_literal_float(self): + self.failUnless("3.1415" == self.connection.literal(3.1415)) + + def test_literal_string(self): + self.failUnless("'foo'" == self.connection.literal("foo")) + if __name__ == '__main__': if test_MySQLdb.leak_test: