[Lxml-checkins] r40929 - lxml/branch/extension_refactoring/src/lxml

scoder at codespeak.net scoder at codespeak.net
Wed Mar 21 14:48:10 CET 2007


Author: scoder
Date: Wed Mar 21 14:48:09 2007
New Revision: 40929

Modified:
   lxml/branch/extension_refactoring/src/lxml/extensions.pxi
   lxml/branch/extension_refactoring/src/lxml/xpath.pxi
   lxml/branch/extension_refactoring/src/lxml/xslt.pxi
Log:
another major rewrite of extension function registration, cleanup

Modified: lxml/branch/extension_refactoring/src/lxml/extensions.pxi
==============================================================================
--- lxml/branch/extension_refactoring/src/lxml/extensions.pxi	(original)
+++ lxml/branch/extension_refactoring/src/lxml/extensions.pxi	Wed Mar 21 14:48:09 2007
@@ -1,4 +1,4 @@
-# supports for extension functions in XPath and XSLT
+# support for extension functions in XPath and XSLT
 
 class XPathError(LxmlError):
     pass
@@ -9,17 +9,20 @@
 class XPathResultError(XPathError):
     pass
 
-################################################################################
-# Base class for XSLT and XPath evaluation contexts: functions, namespaces, ...
+# forward declarations
 
 ctypedef int _register_function(void* ctxt, name_utf, ns_uri_utf)
+cdef class _ExsltRegExp
+
+################################################################################
+# Base class for XSLT and XPath evaluation contexts: functions, namespaces, ...
 
 cdef class _BaseContext:
     cdef xpath.xmlXPathContext* _xpathCtxt
     cdef _Document _doc
     cdef object _extensions
     cdef object _namespaces
-    cdef object _registered_namespaces
+    cdef object _global_namespaces
     cdef object _utf_refs
     cdef object _function_cache
     cdef object _function_cache_ns
@@ -28,10 +31,10 @@
     cdef _TempStore _temp_refs
     cdef _ExceptionContext _exc
 
-    def __init__(self, namespaces, extensions):
-        self._xpathCtxt = NULL
+    def __init__(self, namespaces, extensions, enable_regexp):
+        cdef _ExsltRegExp _regexp 
         self._utf_refs = {}
-        self._registered_namespaces = []
+        self._global_namespaces = []
         self._function_cache = {}
         self._function_cache_ns = {}
 
@@ -39,7 +42,7 @@
             # convert extensions to UTF-8
             if python.PyDict_Check(extensions):
                 extensions = (extensions,)
-            # format: [ {(ns,name):function} ] -> {(ns_utf,name_utf):function}
+            # format: [ {(ns, name):function} ] -> {(ns_utf, name_utf):function}
             new_extensions = {}
             for extension in extensions:
                 for (ns_uri, name), function in extension.items():
@@ -52,17 +55,38 @@
                         new_extensions, (ns_utf, name_utf), function)
             extensions = new_extensions or None
 
+        if namespaces is not None:
+            if python.PyDict_Check(namespaces):
+                namespaces = namespaces.items()
+            if namespaces:
+                ns = []
+                for prefix, ns_uri in namespaces:
+                    if prefix is None:
+                        raise TypeError, \
+                              "empty namespace prefix is not supported in XPath"
+                    if ns_uri is None:
+                        raise TypeError, \
+                              "setting default namespace is not supported in XPath"
+                    prefix_utf = self._to_utf(prefix)
+                    ns_uri_utf = self._to_utf(ns_uri)
+                    python.PyList_Append(ns, (prefix_utf, ns_uri_utf))
+                namespaces = ns
+
         self._doc        = None
         self._exc        = _ExceptionContext()
         self._extensions = extensions
         self._namespaces = namespaces
         self._temp_refs = _TempStore()
 
+        if enable_regexp:
+            _regexp = _ExsltRegExp()
+            _regexp._register_in_context(self)
+
     cdef _copy(self):
         cdef _BaseContext context
         if self._namespaces is not None:
-            namespaces = python.PyDict_Copy(self._namespaces)
-        context = self.__class__(namespaces, None)
+            namespaces = self._namespaces[:]
+        context = self.__class__(namespaces, None, False)
         if self._extensions is not None:
             context._extensions = python.PyDict_Copy(self._extensions)
         return context
@@ -86,57 +110,72 @@
     cdef _register_context(self, _Document doc):
         self._doc = doc
         self._exc.clear()
