docstrings and fixes
authorShikhar Bhushan <shikhar@schmizz.net>
Thu, 14 May 2009 13:49:26 +0000 (13:49 +0000)
committerShikhar Bhushan <shikhar@schmizz.net>
Thu, 14 May 2009 13:49:26 +0000 (13:49 +0000)
git-svn-id: http://ncclient.googlecode.com/svn/trunk@119 6dbcf712-26ac-11de-a2f3-1373824ab735

ncclient/content.py
ncclient/manager.py
ncclient/operations/retrieve.py
ncclient/operations/rpc.py
ncclient/operations/subscribe.py
ncclient/operations/util.py
ncclient/transport/errors.py
ncclient/transport/session.py
ncclient/transport/ssh.py
ncclient/util.py

index 72a73ce..434231d 100644 (file)
@@ -53,8 +53,9 @@ class DictTree:
     @staticmethod
     def Element(spec):
         """DictTree -> Element
-        
+
         :type spec: :obj:`dict` or :obj:`string` or :class:`~xml.etree.ElementTree.Element`
+
         :rtype: :class:`~xml.etree.ElementTree.Element`
         """
         if iselement(spec):
@@ -64,7 +65,7 @@ class DictTree:
         if not isinstance(spec, dict):
             raise ContentError("Invalid tree spec")
         if 'tag' in spec:
-            ele = ET.Element(spec.get('tag'), spec.get('attributes', {}))
+            ele = ET.Element(spec.get('tag'), spec.get('attrib', {}))
             ele.text = spec.get('text', '')
             ele.tail = spec.get('tail', '')
             subtree = spec.get('subtree', [])
@@ -78,23 +79,25 @@ class DictTree:
             return ET.Comment(spec.get('comment'))
         else:
             raise ContentError('Invalid tree spec')
-    
+
     @staticmethod
     def XML(spec, encoding='UTF-8'):
         """DictTree -> XML
-        
+
         :type spec: :obj:`dict` or :obj:`string` or :class:`~xml.etree.ElementTree.Element`
+
         :arg encoding: chraracter encoding
+
         :rtype: string
         """
         return Element.XML(DictTree.Element(spec), encoding)
 
 class Element:
-    
+
     @staticmethod
     def DictTree(ele):
         """DictTree -> Element
-        
+
         :type spec: :class:`~xml.etree.ElementTree.Element`
         :rtype: :obj:`dict`
         """
@@ -105,11 +108,11 @@ class Element:
             'tail': ele.tail,
             'subtree': [ Element.DictTree(child) for child in root.getchildren() ]
         }
-    
+
     @staticmethod
     def XML(ele, encoding='UTF-8'):
         """Element -> XML
-        
+
         :type spec: :class:`~xml.etree.ElementTree.Element`
         :arg encoding: character encoding
         :rtype: :obj:`string`
@@ -121,20 +124,20 @@ class Element:
             return '<?xml version="1.0" encoding="%s"?>%s' % (encoding, xml)
 
 class XML:
-    
+
     @staticmethod
     def DictTree(xml):
         """XML -> DictTree
-        
+
         :type spec: :obj:`string`
         :rtype: :obj:`dict`
         """
         return Element.DictTree(XML.Element(xml))
-    
+
     @staticmethod
     def Element(xml):
         """XML -> Element
-        
+
         :type xml: :obj:`string`
         :rtype: :class:`~xml.etree.ElementTree.Element`
         """
@@ -153,7 +156,7 @@ iselement = ET.iselement
 
 def find(ele, tag, nslist=[]):
     """If `nslist` is empty, same as :meth:`xml.etree.ElementTree.Element.find`. If it is not, `tag` is interpreted as an unqualified name and qualified using each item in `nslist`. The first match is returned.
-    
+
     :arg nslist: optional list of namespaces
     """
     if nslist:
@@ -166,7 +169,7 @@ def find(ele, tag, nslist=[]):
 
 def parse_root(raw):
     """Efficiently parses the root element of an XML document.
