changeset 32:4a5668deee4a MySQLdb

Merge back r554 for deprecated sets module
author kylev
date Wed, 11 Feb 2009 23:45:57 +0000
parents 10038670b963
children 7c7b7114864b
files MySQLdb/__init__.py
diffstat 1 files changed, 20 insertions(+), 31 deletions(-) [+]
line wrap: on
line diff
--- a/MySQLdb/__init__.py	Wed Feb 11 22:15:41 2009 +0000
+++ b/MySQLdb/__init__.py	Wed Feb 11 23:45:57 2009 +0000
@@ -30,36 +30,19 @@
 from MySQLdb.times import Date, Time, Timestamp, \
     DateFromTicks, TimeFromTicks, TimestampFromTicks
 
-from sets import ImmutableSet
+try:
+    frozenset
+except NameError:
+    from sets import ImmutableSet as frozenset
 
-class DBAPISet(ImmutableSet):
-
+class DBAPISet(frozenset):
     """A special type of set for which A == x is True if A is a
     DBAPISet and x is a member of that set.
-    
-      >>> from MySQLdb.constants import FIELD_TYPE
-      >>> FIELD_TYPE.VAR_STRING == STRING
-      True
-      >>> FIELD_TYPE.DATE == NUMBER
-      False
-      >>> FIELD_TYPE.DATE != DATE
-      False
-      
     """
-
-    def __ne__(self, other):
-        from sets import BaseSet
-        if isinstance(other, BaseSet):
-            return super(DBAPISet, self).__ne__(self, other)
-        else:
-            return other not in self
-
     def __eq__(self, other):
-        from sets import BaseSet
-        if isinstance(other, BaseSet):
-            return super(DBAPISet, self).__eq__(self, other)
-        else:
-            return other in self
+        if isinstance(other, DBAPISet):
+            return not self.difference(other)
+        return other in self
 
 
 STRING    = DBAPISet([FIELD_TYPE.ENUM, FIELD_TYPE.STRING,
@@ -75,6 +58,18 @@
 DATETIME  = TIMESTAMP
 ROWID     = DBAPISet()
 
+def test_DBAPISet_set_equality():
+    assert STRING == STRING
+
+def test_DBAPISet_set_inequality():
+    assert STRING != NUMBER
+
+def test_DBAPISet_set_equality_membership():
+    assert FIELD_TYPE.VAR_STRING == STRING
+
+def test_DBAPISet_set_inequality_membership():
+    assert FIELD_TYPE.DATE != STRING
+
 def Binary(x):
     """Return x as a binary type."""
     return str(x)
@@ -97,9 +92,3 @@
     'escape_dict', 'escape_sequence', 'escape_string', 'get_client_info',
     'paramstyle', 'string_literal', 'threadsafety', 'version_info',
     ]
-
-
-if __name__ == "__main__":
-    import doctest
-    doctest.testmod()
-