# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
from xml.etree import cElementTree as ElementTree
-NAMESPACE = 'urn:ietf:params:xml:ns:netconf:base:1.0'
+logger = logging.getLogger('ncclient.content')
-def qualify(tag, ns=NAMESPACE):
+BASE_NS = 'urn:ietf:params:xml:ns:netconf:base:1.0'
+NOTIFICATION_NS = 'urn:ietf:params:xml:ns:netconf:notification:1.0'
+
+def qualify(tag, ns=BASE_NS):
return '{%s}%s' % (ns, tag)
_ = qualify
def make_hello(capabilities):
- return '<hello xmlns="%s">%s</hello>' % (NAMESPACE, capabilities)
+ return '<hello xmlns="%s">%s</hello>' % (BASE_NS, capabilities)
def make_rpc(id, op):
- return '<rpc message-id="%s" xmlns="%s">%s</rpc>' % (id, NAMESPACE, op)
+ return '<rpc message-id="%s" xmlns="%s">%s</rpc>' % (id, BASE_NS, op)
def parse_hello(raw):
- from capability import Capabilities
+ from capabilities import Capabilities
id, capabilities = 0, Capabilities()
root = ElementTree.fromstring(raw)
if root.tag == _('hello'):
capabilities.add(cap.text)
return id, capabilities
-def parse_message_type(raw):
-
- target = RootElementParser()
- parser = ElementTree.XMLTreeBuilder(target=target)
- parser.feed(raw)
- return target.id
+def parse_message_root(raw):
+ from cStringIO import StringIO
+ fp = StringIO(raw)
+ for event, element in ElementTree.iterparse(fp, events=('start',)):
+ if element.tag == _('rpc'):
+ return element.attrib['message-id']
+ elif element.tag == _('notification', NOTIFICATION_NS):
+ return 'notification'
+ else:
+ return None
\ No newline at end of file
- # Copyright 2009 Shikhar Bhushan
+# Copyright 2009 Shikhar Bhushan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# limitations under the License.
import logging
-import weakref
+from weakref import WeakValueDictionary
+
+import content
logger = logging.getLogger('ncclient.listeners')
-import content
+session_listeners = {}
+def session_listener_factory(session):
+ try:
+ return session_listeners[session]
+ except KeyError:
+ session_listeners[session] = SessionListener()
+ return session_listeners[session]
class SessionListener(object):
- 'A multiton - one listener per session'
-
- instances = weakref.WeakValueDictionary()
-
- def __new__(cls, sid):
- if sid in instances:# not been gc'd
- return cls.instances[sid]
- else:
- inst = object.__new__(cls)
- cls.instances[sid] = inst
- return inst
+ def __init__(self):
+ self._id2rpc = WeakValueDictionary()
+ self._expecting_close = False
+ self._subscription = None
def __str__(self):
return 'SessionListener'
- def set_subscription(self, id):
+ def set_subscription(self, id):
self._subscription = id
+ def expect_close(self):
+ self._expecting_close = True
+
def register(self, id, op):
self._id2rpc[id] = op
- def unregister(self, id):
- del self._id2prc[id]
-
### Events
def reply(self, raw):
- id = content.parse_message(raw)
- if id:
- self._id2rpc[id]._deliver(raw)
- else:
- self._id2rpc[self._sub_id]._notify(raw)
+ try:
+ id = content.parse_message_root(raw)
+ if id is None:
+ pass
+ elif id == 'notification':
+ self._id2rpc[self._sub_id]._notify(raw)
+ else:
+ self._id2rpc[id]._response_cb(raw)
+ except Exception as e:
+ logger.warning(e)
- def close(self, buf):
- pass # TODO
+ def error(self, err):
+ from ssh import SessionCloseError
+ if err is SessionCloseError:
+ logger.debug('received session close, expecting_close=%s' %
+ self._expecting_close)
+ if not self._expecting_close:
+ raise err
class DebugListener:
return 'DebugListener'
def reply(self, raw):
- logger.debug('reply:\n%s' % raw)
+ logger.debug('DebugListener:reply:\n%s' % raw)
def error(self, err):
- logger.debug(err)
+ logger.debug('DebugListener:error:\n%s' % err)
# limitations under the License.
from threading import Event, Lock
-
-from listener import SessionListener
-
from uuid import uuid1
+import content
+from listeners import session_listener_factory
+
class RPC:
- def __init__(self, session, async=False):
+ def __init__(self, session, async=False, parse=True):
self._session = session
self._async = async
+ self._id = uuid1().urn
+ self._listener = session_listener_factory(self._session)
+ listener.register(self._id, self)
+ session.add_listener(self._listener)
self._reply = None
self._reply_event = Event()
- self._id = uuid1().urn
- def _listener(self):
- if not RPC.listeners.has_key(self._session.id):
- RPC.listeners[self._session.id] = SessionListener()
- return RPC.listeners[self._session.id]
-
- def request(self, async=False):
- self._async = async
- listener = SessionListener(self._session.id)
- session.add_listener(listener)
- listener.register(self._id, self)
- self._session.send(self.to_xml())
-
- def response_cb(self, reply):
- self._reply = reply # does callback parse??
+ def _response_cb(self, reply):
+ self._reply = reply
self._event.set()
+ def _do_request(self, op):
+ self._session.send(content.make_rpc(self._id, op))
+ if not self._async:
+ self._reply_event.wait()
+ return self._reply
+
+ def request(self):
+ raise NotImplementedError
+
+ def wait_for_reply(self, timeout=None):
+ self._reply_event.wait(timeout)
+
@property
def has_reply(self):
return self._reply_event.isSet()
- def wait_on_reply(self, timeout=None):
- self._reply_event.wait(timeout)
-
@property
def is_async(self):
return self._async
@property
+ def reply(self):
+ return self._reply
+
+ @property
def id(self):
- return self._id
\ No newline at end of file
+ return self._id
+
+ @property
+ def session(self):
+ return self._session
+
+
+class RPCReply:
+ pass
+
+class RPCError:
+ pass
\ No newline at end of file
# limitations under the License.
import logging
-
-import content
-
from threading import Thread, Event
from Queue import Queue
-from capability import CAPABILITIES
+import content
+from capabilities import CAPABILITIES
from error import ClientError
from subject import Subject
logger = logging.getLogger('ncclient.session')
-class SessionError(ClientError):
-
- pass
+class SessionError(ClientError): pass
class Session(Thread, Subject):
self._client_capabilities = CAPABILITIES
self._server_capabilities = None # yet
self._id = None # session-id
- self._connected = False # subclasses should set this
self._error = None
self._init_event = Event()
self._q = Queue()
+ self._connected = False # to be set/cleared by subclass
+
+ def _post_connect(self):
+ # start the subclass' main loop
+ self.start()
+ # queue client's hello message for sending
+ self.send(content.make_hello(self._client_capabilities))
+ # we expect server's hello message, wait for _init_event to be set by HelloListener
+ self._init_event.wait()
+ # there may have been an error
+ if self._error:
+ self._close()
+ raise self._error
def send(self, message):
- message = (u'<?xml version="1.0" encoding="UTF-8"?>%s' % message).encode('utf-8')
+ message = (u'<?xml version="1.0" encoding="UTF-8"?>%s' %
+ message).encode('utf-8')
logger.debug('queueing message: \n%s' % message)
self._q.put(message)
return self._client_capabilities
@property
- def serve_capabilities(self):
+ def server_capabilities(self):
return self._server_capabilities
@property
def id(self):
return self._id
- def _post_connect(self):
- # start the subclass' main loop
- self.start()
- # queue client's hello message for sending
- self.send(content.make_hello(self._client_capabilities))
- # we expect server's hello message, wait for _init_event to be set by HelloListener
- self._init_event.wait()
- # there may have been an error
- if self._error:
- self._close()
- raise self._error
-
class HelloListener:
def __str__(self):
# limitations under the License.
import logging
-import paramiko
-
-from os import SEEK_CUR
from cStringIO import StringIO
+from os import SEEK_CUR
+
+import paramiko
from session import Session, SessionError
missing_host_key_policy=paramiko.RejectPolicy):
Session.__init__(self)
self._client = paramiko.SSHClient()
+ self._channel = None
if load_known_hosts:
self._client.load_system_host_keys()
self._client.set_missing_host_key_policy(missing_host_key_policy)
self._parsing_state = 0
self._parsing_pos = 0
+ def _close(self):
+ self._channel.close()
+ self._connected = False
+
+ def _fresh_data(self):
+ delim = SSHSession.MSG_DELIM
+ n = len(delim) - 1
+ state = self._parsing_state
+ buf = self._in_buf
+ buf.seek(self._parsing_pos)
+ while True:
+ x = buf.read(1)
+ if not x: # done reading
+ break
+ elif x == delim[state]:
+ state += 1
+ else:
+ continue
+ # loop till last delim char expected, break if other char encountered
+ for i in range(state, n):
+ x = buf.read(1)
+ if not x: # done reading
+ break
+ if x==delim[i]: # what we expected
+ state += 1 # expect the next delim char
+ else:
+ state = 0 # reset
+ break
+ else: # if we didn't break out of above loop, full delim parsed
+ till = buf.tell() - n
+ buf.seek(0)
+ msg = buf.read(till)
+ self.dispatch('reply', msg)
+ buf.seek(n+1, SEEK_CUR)
+ rest = buf.read()
+ buf = StringIO()
+ buf.write(rest)
+ buf.seek(0)
+ state = 0
+ self._in_buf = buf
+ self._parsing_state = state
+ self._parsing_pos = self._in_buf.tell()
+
def load_host_keys(self, filename):
self._client.load_host_keys(filename)
except Exception as e:
logger.debug('*** broke out of main loop ***')
self.dispatch('error', e)
-
- def _close(self):
- self._channel.close()
- self._connected = False
-
- def _fresh_data(self):
- delim = SSHSession.MSG_DELIM
- n = len(delim) - 1
- state = self._parsing_state
- buf = self._in_buf
- buf.seek(self._parsing_pos)
- while True:
- x = buf.read(1)
- if not x: # done reading
- break
- elif x == delim[state]:
- state += 1
- else:
- continue
- # loop till last delim char expected, break if other char encountered
- for i in range(state, n):
- x = buf.read(1)
- if not x: # done reading
- break
- if x==delim[i]: # what we expected
- state += 1 # expect the next delim char
- else:
- state = 0 # reset
- break
- else: # if we didn't break out of above loop, full delim parsed
- till = buf.tell() - n
- buf.seek(0)
- msg = buf.read(till)
- self.dispatch('reply', msg)
- buf.seek(n+1, SEEK_CUR)
- rest = buf.read()
- buf = StringIO()
- buf.write(rest)
- buf.seek(0)
- state = 0
- self._in_buf = buf
- self._parsing_state = state
- self._parsing_pos = self._in_buf.tell()
class MissingHostKeyPolicy(paramiko.MissingHostKeyPolicy):