git-svn-id: http://ncclient.googlecode.com/svn/trunk@109 6dbcf712-26ac-11de-a2f3...
[ncclient] / ncclient / content.py
index e6b6823..a895af1 100644 (file)
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-"TODO: docstring"
-
 from xml.etree import cElementTree as ET
 
-iselement = ET.iselement
-element2string = ET.tostring
+from ncclient import NCClientError
+
+class ContentError(NCClientError):
+    pass
 
-### Namespace-related ###
+### Namespace-related
 
 BASE_NS = 'urn:ietf:params:xml:ns:netconf:base:1.0'
 # and this is BASE_NS according to cisco devices...
@@ -36,61 +36,31 @@ except AttributeError:
 # we'd like BASE_NS to be prefixed as "netconf"
 register_namespace('netconf', BASE_NS)
 
-qualify = lambda tag, ns=BASE_NS: '{%s}%s' % (ns, tag)
+qualify = lambda tag, ns=BASE_NS: tag if ns is None else '{%s}%s' % (ns, tag)
 
-# i would have written a def if lambdas weren't so much fun
-multiqualify = lambda tag, nslist=(BASE_NS, CISCO_BS): [qualify(tag, ns)
-                                                        for ns in nslist]
+multiqualify = lambda tag, nslist=(BASE_NS, CISCO_BS): [qualify(tag, ns) for ns in nslist]
 
 unqualify = lambda tag: tag[tag.rfind('}')+1:]
 
-def namespaced_find(ele, tag, workaround=True):
-    """`workaround` is for Cisco implementations (at least the one tested), 
-    which uses an incorrect namespace.
-    """
-    found = None
-    if not workaround:
-        found = ele.find(tag)
-    else:
-        for qname in multiqualify(tag):
-            found = ele.find(qname)
-            if found is not None:
-                break
-    return found
-    
+### XML with Python data structures
 
-### Build XML using Python data structures ###
+dtree2ele = DictTree.Element
+dtree2xml = DictTree.XML
+ele2dtree = Element.DictTree
+ele2xml = Element.XML
+xml2dtree = XML.DictTree
+xml2ele = XML.Element
+
+class DictTree:
 
-class XMLConverter:
-    """Build an ElementTree.Element instance from an XML tree specification
-    based on nested dictionaries. TODO: describe spec
-    """
-    
-    def __init__(self, spec):
-        "TODO: docstring"
-        self._root = XMLConverter.build(spec)
-    
-    def to_string(self, encoding='utf-8'):
-        "TODO: docstring"
-        xml = ET.tostring(self._root, encoding)
-        # some etree versions don't include xml decl with utf-8
-        # this is a problem with some devices
-        return (xml if xml.startswith('<?xml')
-                else '<?xml version="1.0" encoding="%s"?>%s' % (encoding, xml))
-    
-    @property
-    def tree(self):
-        "TODO: docstring"
-        return self._root
-    
     @staticmethod
-    def build(spec):
-        "TODO: docstring"
+    def Element(spec):
         if iselement(spec):
             return spec
         elif isinstance(spec, basestring):
-            return ET.XML(spec)
-        ## assume isinstance(spec, dict)
+            return XML.Element(spec)
+        if not isinstance(spec, dict):
+            raise ContentError("Invalid tree spec")
         if 'tag' in spec:
             ele = ET.Element(spec.get('tag'), spec.get('attributes', {}))
             ele.text = spec.get('text', '')
@@ -100,7 +70,7 @@ class XMLConverter:
             if isinstance(subtree, dict):
                 subtree = [subtree]
             for subele in subtree:
-                ele.append(XMLConverter.build(subele))
+                ele.append(DictTree.Element(subele))
             return ele
         elif 'comment' in spec:
             return ET.Comment(spec.get('comment'))
@@ -108,15 +78,72 @@ class XMLConverter:
             raise ContentError('Invalid tree spec')
     
     @staticmethod
-    def from_string(xml):
-        return XMLConverter.parse(ET.fromstring(xml))
+    def XML(spec):
+        Element.XML(DictTree.Element(spec))
+
+class Element:
     
     @staticmethod
-    def parse(root):
+    def DictTree(ele):
         return {
-            'tag': root.tag,
-            'attributes': root.attrib,
-            'text': root.text,
-            'tail': root.tail,
-            'subtree': [ XMLConverter.parse(child) for child in root.getchildren() ]
+            'tag': ele.tag,
+            'attributes': ele.attrib,
+            'text': ele.text,
+            'tail': ele.tail,
+            'subtree': [ Element.DictTree(child) for child in root.getchildren() ]
         }
+    
+    @staticmethod
+    def XML(ele, encoding='utf-8'):
+        xml = ET.tostring(ele, encoding)
+        return xml if xml.startswith('<?xml') else '<?xml version="1.0" encoding="%s"?>\n%s' % (encoding, xml)
+
+class XML:
+    
+    @staticmethod
+    def DictTree(ele):
+        return Element.DictTree(Element.XML(ele))
+    
+    @staticmethod
+    def Element(xml):
+        return ET.fromstring(xml)
+
+### Other utility functions
+
+iselement = ET.iselement
+
+def find(ele, tag, strict=False):
+    """In strict mode, doesn't workaround Cisco implementations sending incorrectly
+    namespaced XML. Supply qualified tag name if using strict mode.
+    """
+    if strict:
+        return ele.find(tag)
+    else:
+        for qname in multiqualify(tag):
+            found = ele.find(qname)
+            if found is not None:
+                return found
+
+def parse_root(raw):
+    '''Parse the top-level element from XML string.
+    
+    Returns a `(tag, attributes)` tuple, where `tag` is a string representing
+    the qualified name of the root element and `attributes` is an
+    `{attribute: value}` dictionary.
+    '''
+    fp = StringIO(raw[:1024]) # this is a guess but start element beyond 1024 bytes would be a bit absurd
+    for event, element in ET.iterparse(fp, events=('start',)):
+        return (element.tag, element.attrib)
+
+def validated_element(rep, tag, attrs=None):
+    ele = dtree2ele(rep)
+    if ele.tag not in (tag, qualify(tag)):
+        raise ContentError("Required root element [%s] not found" % tag)
+    if attrs is not None:
+        for req in attrs:
+            for attr in ele.attrib:
+                if unqualify(attr) == req:
+                    break
+            else:
+                raise ContentError("Required attribute [%s] not found in element [%s]" % (req, req_tag))
+    return ele