more doc updates
[ncclient] / ncclient / transport / ssh.py
index f84bb41..3e2be10 100644 (file)
@@ -14,6 +14,7 @@
 
 import os
 import socket
+import getpass
 from binascii import hexlify
 from cStringIO import StringIO
 from select import select
@@ -24,24 +25,30 @@ from errors import AuthenticationError, SessionCloseError, SSHError, SSHUnknownH
 from session import Session
 
 import logging
-logger = logging.getLogger('ncclient.transport.ssh')
+logger = logging.getLogger("ncclient.transport.ssh")
 
 BUF_SIZE = 4096
-MSG_DELIM = ']]>]]>'
+MSG_DELIM = "]]>]]>"
 TICK = 0.1
 
-def default_unknown_host_cb(host, key):
-    """An `unknown host callback` returns :const:`True` if it finds the key
-    acceptable, and :const:`False` if not.
+def default_unknown_host_cb(host, fingerprint):
+    """An unknown host callback returns `True` if it finds the key acceptable, and `False` if not.
 
-    :arg host: the hostname/address which needs to be verified
+    This default callback always returns `False`, which would lead to :meth:`connect` raising a :exc:`SSHUnknownHost` exception.
+    
+    Supply another valid callback if you need to verify the host key programatically.
 
-    :arg key: a hex string representing the host key fingerprint
+    *host* is the hostname that needs to be verified
 
-    :returns: this default callback always returns :const:`False`
+    *fingerprint* is a hex string representing the host key fingerprint, colon-delimited e.g. `"4b:69:6c:72:6f:79:20:77:61:73:20:68:65:72:65:21"`
     """
     return False
 
+def _colonify(fp):
+    finga = fp[:2]
+    for idx  in range(2, len(fp), 2):
+        finga += ":" + fp[idx:idx+2]
+    return finga
 
 class SSHSession(Session):
 
@@ -50,21 +57,16 @@ class SSHSession(Session):
     def __init__(self, capabilities):
         Session.__init__(self, capabilities)
         self._host_keys = paramiko.HostKeys()
-        self._system_host_keys = paramiko.HostKeys()
         self._transport = None
         self._connected = False
         self._channel = None
-        self._expecting_close = False
         self._buffer = StringIO() # for incoming data
         # parsing-related, see _parse()
         self._parsing_state = 0
         self._parsing_pos = 0
-
+    
     def _parse(self):
-        '''Messages ae delimited by MSG_DELIM. The buffer could have grown by a
-        maximum of BUF_SIZE bytes everytime this method is called. Retains state
-        across method calls and if a byte has been read it will not be
-        considered again. '''
+        "Messages ae delimited by MSG_DELIM. The buffer could have grown by a maximum of BUF_SIZE bytes everytime this method is called. Retains state across method calls and if a byte has been read it will not be considered again."
         delim = MSG_DELIM
         n = len(delim) - 1
         expect = self._parsing_state
@@ -77,6 +79,7 @@ class SSHSession(Session):
             elif x == delim[expect]: # what we expected
                 expect += 1 # expect the next delim char
             else:
+                expect = 0
                 continue
             # loop till last delim char expected, break if other char encountered
             for i in range(expect, n):
@@ -103,82 +106,75 @@ class SSHSession(Session):
         self._parsing_state = expect
         self._parsing_pos = self._buffer.tell()
 
-    def load_system_host_keys(self, filename=None):
+    def load_known_hosts(self, filename=None):
+        """Load host keys from an openssh :file:`known_hosts`-style file. Can be called multiple times.
+
+        If *filename* is not specified, looks in the default locations i.e. :file:`~/.ssh/known_hosts` and :file:`~/ssh/known_hosts` for Windows.
+        """
         if filename is None:
             filename = os.path.expanduser('~/.ssh/known_hosts')
             try:
-                self._system_host_keys.load(filename)
+                self._host_keys.load(filename)
             except IOError:
                 # for windows
                 filename = os.path.expanduser('~/ssh/known_hosts')
                 try:
-                    self._system_host_keys.load(filename)
+                    self._host_keys.load(filename)
                 except IOError:
                     pass
-            return
-        self._system_host_keys.load(filename)
-
-    def load_host_keys(self, filename):
-        self._host_keys.load(filename)
-
-    def add_host_key(self, key):
-        self._host_keys.add(key)
-
-    def save_host_keys(self, filename):
-        f = open(filename, 'w')
-        for host, keys in self._host_keys.iteritems():
-            for keytype, key in keys.iteritems():
-                f.write('%s %s %s\n' % (host, keytype, key.get_base64()))
-        f.close()
+        else:
+            self._host_keys.load(filename)
 
     def close(self):
-        self._expecting_close = True
         if self._transport.is_active():
             self._transport.close()
         self._connected = False
 
-    def connect(self, host, port=830, timeout=None,
-                unknown_host_cb=default_unknown_host_cb,
-                username=None, password=None,
-                key_filename=None, allow_agent=True, look_for_keys=True):
-        """Connect via SSH and initialize the NETCONF session. First attempts
-        the publickey authentication method and then password authentication.
+    # REMEMBER to update transport.rst if sig. changes, since it is hardcoded there
+    def connect(self, host, port=830, timeout=None, unknown_host_cb=default_unknown_host_cb,
+                username=None, password=None, key_filename=None, allow_agent=True, look_for_keys=True):
+        """Connect via SSH and initialize the NETCONF session. First attempts the publickey authentication method and then password authentication.
 
-        To disable publickey authentication, call with *allow_agent* and
-        *look_for_keys* as :const:`False`
+        To disable attempting publickey authentication altogether, call with *allow_agent* and *look_for_keys* as `False`.
 
-        :arg host: the hostname or IP address to connect to
+        *host* is the hostname or IP address to connect to
 
-        :arg port: by default 830, but some devices use the default SSH port of 22 so this may need to be specified
+        *port* is by default 830, but some devices use the default SSH port of 22 so this may need to be specified
 
-        :arg timeout: an optional timeout for the TCP handshake
+        *timeout* is an optional timeout for socket connect
 
-        :arg unknown_host_cb: called when a host key is not known. See :func:`unknown_host_cb` for details on signature
+        *unknown_host_cb* is called when the server host key is not recognized. It takes two arguments, the hostname and the fingerprint (see the signature of :func:`default_unknown_host_cb`)
 
-        :arg username: the username to use for SSH authentication
+        *username* is the username to use for SSH authentication
 
-        :arg password: the password used if using password authentication, or the passphrase to use in order to unlock keys that require it
+        *password* is the password used if using password authentication, or the passphrase to use for unlocking keys that require it
 
-        :arg key_filename: a filename where a the private key to be used can be found
+        *key_filename* is a filename where a the private key to be used can be found
 
-        :arg allow_agent: enables querying SSH agent (if found) for keys
+        *allow_agent* enables querying SSH agent (if found) for keys
 
-        :arg look_for_keys: enables looking in the usual locations for ssh keys (e.g. :file:`~/.ssh/id_*`)
+        *look_for_keys* enables looking in the usual locations for ssh keys (e.g. :file:`~/.ssh/id_*`)
         """