-    
+
     :type raw: string
     :returns: a tuple of `(tag, attributes)`, where `tag` is the (qualified) name of the element and `attributes` is a dictionary of its attributes.
     """
@@ -176,12 +179,11 @@ def parse_root(raw):
 
 def validated_element(rep, tag=None, attrs=None, text=None):
     """Checks if the root element meets the supplied criteria. Returns a :class:`~xml.etree.ElementTree.Element` instance if so, otherwise raises :exc:`ContentError`.
-    
+
     :arg tag: tag name or a list of allowable tag names
     :arg attrs: list of required attribute names, each item may be a list of allowable alternatives
     :arg text: textual content to match
     :type rep: :obj:`dict` or :obj:`string` or :class:`~xml.etree.ElementTree.Element`
-    :see: :ref:`dtree`
     """
     ele = dtree2ele(rep)
     err = False
index 1795fa6..14d6443 100644 (file)
@@ -28,12 +28,15 @@ connect = ssh_connect # default session type
 RAISE_ALL, RAISE_ERROR, RAISE_NONE = range(3)
 
 class Manager:
-    
+
     "Thin layer of abstraction for the ncclient API."
-    
-    def __init__(self, session, rpc_error=RAISE_ALL):
+
+    def __init__(self, session):
         self._session = session
-        self._raise = rpc_error
+        self._rpc_error_handling = RAISE_ALL
+
+    def set_rpc_error_option(self, option):
+        self._rpc_error_handling = option
 
     def do(self, op, *args, **kwds):
         op = operations.OPERATIONS[op](self._session)
@@ -46,44 +49,44 @@ class Manager:
                     if error.severity == 'error':
                         raise error
         return reply
-    
+
     def __enter__(self):
         pass
-    
+
     def __exit__(self, *args):
         self.close()
         return False
-    
+
     def locked(self, target):
-        """Returns a context manager for use withthe 'with' statement.
+        """Returns a context manager for use with the 'with' statement.
         `target` is the datastore to lock, e.g. 'candidate
         """
         return operations.LockContext(self._session, target)
-     
+
     get = lambda self, *args, **kwds: self.do('get', *args, **kwds).data
-    
+
     get_config = lambda self, *args, **kwds: self.do('get-config', *args, **kwds).data
-    
+
     edit_config = lambda self, *args, **kwds: self.do('edit-config', *args, **kwds)
-    
+
     copy_config = lambda self, *args, **kwds: self.do('copy-config', *args, **kwds)
-    
+
     validate = lambda self, *args, **kwds: self.do('validate', *args, **kwds)
-    
+
     commit = lambda self, *args, **kwds: self.do('commit', *args, **kwds)
-    
+
     discard_changes = lambda self, *args, **kwds: self.do('discard-changes', *args, **kwds)
-    
+
     delete_config = lambda self, *args, **kwds: self.do('delete-config', *args, **kwds)
-    
+
     lock = lambda self, *args, **kwds: self.do('lock', *args, **kwds)
-    
+
     unlock = lambda self, *args, **kwds: self.do('unlock', *args, **kwds)
-    
+
     close_session = lambda self, *args, **kwds: self.do('close-session', *args, **kwds)
-    
+
     kill_session = lambda self, *args, **kwds: self.do('kill-session', *args, **kwds)
-    
+
     def close(self):
         try: # try doing it clean
             self.close_session()
@@ -91,17 +94,21 @@ class Manager:
             pass
         if self._session.connected: # if that didn't work...
             self._session.close()
-    
+
     @property
     def session(self, session):
         return self._session
-    
+
     def get_capabilities(self, whose):
         if whose in ('manager', 'client'):
             return self._session._client_capabilities
         elif whose in ('agent', 'server'):
             return self._session._server_capabilities
-    
+
     @property
     def capabilities(self):
         return self._session._client_capabilities
+
+    @property
+    def server_capabilities(self):
+        return self._session._server_capabilities
index d70d0d1..9ea4c95 100644 (file)
@@ -28,7 +28,7 @@ class GetReply(RPCReply):
     def _parsing_hook(self, root):
         self._data = None
         if not self._errors:
-            self._data = content.find(root, 'data', strict=False)
+            self._data = content.find(root, 'data', nslist=[content.BASE_NS, content.CISCO_BS])
     
     @property
     def data(self):
