rename content to xml_
[ncclient] / ncclient / transport / ssh.py
index 41122da..b5597ce 100644 (file)
@@ -30,29 +30,46 @@ BUF_SIZE = 4096
 MSG_DELIM = ']]>]]>'
 TICK = 0.1
 
+def default_unknown_host_cb(host, key):
+    """An `unknown host callback` returns :const:`True` if it finds the key
+    acceptable, and :const:`False` if not.
+
+    This default callback always returns :const:`False`, which would lead to
+    :meth:`connect` raising a :exc:`SSHUnknownHost` exception.
+
+    Supply another valid callback if you need to verify the host key
+    programatically.
+
+    :arg host: the host for whom key needs to be verified
+    :type host: string
+
+    :arg key: a hex string representing the host key fingerprint
+    :type key: string
+    """
+    return False
+
+
 class SSHSession(Session):
-    
-    "A NETCONF SSH session, per :rfc: 4742"
-    
+
+    "Implements a :rfc:`4742` NETCONF session over SSH."
+
     def __init__(self, capabilities):
         Session.__init__(self, capabilities)
         self._host_keys = paramiko.HostKeys()
-        self._system_host_keys = paramiko.HostKeys()
         self._transport = None
         self._connected = False
         self._channel = None
         self._expecting_close = False
         self._buffer = StringIO() # for incoming data
         # parsing-related, see _parse()
-        self._parsing_state = 0 
+        self._parsing_state = 0
         self._parsing_pos = 0
