same goes for <hello>, wasn't causing problems but still..
[ncclient] / ncclient / transport / session.py
1 # Copyright 2009 Shikhar Bhushan
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 #    http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 from Queue import Queue
16 from threading import Thread, Lock, Event
17
18 from ncclient import content
19 from ncclient.capabilities import Capabilities
20
21 from errors import TransportError
22
23 import logging
24 logger = logging.getLogger('ncclient.transport.session')
25
26 class Session(Thread):
27
28     "Base class for use by transport protocol implementations."
29
30     def __init__(self, capabilities):
31         Thread.__init__(self)
32         self.setDaemon(True)
33         self._listeners = set() # 3.0's weakset would be ideal
34         self._lock = Lock()
35         self.setName('session')
36         self._q = Queue()
37         self._client_capabilities = capabilities
38         self._server_capabilities = None # yet
39         self._id = None # session-id
40         self._connected = False # to be set/cleared by subclass implementation
41         logger.debug('%r created: client_capabilities=%r' %
42                      (self, self._client_capabilities))
43
44     def _dispatch_message(self, raw):
45         try:
46             root = content.parse_root(raw)
47         except Exception as e:
48             logger.error('error parsing dispatch message: %s' % e)
49             return
50         with self._lock:
51             listeners = list(self._listeners)
52         for l in listeners:
53             logger.debug('dispatching message to %r' % l)
54             try:
55                 l.callback(root, raw)
56             except Exception as e:
57                 logger.warning('[error] %r' % e)
58
59     def _dispatch_error(self, err):
60         with self._lock:
61             listeners = list(self._listeners)
62         for l in listeners:
63             logger.debug('dispatching error to %r' % l)
64             try:
65                 l.errback(err)
66             except Exception as e:
67                 logger.warning('error dispatching to %r: %r' % (l, e))
68
69     def _post_connect(self):
70         "Greeting stuff"
71         init_event = Event()
72         error = [None] # so that err_cb can bind error[0]. just how it is.
73         # callbacks
74         def ok_cb(id, capabilities):
75             self._id = id
76             self._server_capabilities = capabilities
77             init_event.set()
78         def err_cb(err):
79             error[0] = err
80             init_event.set()
81         listener = HelloHandler(ok_cb, err_cb)
82         self.add_listener(listener)
83         self.send(HelloHandler.build(self._client_capabilities))
84         logger.debug('starting main loop')
85         self.start()
86         # we expect server's hello message
87         init_event.wait()
88         # received hello message or an error happened
89         self.remove_listener(listener)
90         if error[0]:
91             raise error[0]
92         #if ':base:1.0' not in self.server_capabilities:
93         #    raise MissingCapabilityError(':base:1.0')
94         logger.info('initialized: session-id=%s | server_capabilities=%s' %
95                     (self._id, self._server_capabilities))
96
97     def add_listener(self, listener):
98         """Register a listener that will be notified of incoming messages and
99         errors.
100
101         :type listener: :class:`SessionListener`
102         """
103         logger.debug('installing listener %r' % listener)
104         if not isinstance(listener, SessionListener):
105             raise SessionError("Listener must be a SessionListener type")
106         with self._lock:
107             self._listeners.add(listener)
108
109     def remove_listener(self, listener):
110         """Unregister some listener; ignore if the listener was never
111         registered.
112
113         :type listener: :class:`SessionListener`
114         """
115         logger.debug('discarding listener %r' % listener)
116         with self._lock:
117             self._listeners.discard(listener)
118
119     def get_listener_instance(self, cls):
120         """If a listener of the specified type is registered, returns the
121         instance.
122
123         :type cls: :class:`SessionListener`
124         """
125         with self._lock:
126             for listener in self._listeners:
127                 if isinstance(listener, cls):
128                     return listener
129
130     def connect(self, *args, **kwds): # subclass implements
131         raise NotImplementedError
132
133     def run(self): # subclass implements
134         raise NotImplementedError
135
136     def send(self, message):
137         """Send the supplied *message* to NETCONF server.
138
139         :arg message: an XML document
140
141         :type message: `string`
142         """
143         if not self.connected:
144             raise TransportError('Not connected to NETCONF server')
145         logger.debug('queueing %s' % message)
146         self._q.put(message)
147
148     ### Properties
149
150     @property
151     def connected(self):
152         "Connection status of the session."
153         return self._connected
154
155     @property
156     def client_capabilities(self):
157         "Client's :class:`Capabilities`"
158         return self._client_capabilities
159
160     @property
161     def server_capabilities(self):
162         "Server's :class:`Capabilities`"
163         return self._server_capabilities
164
165     @property
166     def id(self):
167         """A `string` representing the `session-id`. If the session has not
168         been initialized it will be :const:`None`"""
169         return self._id
170
171     @property
172     def can_pipeline(self):
173         "Whether this session supports pipelining"
174         return True
175
176
177 class SessionListener(object):
178
179     """Base class for :class:`Session` listeners, which are notified when a new
180     NETCONF message is received or an error occurs.
181
182     .. note::
183         Avoid time-intensive tasks in a callback's context.
184     """
185
186     def callback(self, root, raw):
187         """Called when a new XML document is received. The `root` argument
188         allows the callback to determine whether it wants to further process the
189         document.
190
191         :arg root: tuple of `(tag, attributes)` where `tag` is the qualified name of the root element and `attributes` is a dictionary of its attributes (also qualified names)
192         :type root: `tuple`
193
194         :arg raw: XML document
195         :type raw: `string`
196         """
197         raise NotImplementedError
198
199     def errback(self, ex):
200         """Called when an error occurs.
201
202         :type ex: :exc:`Exception`
203         """
204         raise NotImplementedError
205
206
207 class HelloHandler(SessionListener):
208
209     def __init__(self, init_cb, error_cb):
210         self._init_cb = init_cb
211         self._error_cb = error_cb
212
213     def callback(self, root, raw):
214         if content.unqualify(root[0]) == 'hello':
215             try:
216                 id, capabilities = HelloHandler.parse(raw)
217             except Exception as e:
218                 self._error_cb(e)
219             else:
220                 self._init_cb(id, capabilities)
221
222     def errback(self, err):
223         self._error_cb(err)
224
225     @staticmethod
226     def build(capabilities):
227         "Given a list of capability URI's returns <hello> message XML string"
228         spec = {
229             'tag': 'hello',
230             'attrib': {'xmlns': content.BASE_NS},
231             'subtree': [{
232                 'tag': 'capabilities',
233                 'subtree': # this is fun :-)
234                     [{'tag': 'capability', 'text': uri} for uri in capabilities]
235                 }]
236             }
237         return content.dtree2xml(spec)
238
239     @staticmethod
240     def parse(raw):
241         "Returns tuple of (session-id (str), capabilities (Capabilities)"
242         sid, capabilities = 0, []
243         root = content.xml2ele(raw)
244         for child in root.getchildren():
245             tag = content.unqualify(child.tag)
246             if tag == 'session-id':
247                 sid = child.text
248             elif tag == 'capabilities':
249                 for cap in child.getchildren():
250                     if content.unqualify(cap.tag) == 'capability':
251                         capabilities.append(cap.text)
252         return sid, Capabilities(capabilities)