index 6e37a3d..c82d872 100644 (file)
@@ -22,35 +22,32 @@ from ncclient.transport import SessionListener
 from errors import OperationError
 
 import logging
-logger = logging.getLogger('ncclient.rpc')
+logger = logging.getLogger('ncclient.operations.rpc')
 
 
 class RPCReply:
-    
-    'NOTES: memory considerations?? storing both raw xml + ET.Element'
-    
+
     def __init__(self, raw):
         self._raw = raw
         self._parsed = False
         self._root = None
         self._errors = []
-    
+
     def __repr__(self):
         return self._raw
-    
-    def _parsing_hook(self, root):
-        pass
-    
+
+    def _parsing_hook(self, root): pass
+
     def parse(self):
         if self._parsed:
             return
         root = self._root = content.xml2ele(self._raw) # <rpc-reply> element
         # per rfc 4741 an <ok/> tag is sent when there are no errors or warnings
-        ok = content.find(root, 'data', strict=False)
+        ok = content.find(root, 'data', nslist=[content.BASE_NS, content.CISCO_BS])
         if ok is not None:
             logger.debug('parsed [%s]' % ok.tag)
         else: # create RPCError objects from <rpc-error> elements
-            error = content.find(root, 'data', strict=False)
+            error = content.find(root, 'data', nslist=[content.BASE_NS, content.CISCO_BS])
             if error is not None:
                 logger.debug('parsed [%s]' % error.tag)
                 for err in root.getiterator(error.tag):
@@ -65,18 +62,18 @@ class RPCReply:
                     self._errors.append(RPCError(d))
         self._parsing_hook(root)
         self._parsed = True
-    
+
     @property
     def xml(self):
         '<rpc-reply> as returned'
         return self._raw
-    
+
     @property
     def ok(self):
         if not self._parsed:
             self.parse()
         return not self._errors # empty list => false
-    
+
     @property
     def error(self):
         if not self._parsed:
@@ -85,7 +82,7 @@ class RPCReply:
             return self._errors[0]
         else:
             return None
-    
+
     @property
     def errors(self):
         'List of RPCError objects. Will be empty if no <rpc-error> elements in reply.'
@@ -95,65 +92,65 @@ class RPCReply:
 
 
 class RPCError(OperationError): # raise it if you like
-    
+
     def __init__(self, err_dict):
         self._dict = err_dict
         if self.message is not None:
             OperationError.__init__(self, self.message)
         else:
             OperationError.__init__(self)
-    
+
     @property
     def type(self):
         return self.get('error-type', None)
-    
+
     @property
     def severity(self):
         return self.get('error-severity', None)
-    
+
     @property
     def tag(self):
         return self.get('error-tag', None)
-    
+
     @property
     def path(self):
         return self.get('error-path', None)
-    
+
     @property
     def message(self):
         return self.get('error-message', None)
-    
+
     @property
     def info(self):
         return self.get('error-info', None)
 
     ## dictionary interface
-    
+
     __getitem__ = lambda self, key: self._dict.__getitem__(key)
-    
+
     __iter__ = lambda self: self._dict.__iter__()
-    
+
     __contains__ = lambda self, key: self._dict.__contains__(key)
-    
+
     keys = lambda self: self._dict.keys()
-    
+
     get = lambda self, key, default: self._dict.get(key, default)
-        
+
     iteritems = lambda self: self._dict.iteritems()
-    
+
     iterkeys = lambda self: self._dict.iterkeys()
-    
+
     itervalues = lambda self: self._dict.itervalues()
-    
+
     values = lambda self: self._dict.values()
-    
+
     items = lambda self: self._dict.items()
-    
+
     __repr__ = lambda self: repr(self._dict)
 
 
 class RPCReplyListener(SessionListener):
-    
+
     # one instance per session
     def __new__(cls, session):
         instance = session.get_listener_instance(cls)
@@ -164,11 +161,11 @@ class RPCReplyListener(SessionListener):
             instance._pipelined = session.can_pipeline
             session.add_listener(instance)
         return instance
