From 4f650d549a5eb870704729de4614279d5a21d797 Mon Sep 17 00:00:00 2001 From: Shikhar Bhushan Date: Thu, 14 May 2009 13:49:26 +0000 Subject: [PATCH] docstrings and fixes git-svn-id: http://ncclient.googlecode.com/svn/trunk@119 6dbcf712-26ac-11de-a2f3-1373824ab735 --- ncclient/content.py | 34 ++++++----- ncclient/manager.py | 55 +++++++++-------- ncclient/operations/retrieve.py | 2 +- ncclient/operations/rpc.py | 114 +++++++++++++++++------------------ ncclient/operations/subscribe.py | 6 +- ncclient/operations/util.py | 2 +- ncclient/transport/errors.py | 14 ++--- ncclient/transport/session.py | 115 +++++++++++++++++++---------------- ncclient/transport/ssh.py | 122 +++++++++++++++++++++++++------------- ncclient/util.py | 8 +-- 10 files changed, 267 insertions(+), 205 deletions(-) diff --git a/ncclient/content.py b/ncclient/content.py index 72a73ce..434231d 100644 --- a/ncclient/content.py +++ b/ncclient/content.py @@ -53,8 +53,9 @@ class DictTree: @staticmethod def Element(spec): """DictTree -> Element - + :type spec: :obj:`dict` or :obj:`string` or :class:`~xml.etree.ElementTree.Element` + :rtype: :class:`~xml.etree.ElementTree.Element` """ if iselement(spec): @@ -64,7 +65,7 @@ class DictTree: if not isinstance(spec, dict): raise ContentError("Invalid tree spec") if 'tag' in spec: - ele = ET.Element(spec.get('tag'), spec.get('attributes', {})) + ele = ET.Element(spec.get('tag'), spec.get('attrib', {})) ele.text = spec.get('text', '') ele.tail = spec.get('tail', '') subtree = spec.get('subtree', []) @@ -78,23 +79,25 @@ class DictTree: return ET.Comment(spec.get('comment')) else: raise ContentError('Invalid tree spec') - + @staticmethod def XML(spec, encoding='UTF-8'): """DictTree -> XML - + :type spec: :obj:`dict` or :obj:`string` or :class:`~xml.etree.ElementTree.Element` + :arg encoding: chraracter encoding + :rtype: string """ return Element.XML(DictTree.Element(spec), encoding) class Element: - + @staticmethod def DictTree(ele): """DictTree -> Element - + :type spec: :class:`~xml.etree.ElementTree.Element` :rtype: :obj:`dict` """ @@ -105,11 +108,11 @@ class Element: 'tail': ele.tail, 'subtree': [ Element.DictTree(child) for child in root.getchildren() ] } - + @staticmethod def XML(ele, encoding='UTF-8'): """Element -> XML - + :type spec: :class:`~xml.etree.ElementTree.Element` :arg encoding: character encoding :rtype: :obj:`string` @@ -121,20 +124,20 @@ class Element: return '%s' % (encoding, xml) class XML: - + @staticmethod def DictTree(xml): """XML -> DictTree - + :type spec: :obj:`string` :rtype: :obj:`dict` """ return Element.DictTree(XML.Element(xml)) - + @staticmethod def Element(xml): """XML -> Element - + :type xml: :obj:`string` :rtype: :class:`~xml.etree.ElementTree.Element` """ @@ -153,7 +156,7 @@ iselement = ET.iselement def find(ele, tag, nslist=[]): """If `nslist` is empty, same as :meth:`xml.etree.ElementTree.Element.find`. If it is not, `tag` is interpreted as an unqualified name and qualified using each item in `nslist`. The first match is returned. - + :arg nslist: optional list of namespaces """ if nslist: @@ -166,7 +169,7 @@ def find(ele, tag, nslist=[]): def parse_root(raw): """Efficiently parses the root element of an XML document. - + :type raw: string :returns: a tuple of `(tag, attributes)`, where `tag` is the (qualified) name of the element and `attributes` is a dictionary of its attributes. """ @@ -176,12 +179,11 @@ def parse_root(raw): def validated_element(rep, tag=None, attrs=None, text=None): """Checks if the root element meets the supplied criteria. Returns a :class:`~xml.etree.ElementTree.Element` instance if so, otherwise raises :exc:`ContentError`. - + :arg tag: tag name or a list of allowable tag names :arg attrs: list of required attribute names, each item may be a list of allowable alternatives :arg text: textual content to match :type rep: :obj:`dict` or :obj:`string` or :class:`~xml.etree.ElementTree.Element` - :see: :ref:`dtree` """ ele = dtree2ele(rep) err = False diff --git a/ncclient/manager.py b/ncclient/manager.py index 1795fa6..14d6443 100644 --- a/ncclient/manager.py +++ b/ncclient/manager.py @@ -28,12 +28,15 @@ connect = ssh_connect # default session type RAISE_ALL, RAISE_ERROR, RAISE_NONE = range(3) class Manager: - + "Thin layer of abstraction for the ncclient API." - - def __init__(self, session, rpc_error=RAISE_ALL): + + def __init__(self, session): self._session = session - self._raise = rpc_error + self._rpc_error_handling = RAISE_ALL + + def set_rpc_error_option(self, option): + self._rpc_error_handling = option def do(self, op, *args, **kwds): op = operations.OPERATIONS[op](self._session) @@ -46,44 +49,44 @@ class Manager: if error.severity == 'error': raise error return reply - + def __enter__(self): pass - + def __exit__(self, *args): self.close() return False - + def locked(self, target): - """Returns a context manager for use withthe 'with' statement. + """Returns a context manager for use with the 'with' statement. `target` is the datastore to lock, e.g. 'candidate """ return operations.LockContext(self._session, target) - + get = lambda self, *args, **kwds: self.do('get', *args, **kwds).data - + get_config = lambda self, *args, **kwds: self.do('get-config', *args, **kwds).data - + edit_config = lambda self, *args, **kwds: self.do('edit-config', *args, **kwds) - + copy_config = lambda self, *args, **kwds: self.do('copy-config', *args, **kwds) - + validate = lambda self, *args, **kwds: self.do('validate', *args, **kwds) - + commit = lambda self, *args, **kwds: self.do('commit', *args, **kwds) - + discard_changes = lambda self, *args, **kwds: self.do('discard-changes', *args, **kwds) - + delete_config = lambda self, *args, **kwds: self.do('delete-config', *args, **kwds) - + lock = lambda self, *args, **kwds: self.do('lock', *args, **kwds) - + unlock = lambda self, *args, **kwds: self.do('unlock', *args, **kwds) - + close_session = lambda self, *args, **kwds: self.do('close-session', *args, **kwds) - + kill_session = lambda self, *args, **kwds: self.do('kill-session', *args, **kwds) - + def close(self): try: # try doing it clean self.close_session() @@ -91,17 +94,21 @@ class Manager: pass if self._session.connected: # if that didn't work... self._session.close() - + @property def session(self, session): return self._session - + def get_capabilities(self, whose): if whose in ('manager', 'client'): return self._session._client_capabilities elif whose in ('agent', 'server'): return self._session._server_capabilities - + @property def capabilities(self): return self._session._client_capabilities + + @property + def server_capabilities(self): + return self._session._server_capabilities diff --git a/ncclient/operations/retrieve.py b/ncclient/operations/retrieve.py index d70d0d1..9ea4c95 100644 --- a/ncclient/operations/retrieve.py +++ b/ncclient/operations/retrieve.py @@ -28,7 +28,7 @@ class GetReply(RPCReply): def _parsing_hook(self, root): self._data = None if not self._errors: - self._data = content.find(root, 'data', strict=False) + self._data = content.find(root, 'data', nslist=[content.BASE_NS, content.CISCO_BS]) @property def data(self): diff --git a/ncclient/operations/rpc.py b/ncclient/operations/rpc.py index 6e37a3d..c82d872 100644 --- a/ncclient/operations/rpc.py +++ b/ncclient/operations/rpc.py @@ -22,35 +22,32 @@ from ncclient.transport import SessionListener from errors import OperationError import logging -logger = logging.getLogger('ncclient.rpc') +logger = logging.getLogger('ncclient.operations.rpc') class RPCReply: - - 'NOTES: memory considerations?? storing both raw xml + ET.Element' - + def __init__(self, raw): self._raw = raw self._parsed = False self._root = None self._errors = [] - + def __repr__(self): return self._raw - - def _parsing_hook(self, root): - pass - + + def _parsing_hook(self, root): pass + def parse(self): if self._parsed: return root = self._root = content.xml2ele(self._raw) # element # per rfc 4741 an tag is sent when there are no errors or warnings - ok = content.find(root, 'data', strict=False) + ok = content.find(root, 'data', nslist=[content.BASE_NS, content.CISCO_BS]) if ok is not None: logger.debug('parsed [%s]' % ok.tag) else: # create RPCError objects from elements - error = content.find(root, 'data', strict=False) + error = content.find(root, 'data', nslist=[content.BASE_NS, content.CISCO_BS]) if error is not None: logger.debug('parsed [%s]' % error.tag) for err in root.getiterator(error.tag): @@ -65,18 +62,18 @@ class RPCReply: self._errors.append(RPCError(d)) self._parsing_hook(root) self._parsed = True - + @property def xml(self): ' as returned' return self._raw - + @property def ok(self): if not self._parsed: self.parse() return not self._errors # empty list => false - + @property def error(self): if not self._parsed: @@ -85,7 +82,7 @@ class RPCReply: return self._errors[0] else: return None - + @property def errors(self): 'List of RPCError objects. Will be empty if no elements in reply.' @@ -95,65 +92,65 @@ class RPCReply: class RPCError(OperationError): # raise it if you like - + def __init__(self, err_dict): self._dict = err_dict if self.message is not None: OperationError.__init__(self, self.message) else: OperationError.__init__(self) - + @property def type(self): return self.get('error-type', None) - + @property def severity(self): return self.get('error-severity', None) - + @property def tag(self): return self.get('error-tag', None) - + @property def path(self): return self.get('error-path', None) - + @property def message(self): return self.get('error-message', None) - + @property def info(self): return self.get('error-info', None) ## dictionary interface - + __getitem__ = lambda self, key: self._dict.__getitem__(key) - + __iter__ = lambda self: self._dict.__iter__() - + __contains__ = lambda self, key: self._dict.__contains__(key) - + keys = lambda self: self._dict.keys() - + get = lambda self, key, default: self._dict.get(key, default) - + iteritems = lambda self: self._dict.iteritems() - + iterkeys = lambda self: self._dict.iterkeys() - + itervalues = lambda self: self._dict.itervalues() - + values = lambda self: self._dict.values() - + items = lambda self: self._dict.items() - + __repr__ = lambda self: repr(self._dict) class RPCReplyListener(SessionListener): - + # one instance per session def __new__(cls, session): instance = session.get_listener_instance(cls) @@ -164,11 +161,11 @@ class RPCReplyListener(SessionListener): instance._pipelined = session.can_pipeline session.add_listener(instance) return instance - + def register(self, id, rpc): with self._lock: self._id2rpc[id] = rpc - + def callback(self, root, raw): tag, attrs = root if content.unqualify(tag) != 'rpc-reply': @@ -195,17 +192,17 @@ class RPCReplyListener(SessionListener): logger.warning(' without message-id received: %s' % raw) logger.debug('delivering to %r' % rpc) rpc.deliver(raw) - + def errback(self, err): for rpc in self._id2rpc.values(): rpc.error(err) class RPC(object): - + DEPENDS = [] REPLY_CLS = RPCReply - + def __init__(self, session, async=False, timeout=None): if not session.can_pipeline: raise UserWarning('Asynchronous mode not supported for this device/session') @@ -214,7 +211,7 @@ class RPC(object): for cap in self.DEPENDS: self._assert(cap) except AttributeError: - pass + pass self._async = async self._timeout = timeout # keeps things simple instead of having a class attr that has to be locked @@ -223,74 +220,77 @@ class RPC(object): self._listener = RPCReplyListener(session) self._listener.register(self._id, self) self._reply = None + self._error = None self._reply_event = Event() - + def _build(self, opspec): "TODO: docstring" spec = { 'tag': content.qualify('rpc'), - 'attributes': {'message-id': self._id}, + 'attrib': {'message-id': self._id}, 'subtree': opspec } return content.dtree2xml(spec) - + def _request(self, op): req = self._build(op) self._session.send(req) if self._async: - return (self._reply_event, self._error_event) + return self._reply_event else: self._reply_event.wait(self._timeout) - if self._reply_event.is_set(): + if self._reply_event.isSet(): if self._error: raise self._error self._reply.parse() return self._reply else: raise ReplyTimeoutError - + def request(self): return self._request(self.SPEC) - + def _delivery_hook(self): 'For subclasses' pass - + def _assert(self, capability): if capability not in self._session.server_capabilities: raise MissingCapabilityError('Server does not support [%s]' % cap) - + def deliver(self, raw): self._reply = self.REPLY_CLS(raw) self._delivery_hook() self._reply_event.set() - + def error(self, err): self._error = err self._reply_event.set() - + @property def has_reply(self): return self._reply_event.is_set() - + @property def reply(self): + if self.error: + raise self._error return self._reply - + @property def id(self): return self._id - + @property def session(self): return self._session - + @property def reply_event(self): return self._reply_event - + def set_async(self, bool): self._async = bool async = property(fget=lambda self: self._async, fset=set_async) - + def set_timeout(self, timeout): self._timeout = timeout timeout = property(fget=lambda self: self._timeout, fset=set_timeout) diff --git a/ncclient/operations/subscribe.py b/ncclient/operations/subscribe.py index c4607d3..42f8035 100644 --- a/ncclient/operations/subscribe.py +++ b/ncclient/operations/subscribe.py @@ -14,8 +14,10 @@ from rpc import RPC -from ncclient.glue import Listener from ncclient.content import qualify as _ +from ncclient.transport import SessionListener + +NOTIFICATION_NS = 'urn:ietf:params:xml:ns:netconf:notification:1.0' # TODO when can actually test it... @@ -28,4 +30,4 @@ class CreateSubscription(RPC): class Notification: pass -class NotificationListener(Listener): pass +class NotificationListener(SessionListener): pass diff --git a/ncclient/operations/util.py b/ncclient/operations/util.py index 2ca923e..20ba089 100644 --- a/ncclient/operations/util.py +++ b/ncclient/operations/util.py @@ -46,7 +46,7 @@ def build_filter(spec, capcheck=None): type, criteria = tuple rep = { 'tag': 'filter', - 'attributes': {'type': type}, + 'attrib': {'type': type}, 'subtree': criteria } else: diff --git a/ncclient/transport/errors.py b/ncclient/transport/errors.py index 683129d..532e452 100644 --- a/ncclient/transport/errors.py +++ b/ncclient/transport/errors.py @@ -21,7 +21,7 @@ class AuthenticationError(TransportError): pass class SessionCloseError(TransportError): - + def __init__(self, in_buf, out_buf=None): msg = 'Unexpected session close.' if in_buf: @@ -34,9 +34,9 @@ class SSHError(TransportError): pass class SSHUnknownHostError(SSHError): - - def __init__(self, hostname, key): - from binascii import hexlify - SSHError(self, 'Unknown host key [%s] for [%s]' - % (hexlify(key.get_fingerprint()), hostname)) - self.hostname = hostname + + def __init__(self, host, fingerprint): + SSHError.__init__(self, 'Unknown host key [%s] for [%s]' + % (fingerprint, host)) + self.host = host + self.fingerprint = fingerprint diff --git a/ncclient/transport/session.py b/ncclient/transport/session.py index 0a28782..8cdbc33 100644 --- a/ncclient/transport/session.py +++ b/ncclient/transport/session.py @@ -22,14 +22,14 @@ import logging logger = logging.getLogger('ncclient.transport.session') class Session(Thread): - "This is a base class for use by protocol implementations" - + "Base class for use by transport protocol implementations." + def __init__(self, capabilities): Thread.__init__(self) - self.set_daemon(True) - self._listeners = set() # 3.0's weakset ideal + self.setDaemon(True) + self._listeners = set() # 3.0's weakset would be ideal self._lock = Lock() - self.set_name('session') + self.setName('session') self._q = Queue() self._client_capabilities = capabilities self._server_capabilities = None # yet @@ -37,7 +37,7 @@ class Session(Thread): self._connected = False # to be set/cleared by subclass implementation logger.debug('%r created: client_capabilities=%r' % (self, self._client_capabilities)) - + def _dispatch_message(self, raw): try: root = content.parse_root(raw) @@ -52,7 +52,7 @@ class Session(Thread): l.callback(root, raw) except Exception as e: logger.warning('[error] %r' % e) - + def _dispatch_error(self, err): with self._lock: listeners = list(self._listeners) @@ -62,7 +62,7 @@ class Session(Thread): l.errback(err) except Exception as e: logger.warning('error %r' % e) - + def _post_connect(self): "Greeting stuff" init_event = Event() @@ -86,109 +86,120 @@ class Session(Thread): self.remove_listener(listener) if error[0]: raise error[0] - logger.info('initialized: session-id=%s | server_capabilities=%s' % (self._id, self._server_capabilities)) - + logger.info('initialized: session-id=%s | server_capabilities=%s' % + (self._id, self._server_capabilities)) + def add_listener(self, listener): - """Register a listener that will be notified of incoming messages and errors. - - :type listener: :class:`SessionListener` + """Register a listener that will be notified of incoming messages and + errors. + + :arg listener: :class:`SessionListener` """ logger.debug('installing listener %r' % listener) if not isinstance(listener, SessionListener): raise SessionError("Listener must be a SessionListener type") with self._lock: self._listeners.add(listener) - + def remove_listener(self, listener): - "Unregister some listener; ignoring if the listener was never registered." + """Unregister some listener; ignore if the listener was never + registered.""" logger.debug('discarding listener %r' % listener) with self._lock: self._listeners.discard(listener) - + def get_listener_instance(self, cls): - """If a listener of the specified type is registered, returns it. This is useful when it is desirable to have only one instance of a particular type per session, i.e. a multiton. - - :type cls: :class:`type` - :rtype: :class:`SessionListener` or :const:`None` + """If a listener of the sspecified type is registered, returns the + instance. This is useful when it is desirable to have only one instance + of a particular type per session, i.e. a multiton. + + :arg cls: class of the listener """ with self._lock: for listener in self._listeners: if isinstance(listener, cls): return listener - + def connect(self, *args, **kwds): # subclass implements raise NotImplementedError def run(self): # subclass implements raise NotImplementedError - + def send(self, message): - """ - :param message: XML document - :type message: string + """Send the supplied *message* to NETCONF server. + + :arg message: an XML document + + :type message: :obj:`string` """ logger.debug('queueing %s' % message) self._q.put(message) - + ### Properties @property def connected(self): - ":rtype: bool" + "Connection status of the session." return self._connected @property def client_capabilities(self): - ":rtype: :class:`Capabilities`" + "Client's :class:`Capabilities`" return self._client_capabilities - + @property def server_capabilities(self): - ":rtype: :class:`Capabilities` or :const:`None`" + "Server's :class:`Capabilities`" return self._server_capabilities - + @property def id(self): - ":rtype: :obj:`string` or :const:`None`" + """A :obj:`string` representing the `session-id`. If the session has not + been initialized it will be :const:`None`""" return self._id - + @property def can_pipeline(self): - ":rtype: :obj:`bool`" + "Whether this session supports pipelining" return True class SessionListener(object): - - """'Listen' to incoming messages on a NETCONF :class:`Session` - + + """Base class for :class:`Session` listeners, which are notified when a new + NETCONF message is received or an error occurs. + .. note:: - Avoid computationally intensive tasks in the callbacks. + Avoid time-intensive tasks in a callback's context. """ - + def callback(self, root, raw): - """Called when a new XML document is received. The `root` argument allows the callback to determine whether it wants to further process the document. - - :param root: tuple of (tag, attrs) where tag is the qualified name of the root element and attrs is a dictionary of its attributes (also qualified names) - :param raw: XML document - :type raw: string + """Called when a new XML document is received. The `root` argument + allows the callback to determine whether it wants to further process the + document. + + :arg root: is a tuple of `(tag, attributes)` where `tag` is the qualified name of the root element and `attributes` is a dictionary of its attributes (also qualified names) + + :arg raw: XML document + :type raw: :obj:`string` """ raise NotImplementedError - + def errback(self, ex): """Called when an error occurs. - - :type ex: :class:`Exception` + + :type ex: :exc:`Exception` """ raise NotImplementedError class HelloHandler(SessionListener): - + def __init__(self, init_cb, error_cb): self._init_cb = init_cb self._error_cb = error_cb - + def callback(self, root, raw): if content.unqualify(root[0]) == 'hello': try: @@ -197,10 +208,10 @@ class HelloHandler(SessionListener): self._error_cb(e) else: self._init_cb(id, capabilities) - + def errback(self, err): self._error_cb(err) - + @staticmethod def build(capabilities): "Given a list of capability URI's returns message XML string" @@ -213,7 +224,7 @@ class HelloHandler(SessionListener): }] } return content.dtree2xml(spec) - + @staticmethod def parse(raw): "Returns tuple of (session-id (str), capabilities (Capabilities)" diff --git a/ncclient/transport/ssh.py b/ncclient/transport/ssh.py index 0f41fb9..21a56ab 100644 --- a/ncclient/transport/ssh.py +++ b/ncclient/transport/ssh.py @@ -30,10 +30,23 @@ BUF_SIZE = 4096 MSG_DELIM = ']]>]]>' TICK = 0.1 +def default_unknown_host_cb(host, key): + """An `unknown host callback` returns :const:`True` if it finds the key + acceptable, and :const:`False` if not. + + :arg host: the hostname/address which needs to be verified + + :arg key: a hex string representing the host key fingerprint + + :returns: this default callback always returns :const:`False` + """ + return False + + class SSHSession(Session): - - "A NETCONF SSH session, per :rfc:`4742`" - + + "Implements a :rfc:`4742` NETCONF session over SSH." + def __init__(self, capabilities): Session.__init__(self, capabilities) self._host_keys = paramiko.HostKeys() @@ -44,15 +57,14 @@ class SSHSession(Session): self._expecting_close = False self._buffer = StringIO() # for incoming data # parsing-related, see _parse() - self._parsing_state = 0 + self._parsing_state = 0 self._parsing_pos = 0 - + def _parse(self): '''Messages ae delimited by MSG_DELIM. The buffer could have grown by a maximum of BUF_SIZE bytes everytime this method is called. Retains state - across method calls and if a byte has been read it will not be considered - again. - ''' + across method calls and if a byte has been read it will not be + considered again. ''' delim = MSG_DELIM n = len(delim) - 1 expect = self._parsing_state @@ -90,7 +102,7 @@ class SSHSession(Session): self._buffer = buf self._parsing_state = expect self._parsing_pos = self._buffer.tell() - + def load_system_host_keys(self, filename=None): if filename is None: filename = os.path.expanduser('~/.ssh/known_hosts') @@ -105,31 +117,57 @@ class SSHSession(Session): pass return self._system_host_keys.load(filename) - + def load_host_keys(self, filename): self._host_keys.load(filename) def add_host_key(self, key): self._host_keys.add(key) - + def save_host_keys(self, filename): f = open(filename, 'w') for host, keys in self._host_keys.iteritems(): for keytype, key in keys.iteritems(): f.write('%s %s %s\n' % (host, keytype, key.get_base64())) - f.close() - + f.close() + def close(self): self._expecting_close = True if self._transport.is_active(): self._transport.close() self._connected = False - + def connect(self, host, port=830, timeout=None, - unknown_host_cb=None, username=None, password=None, + unknown_host_cb=default_unknown_host_cb, + username=None, password=None, key_filename=None, allow_agent=True, look_for_keys=True): + """Connect via SSH and initialize the NETCONF session. First attempts + the publickey authentication method and then password authentication. + + To disable publickey authentication, call with *allow_agent* and + *look_for_keys* as :const:`False` + + :arg host: the hostname or IP address to connect to + + :arg port: by default 830, but some devices use the default SSH port of 22 so this may need to be specified + + :arg timeout: an optional timeout for the TCP handshake + + :arg unknown_host_cb: called when a host key is not known. See :func:`unknown_host_cb` for details on signature + + :arg username: the username to use for SSH authentication + + :arg password: the password used if using password authentication, or the passphrase to use in order to unlock keys that require it + + :arg key_filename: a filename where a the private key to be used can be found + + :arg allow_agent: enables querying SSH agent (if found) for keys + + :arg look_for_keys: enables looking in the usual locations for ssh keys (e.g. :file:`~/.ssh/id_*`) + """ + assert(username is not None) - + for (family, socktype, proto, canonname, sockaddr) in \ socket.getaddrinfo(host, port): if socktype == socket.SOCK_STREAM: @@ -143,44 +181,43 @@ class SSHSession(Session): sock.connect(addr) t = self._transport = paramiko.Transport(sock) t.set_log_channel(logger.name) - + try: t.start_client() except paramiko.SSHException: raise SSHError('Negotiation failed') - + # host key verification server_key = t.get_remote_server_key() known_host = self._host_keys.check(host, server_key) or \ self._system_host_keys.check(host, server_key) - - if unknown_host_cb is None: - unknown_host_cb = lambda *args: False - if not known_host and not unknown_host_cb(host, server_key): - raise SSHUnknownHostError(host, server_key) - + + fp = hexlify(server_key.get_fingerprint()) + if not known_host and not unknown_host_cb(host, fp): + raise SSHUnknownHostError(host, fp) + if key_filename is None: key_filenames = [] elif isinstance(key_filename, basestring): key_filenames = [ key_filename ] else: key_filenames = key_filename - + self._auth(username, password, key_filenames, allow_agent, look_for_keys) - + self._connected = True # there was no error authenticating - + c = self._channel = self._transport.open_session() c.set_name('netconf') c.invoke_subsystem('netconf') - + self._post_connect() - + # on the lines of paramiko.SSHClient._auth() def _auth(self, username, password, key_filenames, allow_agent, look_for_keys): saved_exception = None - + for key_filename in key_filenames: for cls in (paramiko.RSAKey, paramiko.DSSKey): try: @@ -192,7 +229,7 @@ class SSHSession(Session): except Exception as e: saved_exception = e logger.debug(e) - + if allow_agent: for key in paramiko.Agent().get_keys(): try: @@ -203,7 +240,7 @@ class SSHSession(Session): except Exception as e: saved_exception = e logger.debug(e) - + keyfiles = [] if look_for_keys: rsa_key = os.path.expanduser('~/.ssh/id_rsa') @@ -219,7 +256,7 @@ class SSHSession(Session): keyfiles.append((paramiko.RSAKey, rsa_key)) if os.path.isfile(dsa_key): keyfiles.append((paramiko.DSSKey, dsa_key)) - + for cls, filename in keyfiles: try: key = cls.from_private_key_file(filename, password) @@ -230,7 +267,7 @@ class SSHSession(Session): except Exception as e: saved_exception = e logger.debug(e) - + if password is not None: try: self._transport.auth_password(username, password) @@ -238,13 +275,13 @@ class SSHSession(Session): except Exception as e: saved_exception = e logger.debug(e) - + if saved_exception is not None: # need pep-3134 to do this right raise SSHAuthenticationError(repr(saved_exception)) - + raise SSHAuthenticationError('No authentication methods available') - + def run(self): chan = self._channel chan.setblocking(0) @@ -252,7 +289,7 @@ class SSHSession(Session): try: while True: # select on a paramiko ssh channel object does not ever return - # it in the writable list, so it channel's don't exactly emulate + # it in the writable list, so it channel's don't exactly emulate # the socket api r, w, e = select([chan], [], [], TICK) # will wakeup evey TICK seconds to check if something @@ -278,12 +315,15 @@ class SSHSession(Session): self.close() if not (isinstance(e, SessionCloseError) and self._expecting_close): self._dispatch_error(e) - + @property def transport(self): - "gug" + """The underlying `paramiko.Transport + `_ + object. This makes it possible to call methods like set_keepalive on it. + """ return self._transport - + @property def can_pipeline(self): if 'Cisco' in self._transport.remote_version: diff --git a/ncclient/util.py b/ncclient/util.py index 548447e..65429cf 100644 --- a/ncclient/util.py +++ b/ncclient/util.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ncclient.glue import Listener +from ncclient.transport import SessionListener + +class PrintListener(SessionListener): -class PrintListener(Listener): - def callback(self, root, raw): print('\n# RECEIVED MESSAGE with root=[tag=%r, attrs=%r] #\n%r\n' % (root[0], root[1], raw)) - + def errback(self, err): print('\n# RECEIVED ERROR #\n%r\n' % err) -- 1.7.10.4