-
-        assert(username is not None)
-
-        for (family, socktype, proto, canonname, sockaddr) in \
-        socket.getaddrinfo(host, port):
-            if socktype == socket.SOCK_STREAM:
-                af = family
-                addr = sockaddr
-                break
+        if username is None:
+            username = getpass.getuser()
+        
+        sock = None
+        for res in socket.getaddrinfo(host, port, socket.AF_UNSPEC, socket.SOCK_STREAM):
+            af, socktype, proto, canonname, sa = res
+            try:
+                sock = socket.socket(af, socktype, proto)
+                sock.settimeout(timeout)
+            except socket.error:
+                continue
+            try:
+                sock.connect(sa)
+            except socket.error:
+                sock.close()
+                continue
+            break
         else:
-            raise SSHError('No suitable address family for %s' % host)
-        sock = socket.socket(af, socket.SOCK_STREAM)
-        sock.settimeout(timeout)
-        sock.connect(addr)
+            raise SSHError("Could not open socket to %s:%s" % (host, port))
+
         t = self._transport = paramiko.Transport(sock)
         t.set_log_channel(logger.name)
 
@@ -189,12 +185,12 @@ class SSHSession(Session):
 
         # host key verification
         server_key = t.get_remote_server_key()
-        known_host = self._host_keys.check(host, server_key) or \
-                        self._system_host_keys.check(host, server_key)
+        known_host = self._host_keys.check(host, server_key)
 
-        fp = hexlify(server_key.get_fingerprint())
-        if not known_host and not unknown_host_cb(host, fp):
-            raise SSHUnknownHostError(host, fp)
+        fingerprint = _colonify(hexlify(server_key.get_fingerprint()))
+
+        if not known_host and not unknown_host_cb(host, fingerprint):
+            raise SSHUnknownHostError(host, fingerprint)
 
         if key_filename is None:
             key_filenames = []
@@ -208,11 +204,11 @@ class SSHSession(Session):
         self._connected = True # there was no error authenticating
 
         c = self._channel = self._transport.open_session()
-        c.set_name('netconf')
-        c.invoke_subsystem('netconf')
+        c.set_name("netconf")
+        c.invoke_subsystem("netconf")
 
         self._post_connect()
-
+    
     # on the lines of paramiko.SSHClient._auth()
     def _auth(self, username, password, key_filenames, allow_agent,
               look_for_keys):