-    
+
     def register(self, id, rpc):
         with self._lock:
             self._id2rpc[id] = rpc
-    
+
     def callback(self, root, raw):
         tag, attrs = root
         if content.unqualify(tag) != 'rpc-reply':
@@ -195,17 +192,17 @@ class RPCReplyListener(SessionListener):
                 logger.warning('<rpc-reply> without message-id received: %s' % raw)
         logger.debug('delivering to %r' % rpc)
         rpc.deliver(raw)
-    
+
     def errback(self, err):
         for rpc in self._id2rpc.values():
             rpc.error(err)
 
 
 class RPC(object):
-    
+
     DEPENDS = []
     REPLY_CLS = RPCReply
-    
+
     def __init__(self, session, async=False, timeout=None):
         if not session.can_pipeline:
             raise UserWarning('Asynchronous mode not supported for this device/session')
@@ -214,7 +211,7 @@ class RPC(object):
             for cap in self.DEPENDS:
                 self._assert(cap)
         except AttributeError:
-            pass        
+            pass
         self._async = async
         self._timeout = timeout
         # keeps things simple instead of having a class attr that has to be locked
@@ -223,74 +220,77 @@ class RPC(object):
         self._listener = RPCReplyListener(session)
         self._listener.register(self._id, self)
         self._reply = None
+        self._error = None
         self._reply_event = Event()
-    
+
     def _build(self, opspec):
         "TODO: docstring"
         spec = {
             'tag': content.qualify('rpc'),
-            'attributes': {'message-id': self._id},
+            'attrib': {'message-id': self._id},
             'subtree': opspec
             }
         return content.dtree2xml(spec)
-    
+
     def _request(self, op):
         req = self._build(op)
         self._session.send(req)
         if self._async:
-            return (self._reply_event, self._error_event)
+            return self._reply_event
         else:
             self._reply_event.wait(self._timeout)
-            if self._reply_event.is_set():
+            if self._reply_event.isSet():
                 if self._error:
                     raise self._error
                 self._reply.parse()
                 return self._reply
             else:
                 raise ReplyTimeoutError
-    
+
     def request(self):
         return self._request(self.SPEC)
-    
+
     def _delivery_hook(self):
         'For subclasses'
         pass
-    
+
     def _assert(self, capability):
         if capability not in self._session.server_capabilities:
             raise MissingCapabilityError('Server does not support [%s]' % cap)
-    
+
     def deliver(self, raw):
         self._reply = self.REPLY_CLS(raw)
         self._delivery_hook()
         self._reply_event.set()
-    
+
     def error(self, err):
         self._error = err
         self._reply_event.set()
-    
+
     @property
     def has_reply(self):
         return self._reply_event.is_set()
-    
+
     @property
     def reply(self):
+        if self.error:
+            raise self._error
         return self._reply
-    
+
     @property
     def id(self):
         return self._id
-    
+
     @property
     def session(self):
         return self._session
-    
+
     @property
     def reply_event(self):
         return self._reply_event
-    
+
     def set_async(self, bool): self._async = bool
     async = property(fget=lambda self: self._async, fset=set_async)
-    
+
     def set_timeout(self, timeout): self._timeout = timeout
     timeout = property(fget=lambda self: self._timeout, fset=set_timeout)
index c4607d3..42f8035 100644 (file)
 
 from rpc import RPC
 
-from ncclient.glue import Listener
 from ncclient.content import qualify as _
+from ncclient.transport import SessionListener
+
+NOTIFICATION_NS = 'urn:ietf:params:xml:ns:netconf:notification:1.0'
 
 # TODO when can actually test it...
 
@@ -28,4 +30,4 @@ class CreateSubscription(RPC):
 
 class Notification: pass
 
-class NotificationListener(Listener): pass
+class NotificationListener(SessionListener): pass
index 2ca923e..20ba089 100644 (file)
@@ -46,7 +46,7 @@ def build_filter(spec, capcheck=None):
         type, criteria = tuple
         rep = {
             'tag': 'filter',
-            'attributes': {'type': type},
+            'attrib': {'type': type},
             'subtree': criteria
         }
     else:
