Revision d771dffc ncclient/content.py

b/ncclient/content.py
12 12
# See the License for the specific language governing permissions and
13 13
# limitations under the License.
14 14

  
15
"TODO: docstring"
16

  
17 15
from xml.etree import cElementTree as ET
18 16

  
19 17
from ncclient import NCClientError
......
44 42

  
45 43
unqualify = lambda tag: tag[tag.rfind('}')+1:]
46 44

  
47
### Other utility functions
48

  
49
iselement = ET.iselement
50

  
51
def namespaced_find(ele, tag, strict=False):
52
    """In strict mode, doesn't work around Cisco implementations sending incorrectly
53
    namespaced XML. Supply qualified name if using strict mode.
54
    """
55
    found = None
56
    if strict:
57
        found = ele.find(tag)
58
    else:
59
        for qname in multiqualify(tag):
60
            found = ele.find(qname)
61
            if found is not None:
62
                break
63
    return found
64

  
65
def parse_root(raw):
66
    '''Parse the top-level element from XML string.
67
    
68
    Returns a `(tag, attributes)` tuple, where `tag` is a string representing
69
    the qualified name of the root element and `attributes` is an
70
    `{attribute: value}` dictionary.
71
    '''
72
    fp = StringIO(raw[:1024]) # this is a guess but start element beyond 1024 bytes would be a bit absurd
73
    for event, element in ET.iterparse(fp, events=('start',)):
74
        return (element.tag, element.attrib)
75

  
76
def root_ensured(rep, req_tag, req_attrs=None):
77
    rep = to_element(rep)
78
    if rep.tag not in (req_tag, qualify(req_tag)):
79
        raise ContentError("Required root element [%s] not found" % req_tag)
80
    if req_attrs is not None:
81
        pass # TODO
82
    return rep
83

  
84 45
### XML with Python data structures
85 46

  
86 47
dtree2ele = DictTree.Element
......
146 107
    @staticmethod
147 108
    def Element(xml):
148 109
        return ET.fromstring(xml)
110

  
111
### Other utility functions
112

  
113
iselement = ET.iselement
114

  
115
def find(ele, tag, strict=False):
116
    """In strict mode, doesn't workaround Cisco implementations sending incorrectly
117
    namespaced XML. Supply qualified tag name if using strict mode.
118
    """
119
    if strict:
120
        return ele.find(tag)
121
    else:
122
        for qname in multiqualify(tag):
123
            found = ele.find(qname)
124
            if found is not None:
125
                return found
126

  
127
def parse_root(raw):
128
    '''Parse the top-level element from XML string.
129
    
130
    Returns a `(tag, attributes)` tuple, where `tag` is a string representing
131
    the qualified name of the root element and `attributes` is an
132
    `{attribute: value}` dictionary.
133
    '''
134
    fp = StringIO(raw[:1024]) # this is a guess but start element beyond 1024 bytes would be a bit absurd
135
    for event, element in ET.iterparse(fp, events=('start',)):
136
        return (element.tag, element.attrib)
137

  
138
def validated_element(rep, tag, attrs=None):
139
    ele = dtree2ele(rep)
140
    if ele.tag not in (tag, qualify(tag)):
141
        raise ContentError("Required root element [%s] not found" % tag)
142
    if attrs is not None:
143
        for req in attrs:
144
            for attr in ele.attrib:
145
                if unqualify(attr) == req:
146
                    break
147
            else:
148
                raise ContentError("Required attribute [%s] not found in element [%s]" % (req, req_tag))
149
    return ele

Also available in: Unified diff