-        python.PyDict_Clear(self._function_cache)
-        python.PyDict_Clear(self._function_cache_ns)
-        namespaces = self._namespaces
-        if namespaces is not None:
-            self.registerNamespaces(namespaces)
 
-    cdef _unregister_context(self):
-        self._unregisterNamespaces()
-#        xpath.xmlXPathRegisteredNsCleanup(self._xpathCtxt)
-        self._free_context()
-
-    cdef _free_context(self):
+    cdef _cleanup_context(self):
+        #xpath.xmlXPathRegisteredNsCleanup(self._xpathCtxt)
+        #self.unregisterGlobalNamespaces()
         python.PyDict_Clear(self._utf_refs)
         self._doc = None
+
+    cdef _release_context(self):
         if self._xpathCtxt is not NULL:
             self._xpathCtxt.userData = NULL
             self._xpathCtxt = NULL
 
     # namespaces (internal UTF-8 methods with leading '_')
 
-    cdef addNamespace(self, prefix, uri):
+    cdef addNamespace(self, prefix, ns_uri):
+        if prefix is None:
+            raise TypeError, "empty prefix is not supported in XPath"
+        prefix_utf = self._to_utf(prefix)
+        ns_uri_utf = self._to_utf(ns_uri)
+        new_item = (prefix_utf, ns_uri_utf)
         if self._namespaces is None:
-            self._namespaces = {}
-        python.PyDict_SetItem(self._namespaces, prefix, uri)
+            self._namespaces = [new_item]
+        else:
+            namespaces = []
+            for item in self._namespaces:
+                if item[0] == prefix_utf:
+                    item = new_item
+                    new_item = None
+                python.PyList_Append(namespaces, item)
+            if new_item is not None:
+                python.PyList_Append(namespaces, new_item)
+            self._namespaces = namespaces
+        if self._xpathCtxt is not NULL:
+            xpath.xmlXPathRegisterNs(
+                self._xpathCtxt, _cstr(prefix_utf), _cstr(ns_uri_utf))
 
-    cdef registerNamespaces(self, namespaces):
-        for prefix, uri in namespaces.items():
-            self.registerNamespace(prefix, uri)
-    
     cdef registerNamespace(self, prefix, ns_uri):
         if prefix is None:
             raise TypeError, "empty prefix is not supported in XPath"
         prefix_utf = self._to_utf(prefix)
         ns_uri_utf = self._to_utf(ns_uri)
-        python.PyList_Append(self._registered_namespaces, prefix_utf)
+        python.PyList_Append(self._global_namespaces, prefix_utf)
         xpath.xmlXPathRegisterNs(self._xpathCtxt,
                                  _cstr(prefix_utf), _cstr(ns_uri_utf))
 
-    cdef _registerNamespace(self, prefix_utf, ns_uri_utf):
-        python.PyList_Append(self._registered_namespaces, prefix_utf)
-        xpath.xmlXPathRegisterNs(self._xpathCtxt,
-                                 _cstr(prefix_utf), _cstr(ns_uri_utf))
-    
-    cdef void _unregisterNamespaces(self):
-        if python.PyList_GET_SIZE(self._registered_namespaces) > 0:
-            for prefix_utf in self._registered_namespaces:
-                sys.stderr.write(prefix_utf)
-                sys.stderr.flush()
+    cdef registerLocalNamespaces(self):
+        if self._namespaces is None:
+            return
+        for prefix_utf, ns_uri_utf in self._namespaces:
+            xpath.xmlXPathRegisterNs(
+                self._xpathCtxt, _cstr(prefix_utf), _cstr(ns_uri_utf))
+
+    cdef registerGlobalNamespaces(self):
+        ns_prefixes = _find_all_extension_prefixes()
+        if python.PyList_GET_SIZE(ns_prefixes) > 0:
+            for prefix_utf, ns_uri_utf in ns_prefixes:
+                python.PyList_Append(self._global_namespaces, prefix_utf)
+                xpath.xmlXPathRegisterNs(
+                    self._xpathCtxt, _cstr(prefix_utf), _cstr(ns_uri_utf))
+
+    cdef unregisterGlobalNamespaces(self):
+        if python.PyList_GET_SIZE(self._global_namespaces) > 0:
+            for prefix_utf in self._global_namespaces:
                 xpath.xmlXPathRegisterNs(self._xpathCtxt,
                                          _cstr(prefix_utf), NULL)