index 683129d..532e452 100644 (file)
@@ -21,7 +21,7 @@ class AuthenticationError(TransportError):
     pass
 
 class SessionCloseError(TransportError):
-    
+
     def __init__(self, in_buf, out_buf=None):
         msg = 'Unexpected session close.'
         if in_buf:
@@ -34,9 +34,9 @@ class SSHError(TransportError):
     pass
 
 class SSHUnknownHostError(SSHError):
-    
-    def __init__(self, hostname, key):
-        from binascii import hexlify
-        SSHError(self, 'Unknown host key [%s] for [%s]'
-                 % (hexlify(key.get_fingerprint()), hostname))
-        self.hostname = hostname
+
+    def __init__(self, host, fingerprint):
+        SSHError.__init__(self, 'Unknown host key [%s] for [%s]'
+                          % (fingerprint, host))
+        self.host = host
+        self.fingerprint = fingerprint
index 0a28782..8cdbc33 100644 (file)
@@ -22,14 +22,14 @@ import logging
 logger = logging.getLogger('ncclient.transport.session')
 
 class Session(Thread):
-    "This is a base class for use by protocol implementations"
-    
+    "Base class for use by transport protocol implementations."
+
     def __init__(self, capabilities):
         Thread.__init__(self)
-        self.set_daemon(True)
-        self._listeners = set() # 3.0's weakset ideal
+        self.setDaemon(True)
+        self._listeners = set() # 3.0's weakset would be ideal
         self._lock = Lock()
-        self.set_name('session')
+        self.setName('session')
         self._q = Queue()
         self._client_capabilities = capabilities
         self._server_capabilities = None # yet
@@ -37,7 +37,7 @@ class Session(Thread):
         self._connected = False # to be set/cleared by subclass implementation
         logger.debug('%r created: client_capabilities=%r' %
                      (self, self._client_capabilities))
-    
+
     def _dispatch_message(self, raw):
         try:
             root = content.parse_root(raw)
@@ -52,7 +52,7 @@ class Session(Thread):
                 l.callback(root, raw)
             except Exception as e:
                 logger.warning('[error] %r' % e)
-    
+
     def _dispatch_error(self, err):
         with self._lock:
             listeners = list(self._listeners)
@@ -62,7 +62,7 @@ class Session(Thread):
                 l.errback(err)
             except Exception as e:
                 logger.warning('error %r' % e)
-    
+
     def _post_connect(self):
         "Greeting stuff"
         init_event = Event()
@@ -86,109 +86,120 @@ class Session(Thread):
         self.remove_listener(listener)
         if error[0]:
             raise error[0]
-        logger.info('initialized: session-id=%s | server_capabilities=%s' % (self._id, self._server_capabilities))
-    
+        logger.info('initialized: session-id=%s | server_capabilities=%s' %
+                    (self._id, self._server_capabilities))
+
     def add_listener(self, listener):
-        """Register a listener that will be notified of incoming messages and errors.
-        
-        :type listener: :class:`SessionListener`
+        """Register a listener that will be notified of incoming messages and
+        errors.
+
+        :arg listener: :class:`SessionListener`
         """
         logger.debug('installing listener %r' % listener)
         if not isinstance(listener, SessionListener):
             raise SessionError("Listener must be a SessionListener type")
         with self._lock:
             self._listeners.add(listener)
-    
+
     def remove_listener(self, listener):
-        "Unregister some listener; ignoring if the listener was never registered."
+        """Unregister some listener; ignore if the listener was never
+        registered."""
         logger.debug('discarding listener %r' % listener)
         with self._lock:
             self._listeners.discard(listener)
-    
+
     def get_listener_instance(self, cls):
-        """If a listener of the specified type is registered, returns it. This is useful when it is desirable to have only one instance of a particular type per session, i.e. a multiton.
-        
-        :type cls: :class:`type`
-        :rtype: :class:`SessionListener` or :const:`None`
+        """If a listener of the sspecified type is registered, returns the
+        instance. This is useful when it is desirable to have only one instance
+        of a particular type per session, i.e. a multiton.
+
+        :arg cls: class of the listener
         """
         with self._lock:
             for listener in self._listeners:
                 if isinstance(listener, cls):
                     return listener