@@ -222,7 +218,7 @@ class SSHSession(Session):
             for cls in (paramiko.RSAKey, paramiko.DSSKey):
                 try:
                     key = cls.from_private_key_file(key_filename, password)
-                    logger.debug('Trying key %s from %s' %
+                    logger.debug("Trying key %s from %s" %
                               (hexlify(key.get_fingerprint()), key_filename))
                     self._transport.auth_publickey(username, key)
                     return
@@ -233,7 +229,7 @@ class SSHSession(Session):
         if allow_agent:
             for key in paramiko.Agent().get_keys():
                 try:
-                    logger.debug('Trying SSH agent key %s' %
+                    logger.debug("Trying SSH agent key %s" %
                                  hexlify(key.get_fingerprint()))
                     self._transport.auth_publickey(username, key)
                     return
@@ -243,15 +239,15 @@ class SSHSession(Session):
 
         keyfiles = []
         if look_for_keys:
-            rsa_key = os.path.expanduser('~/.ssh/id_rsa')
-            dsa_key = os.path.expanduser('~/.ssh/id_dsa')
+            rsa_key = os.path.expanduser("~/.ssh/id_rsa")
+            dsa_key = os.path.expanduser("~/.ssh/id_dsa")
             if os.path.isfile(rsa_key):
                 keyfiles.append((paramiko.RSAKey, rsa_key))
             if os.path.isfile(dsa_key):
                 keyfiles.append((paramiko.DSSKey, dsa_key))
             # look in ~/ssh/ for windows users:
-            rsa_key = os.path.expanduser('~/ssh/id_rsa')
-            dsa_key = os.path.expanduser('~/ssh/id_dsa')
+            rsa_key = os.path.expanduser("~/ssh/id_rsa")
+            dsa_key = os.path.expanduser("~/ssh/id_dsa")
             if os.path.isfile(rsa_key):
                 keyfiles.append((paramiko.RSAKey, rsa_key))
             if os.path.isfile(dsa_key):
@@ -260,7 +256,7 @@ class SSHSession(Session):
         for cls, filename in keyfiles:
             try:
                 key = cls.from_private_key_file(filename, password)
-                logger.debug('Trying discovered key %s in %s' %
+                logger.debug("Trying discovered key %s in %s" %
                           (hexlify(key.get_fingerprint()), filename))
                 self._transport.auth_publickey(username, key)
                 return
@@ -280,7 +276,7 @@ class SSHSession(Session):
             # need pep-3134 to do this right
             raise AuthenticationError(repr(saved_exception))
 
-        raise AuthenticationError('No authentication methods available')
+        raise AuthenticationError("No authentication methods available")
 
     def run(self):
         chan = self._channel
@@ -288,13 +284,9 @@ class SSHSession(Session):
         q = self._q
         try:
             while True:
-                # select on a paramiko ssh channel object does not ever return
-                # it in the writable list, so it channel's don't exactly emulate
-                # the socket api
+                # select on a paramiko ssh channel object does not ever return it in the writable list, so channels don't exactly emulate the socket api
                 r, w, e = select([chan], [], [], TICK)
-                # will wakeup evey TICK seconds to check if something
-                # to send, more if something to read (due to select returning
-                # chan in readable list)
+                # will wakeup evey TICK seconds to check if something to send, more if something to read (due to select returning chan in readable list)
                 if r:
                     data = chan.recv(BUF_SIZE)
                     if data:
@@ -303,7 +295,7 @@ class SSHSession(Session):
                     else:
                         raise SessionCloseError(self._buffer.getvalue())
                 if not q.empty() and chan.send_ready():
-                    logger.debug('sending message')
+                    logger.debug("Sending message")
                     data = q.get() + MSG_DELIM
                     while data:
                         n = chan.send(data)
@@ -311,22 +303,11 @@ class SSHSession(Session):
                             raise SessionCloseError(self._buffer.getvalue(), data)
                         data = data[n:]
         except Exception as e:
-            logger.debug('broke out of main loop')
+            logger.debug("Broke out of main loop, error=%r", e)
             self.close()
-            if not (isinstance(e, SessionCloseError) and self._expecting_close):
-                self._dispatch_error(e)
+            self._dispatch_error(e)
 
     @property
     def transport(self):
-        """The underlying `paramiko.Transport
-        <http://www.lag.net/paramiko/docs/paramiko.Transport-class.html>`_
-        object. This makes it possible to call methods like set_keepalive on it.
-        """
+        "Underlying `paramiko.Transport <http://www.lag.net/paramiko/docs/paramiko.Transport-class.html>`_ object. This makes it possible to call methods like :meth:`~paramiko.Transport.set_keepalive` on it."
         return self._transport
-
-    @property
-    def can_pipeline(self):
-        if 'Cisco' in self._transport.remote_version:
-            return False
-        # elif ..
-        return True