git-svn-id: http://ncclient.googlecode.com/svn/trunk@117 6dbcf712-26ac-11de-a2f3...
[ncclient] / ncclient / content.py
1 # Copyright 2009 Shikhar Bhushan
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 #    http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 from cStringIO import StringIO
16 from xml.etree import cElementTree as ET
17
18 from ncclient import NCClientError
19
20 class ContentError(NCClientError):
21     pass
22
23 ### Namespace-related
24
25 BASE_NS = 'urn:ietf:params:xml:ns:netconf:base:1.0'
26 NOTIFICATION_NS = 'urn:ietf:params:xml:ns:netconf:notification:1.0'
27 # and this is BASE_NS according to cisco devices...
28 CISCO_BS = 'urn:ietf:params:netconf:base:1.0'
29
30 try:
31     register_namespace = ET.register_namespace
32 except AttributeError:
33     def register_namespace(prefix, uri):
34         from xml.etree import ElementTree
35         # cElementTree uses ElementTree's _namespace_map, so that's ok
36         ElementTree._namespace_map[uri] = prefix
37
38 # we'd like BASE_NS to be prefixed as "netconf"
39 register_namespace('netconf', BASE_NS)
40
41 qualify = lambda tag, ns=BASE_NS: tag if ns is None else '{%s}%s' % (ns, tag)
42
43 # deprecated
44 multiqualify = lambda tag, nslist=(BASE_NS, CISCO_BS): [qualify(tag, ns) for ns in nslist]
45
46 unqualify = lambda tag: tag[tag.rfind('}')+1:]
47
48 ### XML with Python data structures
49
50 class DictTree:
51
52     @staticmethod
53     def Element(spec):
54         if iselement(spec):
55             return spec
56         elif isinstance(spec, basestring):
57             return XML.Element(spec)
58         if not isinstance(spec, dict):
59             raise ContentError("Invalid tree spec")
60         if 'tag' in spec:
61             ele = ET.Element(spec.get('tag'), spec.get('attributes', {}))
62             ele.text = spec.get('text', '')
63             ele.tail = spec.get('tail', '')
64             subtree = spec.get('subtree', [])
65             # might not be properly specified as list but may be dict
66             if isinstance(subtree, dict):
67                 subtree = [subtree]
68             for subele in subtree:
69                 ele.append(DictTree.Element(subele))
70             return ele
71         elif 'comment' in spec:
72             return ET.Comment(spec.get('comment'))
73         else:
74             raise ContentError('Invalid tree spec')
75     
76     @staticmethod
77     def XML(spec, encoding='UTF-8'):
78         return Element.XML(DictTree.Element(spec), encoding)
79
80 class Element:
81     
82     @staticmethod
83     def DictTree(ele):
84         return {
85             'tag': ele.tag,
86             'attributes': ele.attrib,
87             'text': ele.text,
88             'tail': ele.tail,
89             'subtree': [ Element.DictTree(child) for child in root.getchildren() ]
90         }
91     
92     @staticmethod
93     def XML(ele, encoding='UTF-8'):
94         xml = ET.tostring(ele, encoding)
95         if xml.startswith('<?xml'):
96             return xml
97         else:
98             return '<?xml version="1.0" encoding="%s"?>%s' % (encoding, xml)
99
100 class XML:
101     
102     @staticmethod
103     def DictTree(xml):
104         return Element.DictTree(XML.Element(xml))
105     
106     @staticmethod
107     def Element(xml):
108         return ET.fromstring(xml)
109
110 dtree2ele = DictTree.Element
111 dtree2xml = DictTree.XML
112 ele2dtree = Element.DictTree
113 ele2xml = Element.XML
114 xml2dtree = XML.DictTree
115 xml2ele = XML.Element
116
117 ### Other utility functions
118
119 iselement = ET.iselement
120
121 def find(ele, tag, strict=True, nslist=[BASE_NS, CISCO_BS]):
122     """In strict mode, doesn't work around Cisco implementations sending incorrectly namespaced XML. Supply qualified tag name if using strict mode.
123     """
124     if strict:
125         return ele.find(tag)
126     else:
127         for qname in multiqualify(tag):
128             found = ele.find(qname)
129             if found is not None:
130                 return found        
131
132 def parse_root(raw):
133     """
134     """
135     fp = StringIO(raw[:1024]) # this is a guess but start element beyond 1024 bytes would be a bit absurd
136     for event, element in ET.iterparse(fp, events=('start',)):
137         return (element.tag, element.attrib)
138
139 def validated_element(rep, tag, attrs=None):
140     """
141     """
142     ele = dtree2ele(rep)
143     if ele.tag not in (tag, qualify(tag)):
144         raise ContentError("Required root element [%s] not found" % tag)
145     if attrs is not None:
146         for req in attrs:
147             for attr in ele.attrib:
148                 if unqualify(attr) == req:
149                     break
150             else:
151                 raise ContentError("Required attribute [%s] not found in element [%s]" % (req, req_tag))
152     return ele