-    
+
     def connect(self, *args, **kwds): # subclass implements
         raise NotImplementedError
 
     def run(self): # subclass implements
         raise NotImplementedError
-    
+
     def send(self, message):
-        """
-        :param message: XML document
-        :type message: string
+        """Send the supplied *message* to NETCONF server.
+
+        :arg message: an XML document
+
+        :type message: :obj:`string`
         """
         logger.debug('queueing %s' % message)
         self._q.put(message)
-    
+
     ### Properties
 
     @property
     def connected(self):
-        ":rtype: bool"
+        "Connection status of the session."
         return self._connected
 
     @property
     def client_capabilities(self):
-        ":rtype: :class:`Capabilities`"
+        "Client's :class:`Capabilities`"
         return self._client_capabilities
-    
+
     @property
     def server_capabilities(self):
-        ":rtype: :class:`Capabilities` or :const:`None`"
+        "Server's :class:`Capabilities`"
         return self._server_capabilities
-    
+
     @property
     def id(self):
-        ":rtype: :obj:`string` or :const:`None`"
+        """A :obj:`string` representing the `session-id`. If the session has not
+        been initialized it will be :const:`None`"""
         return self._id
-    
+
     @property
     def can_pipeline(self):
-        ":rtype: :obj:`bool`"
+        "Whether this session supports pipelining"
         return True
 
 
 class SessionListener(object):
-    
-    """'Listen' to incoming messages on a NETCONF :class:`Session`
-    
+
+    """Base class for :class:`Session` listeners, which are notified when a new
+    NETCONF message is received or an error occurs.
+
     .. note::
-        Avoid computationally intensive tasks in the callbacks.
+        Avoid time-intensive tasks in a callback's context.
     """
-    
+
     def callback(self, root, raw):
-        """Called when a new XML document is received. The `root` argument allows the callback to determine whether it wants to further process the document.
-        
-        :param root: tuple of (tag, attrs) where tag is the qualified name of the root element and attrs is a dictionary of its attributes (also qualified names)
-        :param raw: XML document
-        :type raw: string
+        """Called when a new XML document is received. The `root` argument
+        allows the callback to determine whether it wants to further process the
+        document.
+
+        :arg root: is a tuple of `(tag, attributes)` where `tag` is the qualified name of the root element and `attributes` is a dictionary of its attributes (also qualified names)
+
+        :arg raw: XML document
+        :type raw: :obj:`string`
         """
         raise NotImplementedError
-    
+
     def errback(self, ex):
         """Called when an error occurs.
