# See the License for the specific language governing permissions and
# limitations under the License.
-"TODO: docstring"
-
from xml.etree import cElementTree as ET
from ncclient import NCClientError
unqualify = lambda tag: tag[tag.rfind('}')+1:]
-### Other utility functions
-
-iselement = ET.iselement
-
-def namespaced_find(ele, tag, strict=False):
- """In strict mode, doesn't work around Cisco implementations sending incorrectly
- namespaced XML. Supply qualified name if using strict mode.
- """
- found = None
- if strict:
- found = ele.find(tag)
- else:
- for qname in multiqualify(tag):
- found = ele.find(qname)
- if found is not None:
- break
- return found
-
-def parse_root(raw):
- '''Parse the top-level element from XML string.
-
- Returns a `(tag, attributes)` tuple, where `tag` is a string representing
- the qualified name of the root element and `attributes` is an
- `{attribute: value}` dictionary.
- '''
- fp = StringIO(raw[:1024]) # this is a guess but start element beyond 1024 bytes would be a bit absurd
- for event, element in ET.iterparse(fp, events=('start',)):
- return (element.tag, element.attrib)
-
-def root_ensured(rep, req_tag, req_attrs=None):
- rep = to_element(rep)
- if rep.tag not in (req_tag, qualify(req_tag)):
- raise ContentError("Required root element [%s] not found" % req_tag)
- if req_attrs is not None:
- pass # TODO
- return rep
-
### XML with Python data structures
dtree2ele = DictTree.Element
@staticmethod
def Element(xml):
return ET.fromstring(xml)
+
+### Other utility functions
+
+iselement = ET.iselement
+
+def find(ele, tag, strict=False):
+ """In strict mode, doesn't workaround Cisco implementations sending incorrectly
+ namespaced XML. Supply qualified tag name if using strict mode.
+ """
+ if strict:
+ return ele.find(tag)
+ else:
+ for qname in multiqualify(tag):
+ found = ele.find(qname)
+ if found is not None:
+ return found
+
+def parse_root(raw):
+ '''Parse the top-level element from XML string.
+
+ Returns a `(tag, attributes)` tuple, where `tag` is a string representing
+ the qualified name of the root element and `attributes` is an
+ `{attribute: value}` dictionary.
+ '''
+ fp = StringIO(raw[:1024]) # this is a guess but start element beyond 1024 bytes would be a bit absurd
+ for event, element in ET.iterparse(fp, events=('start',)):
+ return (element.tag, element.attrib)
+
+def validated_element(rep, tag, attrs=None):
+ ele = dtree2ele(rep)
+ if ele.tag not in (tag, qualify(tag)):
+ raise ContentError("Required root element [%s] not found" % tag)
+ if attrs is not None:
+ for req in attrs:
+ for attr in ele.attrib:
+ if unqualify(attr) == req:
+ break
+ else:
+ raise ContentError("Required attribute [%s] not found in element [%s]" % (req, req_tag))
+ return ele