c82d8720c3d017cfa162829818c8919bc16989d1
[ncclient] / ncclient / operations / rpc.py
1 # Copyright 2009 Shikhar Bhushan
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 #    http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 from threading import Event, Lock
16 from uuid import uuid1
17 from weakref import WeakValueDictionary
18
19 from ncclient import content
20 from ncclient.transport import SessionListener
21
22 from errors import OperationError
23
24 import logging
25 logger = logging.getLogger('ncclient.operations.rpc')
26
27
28 class RPCReply:
29
30     def __init__(self, raw):
31         self._raw = raw
32         self._parsed = False
33         self._root = None
34         self._errors = []
35
36     def __repr__(self):
37         return self._raw
38
39     def _parsing_hook(self, root): pass
40
41     def parse(self):
42         if self._parsed:
43             return
44         root = self._root = content.xml2ele(self._raw) # <rpc-reply> element
45         # per rfc 4741 an <ok/> tag is sent when there are no errors or warnings
46         ok = content.find(root, 'data', nslist=[content.BASE_NS, content.CISCO_BS])
47         if ok is not None:
48             logger.debug('parsed [%s]' % ok.tag)
49         else: # create RPCError objects from <rpc-error> elements
50             error = content.find(root, 'data', nslist=[content.BASE_NS, content.CISCO_BS])
51             if error is not None:
52                 logger.debug('parsed [%s]' % error.tag)
53                 for err in root.getiterator(error.tag):
54                     # process a particular <rpc-error>
55                     d = {}
56                     for err_detail in err.getchildren(): # <error-type> etc..
57                         tag = content.unqualify(err_detail.tag)
58                         if tag != 'error-info':
59                             d[tag] = err_detail.text.strip()
60                         else:
61                             d[tag] = content.ele2xml(err_detail)
62                     self._errors.append(RPCError(d))
63         self._parsing_hook(root)
64         self._parsed = True
65
66     @property
67     def xml(self):
68         '<rpc-reply> as returned'
69         return self._raw
70
71     @property
72     def ok(self):
73         if not self._parsed:
74             self.parse()
75         return not self._errors # empty list => false
76
77     @property
78     def error(self):
79         if not self._parsed:
80             self.parse()
81         if self._errors:
82             return self._errors[0]
83         else:
84             return None
85
86     @property
87     def errors(self):
88         'List of RPCError objects. Will be empty if no <rpc-error> elements in reply.'
89         if not self._parsed:
90             self.parse()
91         return self._errors
92
93
94 class RPCError(OperationError): # raise it if you like
95
96     def __init__(self, err_dict):
97         self._dict = err_dict
98         if self.message is not None:
99             OperationError.__init__(self, self.message)
100         else:
101             OperationError.__init__(self)
102
103     @property
104     def type(self):
105         return self.get('error-type', None)
106
107     @property
108     def severity(self):
109         return self.get('error-severity', None)
110
111     @property
112     def tag(self):
113         return self.get('error-tag', None)
114
115     @property
116     def path(self):
117         return self.get('error-path', None)
118
119     @property
120     def message(self):
121         return self.get('error-message', None)
122
123     @property
124     def info(self):
125         return self.get('error-info', None)
126
127     ## dictionary interface
128
129     __getitem__ = lambda self, key: self._dict.__getitem__(key)
130
131     __iter__ = lambda self: self._dict.__iter__()
132
133     __contains__ = lambda self, key: self._dict.__contains__(key)
134
135     keys = lambda self: self._dict.keys()
136
137     get = lambda self, key, default: self._dict.get(key, default)
138
139     iteritems = lambda self: self._dict.iteritems()
140
141     iterkeys = lambda self: self._dict.iterkeys()
142
143     itervalues = lambda self: self._dict.itervalues()
144
145     values = lambda self: self._dict.values()
146
147     items = lambda self: self._dict.items()
148
149     __repr__ = lambda self: repr(self._dict)
150
151
152 class RPCReplyListener(SessionListener):
153
154     # one instance per session
155     def __new__(cls, session):
156         instance = session.get_listener_instance(cls)
157         if instance is None:
158             instance = object.__new__(cls)
159             instance._lock = Lock()
160             instance._id2rpc = WeakValueDictionary()
161             instance._pipelined = session.can_pipeline
162             session.add_listener(instance)
163         return instance
164
165     def register(self, id, rpc):
166         with self._lock:
167             self._id2rpc[id] = rpc
168
169     def callback(self, root, raw):
170         tag, attrs = root
171         if content.unqualify(tag) != 'rpc-reply':
172             return
173         rpc = None
174         for key in attrs:
175             if content.unqualify(key) == 'message-id':
176                 id = attrs[key]
177                 try:
178                     with self._lock:
179                         rpc = self._id2rpc.pop(id)
180                 except KeyError:
181                     logger.warning('no object registered for message-id: [%s]' % id)
182                 except Exception as e:
183                     logger.debug('error - %r' % e)
184                 break
185         else:
186             if not self._pipelined:
187                 with self._lock:
188                     assert(len(self._id2rpc) == 1)
189                     rpc = self._id2rpc.values()[0]
190                     self._id2rpc.clear()
191             else:
192                 logger.warning('<rpc-reply> without message-id received: %s' % raw)
193         logger.debug('delivering to %r' % rpc)
194         rpc.deliver(raw)
195
196     def errback(self, err):
197         for rpc in self._id2rpc.values():
198             rpc.error(err)
199
200
201 class RPC(object):
202
203     DEPENDS = []
204     REPLY_CLS = RPCReply
205
206     def __init__(self, session, async=False, timeout=None):
207         if not session.can_pipeline:
208             raise UserWarning('Asynchronous mode not supported for this device/session')
209         self._session = session
210         try:
211             for cap in self.DEPENDS:
212                 self._assert(cap)
213         except AttributeError:
214             pass
215         self._async = async
216         self._timeout = timeout
217         # keeps things simple instead of having a class attr that has to be locked
218         self._id = uuid1().urn
219         # RPCReplyListener itself makes sure there isn't more than one instance -- i.e. multiton
220         self._listener = RPCReplyListener(session)
221         self._listener.register(self._id, self)
222         self._reply = None
223         self._error = None
224         self._reply_event = Event()
225
226     def _build(self, opspec):
227         "TODO: docstring"
228         spec = {
229             'tag': content.qualify('rpc'),
230             'attrib': {'message-id': self._id},
231             'subtree': opspec
232             }
233         return content.dtree2xml(spec)
234
235     def _request(self, op):
236         req = self._build(op)
237         self._session.send(req)
238         if self._async:
239             return self._reply_event
240         else:
241             self._reply_event.wait(self._timeout)
242             if self._reply_event.isSet():
243                 if self._error:
244                     raise self._error
245                 self._reply.parse()
246                 return self._reply
247             else:
248                 raise ReplyTimeoutError
249
250     def request(self):
251         return self._request(self.SPEC)
252
253     def _delivery_hook(self):
254         'For subclasses'
255         pass
256
257     def _assert(self, capability):
258         if capability not in self._session.server_capabilities:
259             raise MissingCapabilityError('Server does not support [%s]' % cap)
260
261     def deliver(self, raw):
262         self._reply = self.REPLY_CLS(raw)
263         self._delivery_hook()
264         self._reply_event.set()
265
266     def error(self, err):
267         self._error = err
268         self._reply_event.set()
269
270     @property
271     def has_reply(self):
272         return self._reply_event.is_set()
273
274     @property
275     def reply(self):
276         if self.error:
277             raise self._error
278         return self._reply
279
280     @property
281     def id(self):
282         return self._id
283
284     @property
285     def session(self):
286         return self._session
287
288     @property
289     def reply_event(self):
290         return self._reply_event
291
292     def set_async(self, bool): self._async = bool
293     async = property(fget=lambda self: self._async, fset=set_async)
294
295     def set_timeout(self, timeout): self._timeout = timeout
296     timeout = property(fget=lambda self: self._timeout, fset=set_timeout)