-        
-        :type ex: :class:`Exception`
+
+        :type ex: :exc:`Exception`
         """
         raise NotImplementedError
 
 
 class HelloHandler(SessionListener):
-    
+
     def __init__(self, init_cb, error_cb):
         self._init_cb = init_cb
         self._error_cb = error_cb
-    
+
     def callback(self, root, raw):
         if content.unqualify(root[0]) == 'hello':
             try:
@@ -197,10 +208,10 @@ class HelloHandler(SessionListener):
                 self._error_cb(e)
             else:
                 self._init_cb(id, capabilities)
-    
+
     def errback(self, err):
         self._error_cb(err)
-    
+
     @staticmethod
     def build(capabilities):
         "Given a list of capability URI's returns <hello> message XML string"
@@ -213,7 +224,7 @@ class HelloHandler(SessionListener):
                 }]
             }
         return content.dtree2xml(spec)
-    
+
     @staticmethod
     def parse(raw):
         "Returns tuple of (session-id (str), capabilities (Capabilities)"
index 0f41fb9..21a56ab 100644 (file)
@@ -30,10 +30,23 @@ BUF_SIZE = 4096
 MSG_DELIM = ']]>]]>'
 TICK = 0.1
 
+def default_unknown_host_cb(host, key):
+    """An `unknown host callback` returns :const:`True` if it finds the key
+    acceptable, and :const:`False` if not.
+
+    :arg host: the hostname/address which needs to be verified
+
+    :arg key: a hex string representing the host key fingerprint
+
+    :returns: this default callback always returns :const:`False`
+    """
+    return False
+
+
 class SSHSession(Session):
-    
-    "A NETCONF SSH session, per :rfc:`4742`"
-    
+
+    "Implements a :rfc:`4742` NETCONF session over SSH."
+
     def __init__(self, capabilities):
         Session.__init__(self, capabilities)
         self._host_keys = paramiko.HostKeys()
@@ -44,15 +57,14 @@ class SSHSession(Session):
         self._expecting_close = False
         self._buffer = StringIO() # for incoming data
         # parsing-related, see _parse()
-        self._parsing_state = 0 
+        self._parsing_state = 0
         self._parsing_pos = 0
-    
+
     def _parse(self):
         '''Messages ae delimited by MSG_DELIM. The buffer could have grown by a
         maximum of BUF_SIZE bytes everytime this method is called. Retains state
-        across method calls and if a byte has been read it will not be considered
-        again.
-        '''
+        across method calls and if a byte has been read it will not be
+        considered again. '''
         delim = MSG_DELIM
         n = len(delim) - 1
         expect = self._parsing_state
@@ -90,7 +102,7 @@ class SSHSession(Session):
         self._buffer = buf
         self._parsing_state = expect
         self._parsing_pos = self._buffer.tell()
-    
+
     def load_system_host_keys(self, filename=None):
         if filename is None:
             filename = os.path.expanduser('~/.ssh/known_hosts')
@@ -105,31 +117,57 @@ class SSHSession(Session):
                     pass
             return
         self._system_host_keys.load(filename)
-    
+
     def load_host_keys(self, filename):
         self._host_keys.load(filename)
 
     def add_host_key(self, key):
         self._host_keys.add(key)
-    
+
     def save_host_keys(self, filename):
         f = open(filename, 'w')
         for host, keys in self._host_keys.iteritems():
             for keytype, key in keys.iteritems():
                 f.write('%s %s %s\n' % (host, keytype, key.get_base64()))
-        f.close()    
-    
+        f.close()
+
     def close(self):
         self._expecting_close = True
         if self._transport.is_active():
             self._transport.close()
         self._connected = False
-    
+
     def connect(self, host, port=830, timeout=None,
-                unknown_host_cb=None, username=None, password=None,
+                unknown_host_cb=default_unknown_host_cb,
+                username=None, password=None,
                 key_filename=None, allow_agent=True, look_for_keys=True):
+        """Connect via SSH and initialize the NETCONF session. First attempts
+        the publickey authentication method and then password authentication.
+
+        To disable publickey authentication, call with *allow_agent* and
+        *look_for_keys* as :const:`False`
+
+        :arg host: the hostname or IP address to connect to
+
+        :arg port: by default 830, but some devices use the default SSH port of 22 so this may need to be specified
+
+        :arg timeout: an optional timeout for the TCP handshake
+
+        :arg unknown_host_cb: called when a host key is not known. See :func:`unknown_host_cb` for details on signature
+
+        :arg username: the username to use for SSH authentication
+
+        :arg password: the password used if using password authentication, or the passphrase to use in order to unlock keys that require it
+
+        :arg key_filename: a filename where a the private key to be used can be found
+
+        :arg allow_agent: enables querying SSH agent (if found) for keys
+
+        :arg look_for_keys: enables looking in the usual locations for ssh keys (e.g. :file:`~/.ssh/id_*`)
+        """
+
         assert(username is not None)
-        
+
         for (family, socktype, proto, canonname, sockaddr) in \
         socket.getaddrinfo(host, port):
             if socktype == socket.SOCK_STREAM:
@@ -143,44 +181,43 @@ class SSHSession(Session):
         sock.connect(addr)
         t = self._transport = paramiko.Transport(sock)
         t.set_log_channel(logger.name)
-        
+
         try:
             t.start_client()
         except paramiko.SSHException:
             raise SSHError('Negotiation failed')
-        
+
         # host key verification
         server_key = t.get_remote_server_key()
         known_host = self._host_keys.check(host, server_key) or \
                         self._system_host_keys.check(host, server_key)
-        
-        if unknown_host_cb is None:
-            unknown_host_cb = lambda *args: False
-        if not known_host and not unknown_host_cb(host, server_key):
-                raise SSHUnknownHostError(host, server_key)
-        
+
+        fp = hexlify(server_key.get_fingerprint())
+        if not known_host and not unknown_host_cb(host, fp):
+            raise SSHUnknownHostError(host, fp)
+
         if key_filename is None:
             key_filenames = []
         elif isinstance(key_filename, basestring):
             key_filenames = [ key_filename ]
         else:
             key_filenames = key_filename
-        
+
         self._auth(username, password, key_filenames, allow_agent, look_for_keys)
-        
+
         self._connected = True # there was no error authenticating
-        
+
         c = self._channel = self._transport.open_session()
         c.set_name('netconf')
         c.invoke_subsystem('netconf')
-        
+
         self._post_connect()
-    
+
     # on the lines of paramiko.SSHClient._auth()
     def _auth(self, username, password, key_filenames, allow_agent,
               look_for_keys):
         saved_exception = None
-        
+
         for key_filename in key_filenames:
             for cls in (paramiko.RSAKey, paramiko.DSSKey):
                 try:
@@ -192,7 +229,7 @@ class SSHSession(Session):
                 except Exception as e:
                     saved_exception = e
                     logger.debug(e)
-        
+
         if allow_agent:
             for key in paramiko.Agent().get_keys():
                 try:
@@ -203,7 +240,7 @@ class SSHSession(Session):
                 except Exception as e:
                     saved_exception = e
                     logger.debug(e)
-        
+
         keyfiles = []
         if look_for_keys:
             rsa_key = os.path.expanduser('~/.ssh/id_rsa')
@@ -219,7 +256,7 @@ class SSHSession(Session):
                 keyfiles.append((paramiko.RSAKey, rsa_key))
             if os.path.isfile(dsa_key):
                 keyfiles.append((paramiko.DSSKey, dsa_key))
-        
+
         for cls, filename in keyfiles:
             try:
                 key = cls.from_private_key_file(filename, password)
@@ -230,7 +267,7 @@ class SSHSession(Session):
             except Exception as e:
                 saved_exception = e
                 logger.debug(e)
-        
+
         if password is not None:
             try:
                 self._transport.auth_password(username, password)
@@ -238,13 +275,13 @@ class SSHSession(Session):
             except Exception as e:
                 saved_exception = e
                 logger.debug(e)
-        
+
         if saved_exception is not None:
             # need pep-3134 to do this right
             raise SSHAuthenticationError(repr(saved_exception))
-        
+
         raise SSHAuthenticationError('No authentication methods available')
-    
+
     def run(self):
         chan = self._channel
         chan.setblocking(0)
@@ -252,7 +289,7 @@ class SSHSession(Session):
         try:
             while True:
                 # select on a paramiko ssh channel object does not ever return
-                # it in the writable list, so it channel's don't exactly emulate 
+                # it in the writable list, so it channel's don't exactly emulate
                 # the socket api
                 r, w, e = select([chan], [], [], TICK)
                 # will wakeup evey TICK seconds to check if something
@@ -278,12 +315,15 @@ class SSHSession(Session):
             self.close()
             if not (isinstance(e, SessionCloseError) and self._expecting_close):
                 self._dispatch_error(e)
-    
+
     @property
     def transport(self):
-        "gug"
+        """The underlying `paramiko.Transport
+        <http://www.lag.net/paramiko/docs/paramiko.Transport-class.html>`_
+        object. This makes it possible to call methods like set_keepalive on it.
+        """
         return self._transport
-    
+
     @property
     def can_pipeline(self):
         if 'Cisco' in self._transport.remote_version:
index 548447e..65429cf 100644 (file)
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ncclient.glue import Listener
+from ncclient.transport import SessionListener
+
+class PrintListener(SessionListener):
 
-class PrintListener(Listener):
-    
     def callback(self, root, raw):
         print('\n# RECEIVED MESSAGE with root=[tag=%r, attrs=%r] #\n%r\n' %
               (root[0], root[1], raw))
-    
+
     def errback(self, err):
         print('\n# RECEIVED ERROR #\n%r\n' % err)