* cisco compatibility in the face of non-compliance * other fixes from testing *
[ncclient] / ncclient / operations / rpc.py
1 # Copyright 2009 Shikhar Bhushan
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 #    http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 from threading import Event, Lock
16 from uuid import uuid1
17 from weakref import WeakValueDictionary
18
19 from ncclient.content import TreeBuilder
20 from ncclient.content import qualify as _
21 from ncclient.content import unqualify as __
22 from ncclient.glue import Listener
23
24 from . import logger
25 from reply import RPCReply
26
27 # Cisco does not include message-id attribute in <rpc-reply> in case of an error.
28 # This is messed up however we have to deal with it.
29 # So essentially, there can be only one operation at a time if we are talking to
30 # a Cisco device.
31
32 class RPC(object):
33     
34     def __init__(self, session, async=False):
35         if session.is_remote_cisco and async:
36             raise UserWarning('Asynchronous mode not supported for Cisco devices')
37         self._session = session
38         self._async = async
39         self._id = uuid1().urn
40         self._listener = RPCReplyListener(session)
41         self._listener.register(self._id, self)
42         self._reply = None
43         self._reply_event = Event()
44     
45     def _build(self, op, encoding='utf-8'):
46         if isinstance(op, dict):
47             return self.build_from_spec(self._id, op, encoding)
48         else:
49             return self.build_from_string(self._id, op, encoding)
50     
51     def _request(self, op):
52         req = self._build(op)
53         self._session.send(req)
54         if self._async:
55             return self._reply_event
56         else:
57             self._reply_event.wait()
58             self._reply.parse()
59             return self._reply
60     
61     def _set_reply(self, raw):
62         self._reply = RPCReply(raw)
63     
64     def _set_reply_event(self):
65         self._reply_event.set()
66     
67     def _delivery_hook(self):
68         'For subclasses'
69         pass
70     
71     def deliver(self, raw):
72         self._set_reply(raw)
73         self._delivery_hook()
74         self._set_reply_event()
75     
76     @property
77     def has_reply(self):
78         return self._reply_event.isSet()
79     
80     @property
81     def reply(self):
82         return self._reply
83     
84     @property
85     def is_async(self):
86         return self._async
87     
88     @property
89     def id(self):
90         return self._id
91     
92     @property
93     def session(self):
94         return self._session
95     
96     @property
97     def reply_event(self):
98         return self._reply_event
99     
100     @staticmethod
101     def build_from_spec(msgid, opspec, encoding='utf-8'):
102         "TODO: docstring"
103         spec = {
104             'tag': _('rpc'),
105             'attributes': {'message-id': msgid},
106             'children': opspec
107             }
108         return TreeBuilder(spec).to_string(encoding)
109     
110     @staticmethod
111     def build_from_string(msgid, opstr, encoding='utf-8'):
112         "TODO: docstring"
113         decl = '<?xml version="1.0" encoding="%s"?>' % encoding
114         doc = (u'<rpc message-id="%s" xmlns="%s">%s</rpc>' %
115                (msgid, BASE_NS, opstr)).encode(encoding)
116         return '%s%s' % (decl, doc)
117
118
119 class RPCReplyListener(Listener):
120     
121     # TODO - determine if need locking
122     
123     # one instance per session
124     def __new__(cls, session):
125         instance = session.get_listener_instance(cls)
126         if instance is None:
127             instance = object.__new__(cls)
128             instance._id2rpc = WeakValueDictionary()
129             instance._cisco = session.is_remote_cisco
130             instance._errback = None
131             session.add_listener(instance)
132         return instance
133     
134     def __str__(self):
135         return 'RPCReplyListener'
136     
137     def set_errback(self, errback):
138         self._errback = errback
139
140     def register(self, id, rpc):
141         self._id2rpc[id] = rpc
142     
143     def callback(self, root, raw):
144         tag, attrs = root
145         if __(tag) != 'rpc-reply':
146             return
147         rpc = None
148         for key in attrs:
149             if __(key) == 'message-id':
150                 id = attrs[key]
151                 try:
152                     rpc = self._id2rpc.pop(id)
153                 except KeyError:
154                     logger.warning('[RPCReplyListener.callback] no object '
155                                    + 'registered for message-id: [%s]' % id)
156                 except Exception as e:
157                     logger.debug('[RPCReplyListener.callback] error - %r' % e)
158                 break
159         else:
160             if self._cisco:
161                 assert(len(self._id2rpc) == 1)
162                 rpc = self._id2rpc.values()[0]
163                 self._id2rpc.clear()
164             else:
165                 logger.warning('<rpc-reply> without message-id received: %s' % raw)
166         logger.debug('[RPCReplyListener.callback] delivering to %r' % rpc)
167         rpc.deliver(raw)
168     
169     def errback(self, err):
170         if self._errback is not None:
171             self._errback(err)