aboutsummaryrefslogtreecommitdiff
blob: 757299b018428b441030ba68cea3f82c927bf490 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
--- Lib/xml/__init__.py
+++ Lib/xml/__init__.py
@@ -28,20 +28,23 @@
 _MINIMUM_XMLPLUS_VERSION = (0, 8, 4)
 
 
-try:
+def use_pyxml():
     import _xmlplus
-except ImportError:
-    pass
-else:
-    try:
-        v = _xmlplus.version_info
-    except AttributeError:
-        # _xmlplus is too old; ignore it
-        pass
+    v = _xmlplus.version_info
+    if v >= _MINIMUM_XMLPLUS_VERSION:
+        import sys
+        _xmlplus.__path__.extend(__path__)
+        sys.modules[__name__] = _xmlplus
+        cleared_modules = []
+        redefined_modules = []
+        for module in sys.modules:
+            if module.startswith("xml.") and not module.startswith(("xml.marshal", "xml.schema", "xml.utils", "xml.xpath", "xml.xslt")):
+                cleared_modules.append(module)
+            if module.startswith(("xml.__init__", "xml.dom", "xml.parsers", "xml.sax")) and sys.modules[module] is not None:
+                redefined_modules.append(module)
+        for module in cleared_modules:
+            del sys.modules[module]
+        for module in sorted(redefined_modules):
+            __import__(module)
     else:
-        if v >= _MINIMUM_XMLPLUS_VERSION:
-            import sys
-            _xmlplus.__path__.extend(__path__)
-            sys.modules[__name__] = _xmlplus
-        else:
-            del v
+        raise ImportError("PyXML too old: %s" % ".".join(str(x) for x in v))