-            self._registered_namespaces = []
+            del self._global_namespaces[:]
     
     cdef void _unregisterNamespace(self, prefix_utf):
         xpath.xmlXPathRegisterNs(self._xpathCtxt,
@@ -149,7 +188,7 @@
             self._extensions = {}
         python.PyDict_SetItem(self._extensions, (ns_utf, name_utf), function)
 
-    cdef void _registerAllFunctions(self, void* ctxt,
+    cdef void registerGlobalFunctions(self, void* ctxt,
                                     _register_function reg_func):
         cdef python.PyObject* dict_result
         for ns_utf, ns_functions in _iter_ns_extension_functions():
@@ -167,6 +206,10 @@
             for name_utf, function in ns_functions.iteritems():
                 python.PyDict_SetItem(d, name_utf, function)
                 reg_func(ctxt, name_utf, ns_utf)
+
+    cdef void registerLocalFunctions(self, void* ctxt,
+                                      _register_function reg_func):
+        cdef python.PyObject* dict_result
         if self._extensions is None:
             return # done
         last_ns = None
@@ -188,7 +231,7 @@
             python.PyDict_SetItem(d, name_utf, function)
             reg_func(ctxt, name_utf, ns_utf)
 
-    cdef void _unregisterAllFunctions(self, void* ctxt,
+    cdef unregisterAllFunctions(self, void* ctxt,
                                       _register_function unreg_func):
         for name_utf in self._function_cache:
             unreg_func(ctxt, name_utf, None)
@@ -196,6 +239,18 @@
             for name_utf in functions:
                 unreg_func(ctxt, name_utf, ns_utf)
 
+    cdef unregisterGlobalFunctions(self, void* ctxt,
+                                         _register_function unreg_func):
+        for name_utf in self._function_cache:
+            if self._extensions is None or \
+                   (None, name_utf) not in self._extensions:
+                unreg_func(ctxt, name_utf, None)
+        for ns_utf, functions in self._function_cache_ns.iteritems():
+            for name_utf in functions:
+                if self._extensions is None or \
+                       (ns_utf, name_utf) not in self._extensions:
+                    unreg_func(ctxt, name_utf, ns_utf)
+
     cdef _find_cached_function(self, char* c_ns_uri, char* c_name):
         """Lookup an extension function in the cache and return it.
 
@@ -215,7 +270,7 @@
                 return <object>dict_result
         return None
 
-    cdef int _prepare_function_call(self, char* c_ns_uri, char* c_name):
+    cdef int __prepare_function_call(self, char* c_ns_uri, char* c_name):
         """Find an extension function and store it in 'self._called_function'.
 
         This is absolutely performance-critical for XPath/XSLT!
@@ -393,27 +448,6 @@
 ################################################################################
 # helper functions
 
-cdef xpath.xmlXPathFunction _function_check(void* ctxt,
-                                            char* c_name, char* c_ns_uri):
-    cdef python.PyGILState_STATE gil_state
-    cdef xpath.xmlXPathFunction c_func
-    gil_state = python.PyGILState_Ensure()
-    c_func = _python_function_check(ctxt, c_name, c_ns_uri)
-    python.PyGILState_Release(gil_state)
-    return c_func
-
-cdef xpath.xmlXPathFunction _python_function_check(void* ctxt,
-                                                   char* c_name, char* c_ns_uri):
-    "Module level lookup function for XPath/XSLT functions"
-    cdef xpath.xmlXPathFunction c_func
-    cdef _BaseContext context
-    context = <_BaseContext>ctxt
-    if context._prepare_function_call(c_ns_uri, c_name):
-        c_func = _call_prepared_function
-    else:
-        c_func = NULL
-    return c_func
-
 cdef xpath.xmlXPathObject* _wrapXPathObject(object obj) except NULL:
     cdef xpath.xmlNodeSet* resultSet
     cdef _Element node
@@ -575,19 +609,3 @@
         xpath.xmlXPathErr(ctxt, xpath.XPATH_UNKNOWN_FUNC_ERROR)
         exception = XPathFunctionError("XPath function '%s' not found" % fref)
         context._exc._store_exception(exception)
-
-# call the function that was stored in 'context._called_function'
-
-cdef void _call_prepared_function(xpath.xmlXPathParserContext* ctxt, int nargs):
-    cdef python.PyGILState_STATE gil_state
-    gil_state = python.PyGILState_Ensure()
-    _call_prepared_python_function(ctxt, nargs)
-    python.PyGILState_Release(gil_state)
-
-cdef void _call_prepared_python_function(xpath.xmlXPathParserContext* ctxt,
-                                         int nargs):
-    cdef xpath.xmlXPathContext* rctxt
-    cdef _BaseContext context
-    rctxt = ctxt.context
-    context = <_BaseContext>(rctxt.userData)
-    _extension_function_call(context, context._called_function, ctxt, nargs)

Modified: lxml/branch/extension_refactoring/src/lxml/xpath.pxi
==============================================================================
--- lxml/branch/extension_refactoring/src/lxml/xpath.pxi	(original)
+++ lxml/branch/extension_refactoring/src/lxml/xpath.pxi	Wed Mar 21 14:48:09 2007
@@ -30,29 +30,29 @@
 
 cdef class _XPathContext(_BaseContext):
     cdef object _variables
-    def __init__(self, namespaces, extensions, variables):
+    def __init__(self, namespaces, extensions, enable_regexp, variables):
         self._variables = variables
-        _BaseContext.__init__(self, namespaces, extensions)
-        
-    cdef register_context(self, xpath.xmlXPathContext* xpathCtxt, _Document doc):
+        _BaseContext.__init__(self, namespaces, extensions, enable_regexp)
+
+    cdef set_context(self, xpath.xmlXPathContext* xpathCtxt):
         self._set_xpath_context(xpathCtxt)
-        ns_prefixes = _find_all_extension_prefixes()
-        if python.PyList_GET_SIZE(ns_prefixes) > 0:
-            for (prefix, ns_uri) in ns_prefixes:
-                self._registerNamespace(prefix, ns_uri)
+        self._setupDict(xpathCtxt)
+        self.registerLocalNamespaces()
+        self.registerLocalFunctions(xpathCtxt, _register_xpath_function)
+
+    cdef register_context(self, _Document doc):
         self._register_context(doc)
+        self.registerGlobalNamespaces()
+        self.registerGlobalFunctions(self._xpathCtxt, _register_xpath_function)
         if self._variables is not None:
             self.registerVariables(self._variables)
-        self._registerAllFunctions(xpathCtxt, _register_xpath_function)
 
     cdef unregister_context(self):
-        cdef xpath.xmlXPathContext* xpathCtxt
-        xpathCtxt = self._xpathCtxt
-        if xpathCtxt is NULL:
-            return
-        xpath.xmlXPathRegisteredVariablesCleanup(xpathCtxt)
-        self._unregisterAllFunctions(xpathCtxt, _unregister_xpath_function)
-        self._unregister_context()
+        self.unregisterGlobalFunctions(
+            self._xpathCtxt, _unregister_xpath_function)
+        self.unregisterGlobalNamespaces()
+        xpath.xmlXPathRegisteredVariablesCleanup(self._xpathCtxt)
+        self._cleanup_context()
 
     cdef registerVariables(self, variable_dict):
         for name, value in variable_dict.items():
@@ -69,25 +69,26 @@
         xpath.xmlXPathRegisterVariable(
             self._xpathCtxt, _cstr(name_utf), _wrapXPathObject(value))
 
-cdef void _setupDict(xpath.xmlXPathContext* xpathCtxt):
-    __GLOBAL_PARSER_CONTEXT.initXPathParserDict(xpathCtxt)
+    cdef void _setupDict(self, xpath.xmlXPathContext* xpathCtxt):
+        __GLOBAL_PARSER_CONTEXT.initXPathParserDict(xpathCtxt)
 
 cdef class _XPathEvaluatorBase:
     cdef xpath.xmlXPathContext* _xpathCtxt
     cdef _XPathContext _context
     cdef python.PyThread_type_lock _eval_lock
 
-    def __init__(self, namespaces, extensions, regexp):
-        cdef _ExsltRegExp _regexp 
-        self._context = _XPathContext(namespaces, extensions, None)
-        if regexp:
-            _regexp = _ExsltRegExp()
-            _regexp._register_in_context(self._context)
+    def __init__(self, namespaces, extensions, enable_regexp):
+        self._context = _XPathContext(namespaces, extensions,
+                                      enable_regexp, None)
 
     def __dealloc__(self):
         if self._xpathCtxt is not NULL:
             xpath.xmlXPathFreeContext(self._xpathCtxt)
 
+    cdef set_context(self, xpath.xmlXPathContext* xpathCtxt):
+        self._xpathCtxt = xpathCtxt
+        self._context.set_context(xpathCtxt)
+
     def evaluate(self, _eval_arg, **_variables):
         """Evaluate an XPath expression.
 
@@ -170,14 +171,13 @@
         cdef xpath.xmlXPathContext* xpathCtxt
         cdef int ns_register_status
         cdef _Document doc
+        self._element = element
         doc = element._doc
+        _XPathEvaluatorBase.__init__(self, namespaces, extensions, regexp)
         xpathCtxt = xpath.xmlXPathNewContext(doc._c_doc)
-        self._xpathCtxt = xpathCtxt
         if xpathCtxt is NULL:
             raise XPathContextError, "Unable to create new XPath context"
-        _setupDict(xpathCtxt)
-        self._element = element
-        _XPathEvaluatorBase.__init__(self, namespaces, extensions, regexp)
+        self.set_context(xpathCtxt)
 
     def registerNamespace(self, prefix, uri):
         """Register a namespace with the XPath context.
@@ -200,27 +200,27 @@
         against the ElementTree as returned by getroottree().
         """
         cdef python.PyThreadState* state
-        cdef xpath.xmlXPathContext* xpathCtxt
         cdef xpath.xmlXPathObject*  xpathObj
         cdef _Document doc
         cdef char* c_path
         path = _utf8(_path)
-        xpathCtxt = self._xpathCtxt
-        xpathCtxt.node = self._element._c_node
         doc = self._element._doc
 
         self._lock()
-        self._context.register_context(xpathCtxt, doc)
+        self._xpathCtxt.node = self._element._c_node
         try:
+            self._context.register_context(doc)
             self._context.registerVariables(_variables)
             state = python.PyEval_SaveThread()
-            xpathObj = xpath.xmlXPathEvalExpression(_cstr(path), xpathCtxt)
-        finally:
+            xpathObj = xpath.xmlXPathEvalExpression(
+                _cstr(path), self._xpathCtxt)
             python.PyEval_RestoreThread(state)
+            result = self._handle_result(xpathObj, doc)
+        finally:
             self._context.unregister_context()
             self._unlock()
 
-        return self._handle_result(xpathObj, doc)
+        return result
 
 
 cdef class XPathDocumentEvaluator(XPathElementEvaluator):
@@ -242,30 +242,32 @@
         are currently not supported for variables.
         """
         cdef python.PyThreadState* state
-        cdef xpath.xmlXPathContext* xpathCtxt
         cdef xpath.xmlXPathObject*  xpathObj
         cdef xmlDoc* c_doc
         cdef _Document doc
         path = _utf8(_path)
-        xpathCtxt = self._xpathCtxt
         doc = self._element._doc
 
         self._lock()
-        self._context.register_context(xpathCtxt, doc)
-        c_doc = _fakeRootDoc(doc._c_doc, self._element._c_node)
         try:
-            self._context.registerVariables(_variables)
-            state = python.PyEval_SaveThread()
-            xpathCtxt.doc  = c_doc
-            xpathCtxt.node = tree.xmlDocGetRootElement(c_doc)
-            xpathObj = xpath.xmlXPathEvalExpression(_cstr(path), xpathCtxt)
+            self._context.register_context(doc)
+            c_doc = _fakeRootDoc(doc._c_doc, self._element._c_node)
+            try:
+                self._context.registerVariables(_variables)
+                state = python.PyEval_SaveThread()
+                self._xpathCtxt.doc  = c_doc
+                self._xpathCtxt.node = tree.xmlDocGetRootElement(c_doc)
+                xpathObj = xpath.xmlXPathEvalExpression(
+                    _cstr(path), self._xpathCtxt)
+                python.PyEval_RestoreThread(state)
+                result = self._handle_result(xpathObj, doc)
+            finally:
+                _destroyFakeDoc(doc._c_doc, c_doc)
+                self._context.unregister_context()
         finally:
-            python.PyEval_RestoreThread(state)
-            _destroyFakeDoc(doc._c_doc, c_doc)
-            self._context.unregister_context()
             self._unlock()
 
-        return self._handle_result(xpathObj, doc)
+        return result
 
 
 def XPathEvaluator(etree_or_element, namespaces=None, extensions=None,
@@ -300,19 +302,20 @@
     cdef readonly object path
 
     def __init__(self, path, namespaces=None, extensions=None, regexp=True):
+        cdef xpath.xmlXPathContext* xpathCtxt
         _XPathEvaluatorBase.__init__(self, namespaces, extensions, regexp)
-        self._xpath = NULL
         self.path = path
         path = _utf8(path)
-        self._xpathCtxt = xpath.xmlXPathNewContext(NULL)
-        _setupDict(self._xpathCtxt)
-        self._xpath = xpath.xmlXPathCtxtCompile(self._xpathCtxt, _cstr(path))
+        xpathCtxt = xpath.xmlXPathNewContext(NULL)
+        if xpathCtxt is NULL:
+            raise XPathContextError, "Unable to create new XPath context"
+        self.set_context(xpathCtxt)
+        self._xpath = xpath.xmlXPathCtxtCompile(xpathCtxt, _cstr(path))
         if self._xpath is NULL:
             self._raise_parse_error()
 
     def __call__(self, _etree_or_element, **_variables):
         cdef python.PyThreadState* state
-        cdef xpath.xmlXPathContext* xpathCtxt
         cdef xpath.xmlXPathObject*  xpathObj
         cdef _Document document
         cdef _Element element
@@ -325,18 +328,18 @@
         self._xpathCtxt.doc  = document._c_doc
         self._xpathCtxt.node = element._c_node
 
-        context = self._context
-        context.register_context(self._xpathCtxt, document)
-        context.registerVariables(_variables)
         try:
+            self._context.register_context(document)
+            self._context.registerVariables(_variables)
             state = python.PyEval_SaveThread()
             xpathObj = xpath.xmlXPathCompiledEval(
                 self._xpath, self._xpathCtxt)
-        finally:
             python.PyEval_RestoreThread(state)
-            context.unregister_context()
+            result = self._handle_result(xpathObj, document)
+        finally:
+            self._context.unregister_context()
             self._unlock()
-        return self._handle_result(xpathObj, document)
+        return result
 
     def __dealloc__(self):
         if self._xpath is not NULL:

Modified: lxml/branch/extension_refactoring/src/lxml/xslt.pxi
==============================================================================
--- lxml/branch/extension_refactoring/src/lxml/xslt.pxi	(original)
+++ lxml/branch/extension_refactoring/src/lxml/xslt.pxi	Wed Mar 21 14:48:09 2007
@@ -210,28 +210,29 @@
 
 cdef class _XSLTContext(_BaseContext):
     cdef xslt.xsltTransformContext* _xsltCtxt
-    def __init__(self, namespaces, extensions):
+    def __init__(self, namespaces, extensions, enable_regexp):
         self._xsltCtxt = NULL
-        if extensions and None in extensions:
-            raise XSLTExtensionError, "extensions must not have empty namespaces"
-        _BaseContext.__init__(self, namespaces, extensions)
+        if extensions is not None:
+            for ns, prefix in extensions:
+                if ns is None:
+                    raise XSLTExtensionError, \
+                          "extensions must not have empty namespaces"
+        _BaseContext.__init__(self, namespaces, extensions, enable_regexp)
 
     cdef register_context(self, xslt.xsltTransformContext* xsltCtxt,
                                _Document doc):
         self._xsltCtxt = xsltCtxt
         self._set_xpath_context(xsltCtxt.xpathCtxt)
         self._register_context(doc)
-        xsltCtxt.xpathCtxt.userData = <void*>self
-        self._registerAllFunctions(xsltCtxt, _register_xslt_function)
+        self.registerLocalFunctions(xsltCtxt, _register_xslt_function)
+        self.registerGlobalFunctions(xsltCtxt, _register_xslt_function)
 
     cdef free_context(self):
-        cdef xslt.xsltTransformContext* xsltCtxt
-        xsltCtxt = self._xsltCtxt
-        if xsltCtxt is NULL:
-            return
-        self._free_context()
-        self._xsltCtxt = NULL
-        xslt.xsltFreeTransformContext(xsltCtxt)
+        self._cleanup_context()
+        self._release_context()
+        if self._xsltCtxt is not NULL:
+            xslt.xsltFreeTransformContext(self._xsltCtxt)
+            self._xsltCtxt = NULL
         self._release_temp_refs()
 
 
@@ -253,7 +254,8 @@
     cdef XSLTAccessControl _access_control
     cdef _ErrorLog _error_log
 
-    def __init__(self, xslt_input, extensions=None, regexp=True, access_control=None):
+    def __init__(self, xslt_input, extensions=None, regexp=True,
+                 access_control=None):
         cdef python.PyThreadState* state
         cdef xslt.xsltStylesheet* c_style
         cdef xmlDoc* c_doc
@@ -299,10 +301,7 @@
         c_doc._private = NULL # no longer used!
         self._c_style = c_style
 
-        self._context = _XSLTContext(None, extensions)
-        if regexp:
-            _regexp = _ExsltRegExp()
-            _regexp._register_in_context(self._context)
+        self._context = _XSLTContext(None, extensions, regexp)
 
     def __dealloc__(self):
         if self._xslt_resolver_context is not None and \
@@ -315,20 +314,24 @@
         def __get__(self):
             return self._error_log.copy()
 
+    def apply(self, _input, profile_run=False, **_kw):
+        return self(_input, profile_run, **_kw)
+
+    def tostring(self, _ElementTree result_tree):
+        """Save result doc to string based on stylesheet output method.
+        """
+        return str(result_tree)
+
     def __call__(self, _input, profile_run=False, **_kw):
-        cdef python.PyThreadState* state
         cdef _XSLTContext context
         cdef _Document input_doc
         cdef _Element root_node
         cdef _Document result_doc
         cdef _Document profile_doc
         cdef xmlDoc* c_profile_doc
-        cdef _XSLTResolverContext resolver_context
         cdef xslt.xsltTransformContext* transform_ctxt
         cdef xmlDoc* c_result
         cdef xmlDoc* c_doc
-        cdef char** params
-        cdef Py_ssize_t i, kw_count
 
         if not _checkThreadDict(self._c_style.doc.dict):
             raise RuntimeError, "stylesheet is not usable in this thread"
@@ -336,9 +339,6 @@
         input_doc = _documentOrRaise(_input)
         root_node = _rootNodeOrRaise(_input)
 
-        resolver_context = _XSLTResolverContext(input_doc._parser)
-        resolver_context._c_style_doc = self._xslt_resolver_context._c_style_doc
-
         c_doc = _fakeRootDoc(input_doc._c_doc, root_node._c_node)
 
         transform_ctxt = xslt.xsltNewTransformContext(self._c_style, c_doc)
@@ -348,28 +348,82 @@
 
         initTransformDict(transform_ctxt)
 
-        self._error_log.connect()
+        if profile_run:
+            transform_ctxt.profile = 1
+
+        try:
+            self._error_log.connect()
+            context = self._context._copy()
+            context.register_context(transform_ctxt, input_doc)
+
+            c_result = self._run_transform(
+                input_doc, c_doc, _kw, context, transform_ctxt)
+
+            if transform_ctxt.profile:
+                c_profile_doc = xslt.xsltGetProfileInformation(transform_ctxt)
+                if c_profile_doc is not NULL:
+                    profile_doc = _documentFactory(
+                        c_profile_doc, input_doc._parser)
+        finally:
+            if context is not None:
+                context.free_context()
+            _destroyFakeDoc(input_doc._c_doc, c_doc)
+            self._error_log.disconnect()
+
+        try:
+            if self._xslt_resolver_context._has_raised():
+                if c_result is not NULL:
+                    tree.xmlFreeDoc(c_result)
+                self._xslt_resolver_context._raise_if_stored()
+
+            if c_result is NULL:
+                error = self._error_log.last_error
+                if error is not None and error.message:
+                    if error.line >= 0:
+                        message = "%s, line %d" % (error.message, error.line)
+                    else:
+                        message = error.message
+                elif error.line >= 0:
+                    message = "Error applying stylesheet, line %d" % error.line
+                else:
+                    message = "Error applying stylesheet"
+                raise XSLTApplyError, message
+        finally:
+            self._xslt_resolver_context.clear()
+
+        result_doc = _documentFactory(c_result, input_doc._parser)
+        return _xsltResultTreeFactory(result_doc, self, profile_doc)
+
+    cdef xmlDoc* _run_transform(self, _Document input_doc, xmlDoc* c_input_doc,
+                               parameters, _XSLTContext context,
+                               xslt.xsltTransformContext* transform_ctxt):
+        cdef python.PyThreadState* state
+        cdef _XSLTResolverContext resolver_context
+        cdef xmlDoc* c_result
+        cdef char** params
+        cdef Py_ssize_t i, parameter_count
+
+        resolver_context = _XSLTResolverContext(input_doc._parser)
+        resolver_context._c_style_doc = self._xslt_resolver_context._c_style_doc
+
         xslt.xsltSetTransformErrorFunc(transform_ctxt, <void*>self._error_log,
                                        _receiveXSLTError)
 
         if self._access_control is not None:
             self._access_control._register_in_context(transform_ctxt)
 
-        if profile_run:
-            transform_ctxt.profile = 1
-
         transform_ctxt._private = <python.PyObject*>self._xslt_resolver_context
 
-        kw_count = python.PyDict_Size(_kw)
-        if kw_count > 0:
+        parameter_count = python.PyDict_Size(parameters)
+        if parameter_count > 0:
             # allocate space for parameters
             # * 2 as we want an entry for both key and value,
             # and + 1 as array is NULL terminated
             params = <char**>python.PyMem_Malloc(
-                sizeof(char*) * (kw_count * 2 + 1))
+                sizeof(char*) * (parameter_count * 2 + 1))
             i = 0
             keep_ref = []
-            for key, value in _kw.iteritems():
+            for key, value in parameters.iteritems():
                 k = _utf8(key)
                 python.PyList_Append(keep_ref, k)
                 v = _utf8(value)
@@ -382,59 +436,16 @@
         else:
             params = NULL
 
-        context = self._context._copy()
-        context.register_context(transform_ctxt, input_doc)
-
         state = python.PyEval_SaveThread()
-        c_result = xslt.xsltApplyStylesheetUser(self._c_style, c_doc, params,
-                                                NULL, NULL, transform_ctxt)
+        c_result = xslt.xsltApplyStylesheetUser(
+            self._c_style, c_input_doc, params, NULL, NULL, transform_ctxt)
         python.PyEval_RestoreThread(state)
 
         if params is not NULL:
             # deallocate space for parameters
             python.PyMem_Free(params)
-            keep_ref = None
-
-        if transform_ctxt.profile:
-            c_profile_doc = xslt.xsltGetProfileInformation(transform_ctxt)
-            if c_profile_doc is not NULL:
-                profile_doc = _documentFactory(c_profile_doc, input_doc._parser)
 
-        context.free_context()
-        _destroyFakeDoc(input_doc._c_doc, c_doc)
-
-        self._error_log.disconnect()
-        try:
-            if self._xslt_resolver_context._has_raised():
-                if c_result is not NULL:
-                    tree.xmlFreeDoc(c_result)
-                self._xslt_resolver_context._raise_if_stored()
-
-            if c_result is NULL:
-                error = self._error_log.last_error
-                if error is not None and error.message:
-                    if error.line >= 0:
-                        message = "%s, line %d" % (error.message, error.line)
-                    else:
-                        message = error.message
-                elif error.line >= 0:
-                    message = "Error applying stylesheet, line %d" % error.line
-                else:
-                    message = "Error applying stylesheet"
-                raise XSLTApplyError, message
-        finally:
-            self._xslt_resolver_context.clear()
-
-        result_doc = _documentFactory(c_result, input_doc._parser)
-        return _xsltResultTreeFactory(result_doc, self, profile_doc)
-
-    def apply(self, _input, profile_run=False, **_kw):
-        return self(_input, profile_run, **_kw)
-
-    def tostring(self, _ElementTree result_tree):
-        """Save result doc to string based on stylesheet output method.
-        """
-        return str(result_tree)
+        return c_result
 
 cdef class _XSLTResultTree(_ElementTree):
     cdef XSLT _xslt
@@ -511,17 +522,6 @@
 # enable EXSLT support for XSLT
 xslt.exsltRegisterAll()
 
-# extension function lookup for XSLT
-cdef xpath.xmlXPathFunction _xslt_function_check(void* ctxt,
-                                                 char* c_name, char* c_ns_uri):
-    "Find XSLT extension function from set of XPath and XSLT functions"
-    cdef xpath.xmlXPathFunction result
-    result = _function_check(ctxt, c_name, c_ns_uri)
-    if result is NULL:
-        return xslt.xsltExtModuleFunctionLookup(c_name, c_ns_uri)
-    else:
-        return result
-
 cdef void initTransformDict(xslt.xsltTransformContext* transform_ctxt):
     __GLOBAL_PARSER_CONTEXT.initThreadDictRef(&transform_ctxt.dict)
 


More information about the lxml-checkins mailing list