git-svn-id: http://ncclient.googlecode.com/svn/trunk@109 6dbcf712-26ac-11de-a2f3...
[ncclient] / ncclient / content.py
index 462ada4..a895af1 100644 (file)
@@ -12,8 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-"TODO: docstring"
-
 from xml.etree import cElementTree as ET
 
 from ncclient import NCClientError
@@ -44,43 +42,6 @@ multiqualify = lambda tag, nslist=(BASE_NS, CISCO_BS): [qualify(tag, ns) for ns
 
 unqualify = lambda tag: tag[tag.rfind('}')+1:]
 
-### Other utility functions
-
-iselement = ET.iselement
-
-def namespaced_find(ele, tag, strict=False):
-    """In strict mode, doesn't work around Cisco implementations sending incorrectly
-    namespaced XML. Supply qualified name if using strict mode.
-    """
-    found = None
-    if strict:
-        found = ele.find(tag)
-    else:
-        for qname in multiqualify(tag):
-            found = ele.find(qname)
-            if found is not None:
-                break
-    return found
-
-def parse_root(raw):
-    '''Parse the top-level element from XML string.
-    
-    Returns a `(tag, attributes)` tuple, where `tag` is a string representing
-    the qualified name of the root element and `attributes` is an
-    `{attribute: value}` dictionary.
-    '''
-    fp = StringIO(raw[:1024]) # this is a guess but start element beyond 1024 bytes would be a bit absurd
-    for event, element in ET.iterparse(fp, events=('start',)):
-        return (element.tag, element.attrib)
-
-def root_ensured(rep, req_tag, req_attrs=None):
-    rep = to_element(rep)
-    if rep.tag not in (req_tag, qualify(req_tag)):
-        raise ContentError("Required root element [%s] not found" % req_tag)
-    if req_attrs is not None:
-        pass # TODO
-    return rep
-
 ### XML with Python data structures
 
 dtree2ele = DictTree.Element
@@ -146,3 +107,43 @@ class XML:
     @staticmethod
     def Element(xml):
         return ET.fromstring(xml)
+
+### Other utility functions
+
+iselement = ET.iselement
+
+def find(ele, tag, strict=False):
+    """In strict mode, doesn't workaround Cisco implementations sending incorrectly
+    namespaced XML. Supply qualified tag name if using strict mode.
+    """
+    if strict:
+        return ele.find(tag)
+    else:
+        for qname in multiqualify(tag):
+            found = ele.find(qname)
+            if found is not None:
+                return found
+
+def parse_root(raw):
+    '''Parse the top-level element from XML string.
+    
+    Returns a `(tag, attributes)` tuple, where `tag` is a string representing
+    the qualified name of the root element and `attributes` is an
+    `{attribute: value}` dictionary.
+    '''
+    fp = StringIO(raw[:1024]) # this is a guess but start element beyond 1024 bytes would be a bit absurd
+    for event, element in ET.iterparse(fp, events=('start',)):
+        return (element.tag, element.attrib)
+
+def validated_element(rep, tag, attrs=None):
+    ele = dtree2ele(rep)
+    if ele.tag not in (tag, qualify(tag)):
+        raise ContentError("Required root element [%s] not found" % tag)
+    if attrs is not None:
+        for req in attrs:
+            for attr in ele.attrib:
+                if unqualify(attr) == req:
+                    break
+            else:
+                raise ContentError("Required attribute [%s] not found in element [%s]" % (req, req_tag))
+    return ele