git-svn-id: http://ncclient.googlecode.com/svn/trunk@109 6dbcf712-26ac-11de-a2f3...
[ncclient] / ncclient / content.py
index 8ce888b..a895af1 100644 (file)
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-"TODO: docstring"
-
 from xml.etree import cElementTree as ET
 
-iselement = ET.iselement
+from ncclient import NCClientError
 
-### Namespace-related ###
+class ContentError(NCClientError):
+    pass
+
+### Namespace-related
 
 BASE_NS = 'urn:ietf:params:xml:ns:netconf:base:1.0'
 # and this is BASE_NS according to cisco devices...
@@ -35,56 +36,114 @@ except AttributeError:
 # we'd like BASE_NS to be prefixed as "netconf"
 register_namespace('netconf', BASE_NS)
 
-qualify = lambda tag, ns=BASE_NS: '{%s}%s' % (ns, tag)
+qualify = lambda tag, ns=BASE_NS: tag if ns is None else '{%s}%s' % (ns, tag)
 
-# i would have written a def if lambdas weren't so much fun
-multiqualify = lambda tag, nslist=(BASE_NS, CISCO_BS): [qualify(tag, ns)
-                                                        for ns in nslist]
+multiqualify = lambda tag, nslist=(BASE_NS, CISCO_BS): [qualify(tag, ns) for ns in nslist]
 
 unqualify = lambda tag: tag[tag.rfind('}')+1:]
 
-### Build XML using Python data structures ###
+### XML with Python data structures
+
+dtree2ele = DictTree.Element
+dtree2xml = DictTree.XML
+ele2dtree = Element.DictTree
+ele2xml = Element.XML
+xml2dtree = XML.DictTree
+xml2ele = XML.Element
+
+class DictTree:
 
-class XMLConverter:
-    """Build an ElementTree.Element instance from an XML tree specification
-    based on nested dictionaries. TODO: describe spec
-    """
-    
-    def __init__(self, spec):
-        "TODO: docstring"
-        self._root = XMLConverter.build(spec)
-    
-    def to_string(self, encoding='utf-8'):
-        "TODO: docstring"
-        xml = ET.tostring(self._root, encoding)
-        # some etree versions don't include xml decl with utf-8
-        # this is a problem with some devices
-        return (xml if xml.startswith('<?xml')
-                else '<?xml version="1.0" encoding="%s"?>%s' % (encoding, xml))
-    
-    @property
-    def tree(self):
-        "TODO: docstring"
-        return self._root
-    
     @staticmethod
-    def build(spec):
-        "TODO: docstring"
-        if ET.iselement(spec):
+    def Element(spec):
+        if iselement(spec):
             return spec
         elif isinstance(spec, basestring):
-            return ET.XML(spec)
-        ## assume isinstance(spec, dict)
+            return XML.Element(spec)
+        if not isinstance(spec, dict):
+            raise ContentError("Invalid tree spec")
         if 'tag' in spec:
             ele = ET.Element(spec.get('tag'), spec.get('attributes', {}))
             ele.text = spec.get('text', '')
-            children = spec.get('children', [])
-            if isinstance(children, dict):
-                children = [children]
-            for child in children:
-                ele.append(XMLConverter.build(child))
+            ele.tail = spec.get('tail', '')
+            subtree = spec.get('subtree', [])
+            # might not be properly specified as list but may be dict
+            if isinstance(subtree, dict):
+                subtree = [subtree]
+            for subele in subtree:
+                ele.append(DictTree.Element(subele))
             return ele
         elif 'comment' in spec:
             return ET.Comment(spec.get('comment'))
         else:
-            raise ValueError('Invalid tree spec')
+            raise ContentError('Invalid tree spec')
+    
+    @staticmethod
+    def XML(spec):
+        Element.XML(DictTree.Element(spec))
+
+class Element:
+    
+    @staticmethod
+    def DictTree(ele):
+        return {
+            'tag': ele.tag,
+            'attributes': ele.attrib,
+            'text': ele.text,
+            'tail': ele.tail,
+            'subtree': [ Element.DictTree(child) for child in root.getchildren() ]
+        }
+    
+    @staticmethod
+    def XML(ele, encoding='utf-8'):
+        xml = ET.tostring(ele, encoding)
+        return xml if xml.startswith('<?xml') else '<?xml version="1.0" encoding="%s"?>\n%s' % (encoding, xml)
+
+class XML:
+    
+    @staticmethod
+    def DictTree(ele):
+        return Element.DictTree(Element.XML(ele))
+    
+    @staticmethod
+    def Element(xml):
+        return ET.fromstring(xml)
+
+### Other utility functions
+
+iselement = ET.iselement
+
+def find(ele, tag, strict=False):
+    """In strict mode, doesn't workaround Cisco implementations sending incorrectly
+    namespaced XML. Supply qualified tag name if using strict mode.
+    """
+    if strict:
+        return ele.find(tag)
+    else:
+        for qname in multiqualify(tag):
+            found = ele.find(qname)
+            if found is not None:
+                return found
+
+def parse_root(raw):
+    '''Parse the top-level element from XML string.
+    
+    Returns a `(tag, attributes)` tuple, where `tag` is a string representing
+    the qualified name of the root element and `attributes` is an
+    `{attribute: value}` dictionary.
+    '''
+    fp = StringIO(raw[:1024]) # this is a guess but start element beyond 1024 bytes would be a bit absurd
+    for event, element in ET.iterparse(fp, events=('start',)):
+        return (element.tag, element.attrib)
+
+def validated_element(rep, tag, attrs=None):
+    ele = dtree2ele(rep)
+    if ele.tag not in (tag, qualify(tag)):
+        raise ContentError("Required root element [%s] not found" % tag)
+    if attrs is not None:
+        for req in attrs:
+            for attr in ele.attrib:
+                if unqualify(attr) == req:
+                    break
+            else:
+                raise ContentError("Required attribute [%s] not found in element [%s]" % (req, req_tag))
+    return ele