-    
+
     def _parse(self):
         '''Messages ae delimited by MSG_DELIM. The buffer could have grown by a
         maximum of BUF_SIZE bytes everytime this method is called. Retains state
-        across method calls and if a byte has been read it will not be considered
-        again.
-        '''
+        across method calls and if a byte has been read it will not be
+        considered again. '''
         delim = MSG_DELIM
         n = len(delim) - 1
         expect = self._parsing_state
@@ -90,97 +107,133 @@ class SSHSession(Session):
         self._buffer = buf
         self._parsing_state = expect
         self._parsing_pos = self._buffer.tell()
-    
-    def load_system_host_keys(self, filename=None):
+
+    def load_known_hosts(self, filename=None):
+        """Load host keys from a :file:`known_hosts`-style file. Can be called multiple
+        times.
+
+        If *filename* is not specified, looks in the default locations i.e.
+        :file:`~/.ssh/known_hosts` and :file:`~/ssh/known_hosts` for Windows.
+        """
         if filename is None:
             filename = os.path.expanduser('~/.ssh/known_hosts')
             try:
-                self._system_host_keys.load(filename)
+                self._host_keys.load(filename)
             except IOError:
                 # for windows
                 filename = os.path.expanduser('~/ssh/known_hosts')
                 try:
-                    self._system_host_keys.load(filename)
+                    self._host_keys.load(filename)
                 except IOError:
                     pass
-            return
-        self._system_host_keys.load(filename)
-    
-    def load_host_keys(self, filename):
-        self._host_keys.load(filename)
-
-    def add_host_key(self, key):
-        self._host_keys.add(key)
-    
-    def save_host_keys(self, filename):
-        f = open(filename, 'w')
-        for host, keys in self._host_keys.iteritems():
-            for keytype, key in keys.iteritems():
-                f.write('%s %s %s\n' % (host, keytype, key.get_base64()))
-        f.close()    
-    
+        else:
+            self._host_keys.load(filename)
+
     def close(self):
         self._expecting_close = True
         if self._transport.is_active():
             self._transport.close()
         self._connected = False
-    
+
     def connect(self, host, port=830, timeout=None,
-                unknown_host_cb=None, username=None, password=None,
+                unknown_host_cb=default_unknown_host_cb,
+                username=None, password=None,
                 key_filename=None, allow_agent=True, look_for_keys=True):
-        assert(username is not None)
-        
-        for (family, socktype, proto, canonname, sockaddr) in \
-        socket.getaddrinfo(host, port):
-            if socktype == socket.SOCK_STREAM:
-                af = family
-                addr = sockaddr
-                break
+        """Connect via SSH and initialize the NETCONF session. First attempts
+        the publickey authentication method and then password authentication.
+
+        To disable attemting publickey authentication altogether, call with
+        *allow_agent* and *look_for_keys* as :const:`False`. This may be needed
+        for Cisco devices which immediately disconnect on an incorrect
+        authentication attempt.
+
+        :arg host: the hostname or IP address to connect to
+        :type host: `string`
+
+        :arg port: by default 830, but some devices use the default SSH port of 22 so this may need to be specified
+        :type port: `int`
+
+        :arg timeout: an optional timeout for the TCP handshake
+        :type timeout: `int`
+
+        :arg unknown_host_cb: called when a host key is not recognized
+        :type unknown_host_cb: see :meth:`signature <ssh.default_unknown_host_cb>`
+
+        :arg username: the username to use for SSH authentication
+        :type username: `string`
+
+        :arg password: the password used if using password authentication, or the passphrase to use for unlocking keys that require it
+        :type password: `string`
+
+        :arg key_filename: a filename where a the private key to be used can be found
+        :type key_filename: `string`
+
+        :arg allow_agent: enables querying SSH agent (if found) for keys
+        :type allow_agent: `bool`
+
+        :arg look_for_keys: enables looking in the usual locations for ssh keys (e.g. :file:`~/.ssh/id_*`)
+        :type look_for_keys: `bool`
+        """
+
+        if username is None:
+            raise SSHError("No username specified")
+
+        sock = None
+        for res in socket.getaddrinfo(host, port, socket.AF_UNSPEC, socket.SOCK_STREAM):
+            af, socktype, proto, canonname, sa = res
+            try:
+                sock = socket.socket(af, socktype, proto)
+                sock.settimeout(timeout)
+            except socket.error:
+                continue
+            try:
+                sock.connect(sa)
+            except socket.error:
+                sock.close()
+                continue
+            break
         else:
-            raise SSHError('No suitable address family for %s' % host)
-        sock = socket.socket(af, socket.SOCK_STREAM)
-        sock.settimeout(timeout)
-        sock.connect(addr)
+            raise SSHError("Could not open socket")
+
         t = self._transport = paramiko.Transport(sock)
         t.set_log_channel(logger.name)
-        
+
         try:
             t.start_client()
         except paramiko.SSHException:
             raise SSHError('Negotiation failed')
-        
+
         # host key verification
         server_key = t.get_remote_server_key()
-        known_host = self._host_keys.check(host, server_key) or \
-                        self._system_host_keys.check(host, server_key)
-        
-        if unknown_host_cb is None:
-            unknown_host_cb = lambda *args: False
-        if not known_host and not unknown_host_cb(host, server_key):
-                raise SSHUnknownHostError(host, server_key)
-        
+        known_host = self._host_keys.check(host, server_key)
+
+        fingerprint = hexlify(server_key.get_fingerprint())
+
+        if not known_host and not unknown_host_cb(host, fingerprint):
+            raise SSHUnknownHostError(host, fingerprint)
+
         if key_filename is None:
             key_filenames = []
         elif isinstance(key_filename, basestring):
             key_filenames = [ key_filename ]
         else:
             key_filenames = key_filename
-        
+
         self._auth(username, password, key_filenames, allow_agent, look_for_keys)
-        
+
         self._connected = True # there was no error authenticating
-        
+
         c = self._channel = self._transport.open_session()
         c.set_name('netconf')
         c.invoke_subsystem('netconf')
-        
+
         self._post_connect()
-    
+
     # on the lines of paramiko.SSHClient._auth()
     def _auth(self, username, password, key_filenames, allow_agent,
               look_for_keys):
         saved_exception = None
-        
+
         for key_filename in key_filenames:
             for cls in (paramiko.RSAKey, paramiko.DSSKey):
                 try:
@@ -192,7 +245,7 @@ class SSHSession(Session):
                 except Exception as e:
                     saved_exception = e
                     logger.debug(e)
-        
+
         if allow_agent:
             for key in paramiko.Agent().get_keys():
                 try:
@@ -203,7 +256,7 @@ class SSHSession(Session):
                 except Exception as e:
                     saved_exception = e
                     logger.debug(e)
-        
+
         keyfiles = []
         if look_for_keys:
             rsa_key = os.path.expanduser('~/.ssh/id_rsa')
@@ -219,7 +272,7 @@ class SSHSession(Session):
                 keyfiles.append((paramiko.RSAKey, rsa_key))
             if os.path.isfile(dsa_key):
                 keyfiles.append((paramiko.DSSKey, dsa_key))
-        
+
         for cls, filename in keyfiles:
             try:
                 key = cls.from_private_key_file(filename, password)
@@ -230,7 +283,7 @@ class SSHSession(Session):
             except Exception as e:
                 saved_exception = e
                 logger.debug(e)
-        
+
         if password is not None:
             try:
                 self._transport.auth_password(username, password)
@@ -238,13 +291,13 @@ class SSHSession(Session):
             except Exception as e:
                 saved_exception = e
                 logger.debug(e)
-        
+
         if saved_exception is not None:
             # need pep-3134 to do this right
-            raise SSHAuthenticationError(repr(saved_exception))
-        
-        raise SSHAuthenticationError('No authentication methods available')
-    
+            raise AuthenticationError(repr(saved_exception))
+
+        raise AuthenticationError('No authentication methods available')
+
     def run(self):
         chan = self._channel
         chan.setblocking(0)
@@ -252,7 +305,7 @@ class SSHSession(Session):
         try:
             while True:
                 # select on a paramiko ssh channel object does not ever return
-                # it in the writable list, so it channel's don't exactly emulate 
+                # it in the writable list, so it channel's don't exactly emulate
                 # the socket api
                 r, w, e = select([chan], [], [], TICK)
                 # will wakeup evey TICK seconds to check if something
@@ -275,14 +328,21 @@ class SSHSession(Session):
                         data = data[n:]
         except Exception as e:
             logger.debug('broke out of main loop')
+            expecting = self._expecting_close
             self.close()
-            if not (isinstance(e, SessionCloseError) and self._expecting_close):
+            logger.debug('error=%r' % e)
+            logger.debug('expecting_close=%r' % expecting)
+            if not (isinstance(e, SessionCloseError) and expecting):
                 self._dispatch_error(e)
-    
+
     @property
     def transport(self):
+        """Underlying `paramiko.Transport
+        <http://www.lag.net/paramiko/docs/paramiko.Transport-class.html>`_
+        object. This makes it possible to call methods like set_keepalive on it.
+        """
         return self._transport
-    
+
     @property
     def can_pipeline(self):
         if 'Cisco' in self._transport.remote_version: