refactoring work
authorShikhar Bhushan <shikhar@schmizz.net>
Sat, 25 Apr 2009 15:46:10 +0000 (15:46 +0000)
committerShikhar Bhushan <shikhar@schmizz.net>
Sat, 25 Apr 2009 15:46:10 +0000 (15:46 +0000)
git-svn-id: http://ncclient.googlecode.com/svn/trunk@56 6dbcf712-26ac-11de-a2f3-1373824ab735

ncclient/session/error.py
ncclient/session/ssh.py

index db0b823..45dc422 100644 (file)
@@ -17,23 +17,6 @@ from ncclient import ClientError
 class SessionError(ClientError):
     pass
 
-class RemoteClosedError(SessionError):
-    
-    def __init__(self, in_buf, out_buf=None):
-        SessionError.__init__(self)
-        self._in_buf, self._out_buf = in_buf, out_buf
-        
-    def __str__(self):
-        msg = 'Session closed by remote endpoint.'
-        if self._in_buf:
-            msg += '\nIN_BUFFER: %s' % self._in_buf
-        if self._out_buf:
-            msg += '\nOUT_BUFFER: %s' % self._out_buf
-        return msg
-
-class AuthenticationError(SessionError):
-    pass
-
 class SSHError(SessionError):
     pass
 
@@ -48,11 +31,19 @@ class SSHUnknownHostError(SSHError):
         return ('Unknown host key [%s] for [%s]' %
                 (hexlify(self.key.get_fingerprint()), self.hostname))
 
-class SSHAuthenticationError(AuthenticationError, SSHError):
-    'wraps a paramiko exception that occured during auth'
-    
-    def __init__(self, ex):
-        self.ex = ex
+class SSHAuthenticationError(SSHError):
+    pass
+
+class SSHSessionClosedError(SSHError):
     
-    def __repr__(self):
-        return repr(ex)
+    def __init__(self, in_buf, out_buf=None):
+        SessionError.__init__(self, "Unexpected session close.")
+        self._in_buf, self._out_buf = in_buf, out_buf
+        
+    def __str__(self):
+        msg = SessionError(self).__str__()
+        if self._in_buf:
+            msg += '\nIN_BUFFER: %s' % self._in_buf
+        if self._out_buf:
+            msg += '\nOUT_BUFFER: %s' % self._out_buf
+        return msg
\ No newline at end of file
index e029961..ff0db2b 100644 (file)
@@ -22,7 +22,7 @@ import paramiko
 
 import session
 from . import logger
-from error import SSHError, SSHUnknownHostError, SSHAuthenticationError, RemoteClosedError
+from error import SSHError, SSHUnknownHostError, SSHAuthenticationError, SSHSessionClosedError
 from session import Session
 
 BUF_SIZE = 4096
@@ -40,42 +40,43 @@ class SSHSession(Session):
         self._connected = False
         self._channel = None
         self._buffer = StringIO() # for incoming data
-        # parsing-related, see _fresh_data()
+        # parsing-related, see _parse()
         self._parsing_state = 0 
         self._parsing_pos = 0
     
-    def _fresh_data(self):
-        '''The buffer could have grown by a maximum of BUF_SIZE bytes everytime 
-        this method is called. Retains state across method calls and if a byte
-        has been read it will not be parsed again.
+    def _parse(self):
+        '''Messages ae delimited by MSG_DELIM. The buffer could have grown by a
+        maximum of BUF_SIZE bytes everytime this method is called. Retains state
+        across method calls and if a byte has been read it will not be considered
+        again.
         '''
         delim = MSG_DELIM
         n = len(delim) - 1
-        state = self._parsing_state
+        expect = self._parsing_state
         buf = self._buffer
         buf.seek(self._parsing_pos)
         while True:
             x = buf.read(1)
             if not x: # done reading
                 break
-            elif x == delim[state]:
-                state += 1
+            elif x == delim[expect]: # what we expected
+                expect += 1 # expect the next delim char
             else:
                 continue
             # loop till last delim char expected, break if other char encountered
-            for i in range(state, n):
+            for i in range(expect, n):
                 x = buf.read(1)
                 if not x: # done reading
                     break
-                if x==delim[state]: # what we expected
-                    state += 1 # expect the next delim char
+                if x == delim[expect]: # what we expected
+                    expect += 1 # expect the next delim char
                 else:
-                    state = 0 # reset
+                    expect = 0 # reset
                     break
-            else: # if we didn't break out of above loop, full delim parsed
-                till = buf.tell() - n
+            else: # if we didn't break out of the loop, full delim was parsed
+                msg_till = buf.tell() - n
                 buf.seek(0)
-                msg = buf.read(till)
+                msg = buf.read(msg_till)
                 self.dispatch('received', msg)
                 buf.seek(n+1, os.SEEK_CUR)
                 rest = buf.read()
@@ -84,7 +85,7 @@ class SSHSession(Session):
                 buf.seek(0)
                 state = 0
         self._buffer = buf
-        self._parsing_state = state
+        self._parsing_state = expect
         self._parsing_pos = self._buffer.tell()
     
     def load_system_host_keys(self, filename=None):
@@ -167,18 +168,14 @@ class SSHSession(Session):
         c.invoke_subsystem('netconf')
         c.set_name('netconf')
         
-        Session._post_connect(self)
+        self._post_connect()
     
     # on the lines of paramiko.SSHClient._auth()
     def _auth(self, username, password, key_filenames, allow_agent,
               look_for_keys):
         saved_exception = None
         
-        allowed = ['publickey', 'keyboard-interactive', 'password']
-        
         for key_filename in key_filenames:
-            if 'publickey' not in allowed:
-                    break
             for cls in (paramiko.RSAKey, paramiko.DSSKey):
                 try:
                     key = cls.from_private_key_file(key_filename, password)
@@ -186,31 +183,23 @@ class SSHSession(Session):
                               (hexlify(key.get_fingerprint()), key_filename))
                     self._transport.auth_publickey(username, key)
                     return
-                except paramiko.BadAuthenticationType as e:
-                    allowed = e.allowed_types
-                    logger.debug(e)
                 except Exception as e:
                     saved_exception = e
                     logger.debug(e)
         
         if allow_agent:
             for key in paramiko.Agent().get_keys():
-                if 'publickey' not in allowed:
-                    break
                 try:
                     logger.debug('Trying SSH agent key %s' %
                                  hexlify(key.get_fingerprint()))
-                    logger.error( self._transport.auth_publickey(username, key) )
+                    self._transport.auth_publickey(username, key)
                     return
-                except paramiko.BadAuthenticationType as e:
-                    allowed = e.allowed_types
-                    logger.debug(e)
                 except Exception as e:
                     saved_exception = e
                     logger.debug(e)
         
         keyfiles = []
-        if look_for_keys and 'publickey' in allowed:
+        if look_for_keys:
             rsa_key = os.path.expanduser('~/.ssh/id_rsa')
             dsa_key = os.path.expanduser('~/.ssh/id_dsa')
             if os.path.isfile(rsa_key):
@@ -230,7 +219,7 @@ class SSHSession(Session):
                 key = cls.from_private_key_file(filename, password)
                 logger.debug('Trying discovered key %s in %s' %
                           (hexlify(key.get_fingerprint()), filename))
-                allowed = self._transport.auth_publickey(username, key)
+                self._transport.auth_publickey(username, key)
                 return
             except Exception as e:
                 saved_exception = e
@@ -245,9 +234,9 @@ class SSHSession(Session):
                 logger.debug(e)
         
         if saved_exception is not None:
-            raise SSHAuthenticationError(saved_exception)
+            raise AuthenticationError(repr(saved_exception))
         
-        raise SSHAuthenticationError('No authentication methods available')
+        raise AuthenticationError('No authentication methods available')
     
     def run(self):
         chan = self._channel
@@ -255,31 +244,35 @@ class SSHSession(Session):
         q = self._q
         try:
             while True:
-                # select on a paramiko ssh channel object does not ever
-                # return it in the writable list, so it does not exactly
-                # emulate the socket api
+                # select on a paramiko ssh channel object does not ever return
+                # it in the writable list, so it channel's don't exactly emulate 
+                # the socket api
                 r, w, e = select([chan], [], [], TICK)
                 # will wakeup evey TICK seconds to check if something
-                # to send, more if something to read (due to select returning chan
-                # in readable list)
+                # to send, more if something to read (due to select returning
+                # chan in readable list)
                 if r:
                     data = chan.recv(BUF_SIZE)
                     if data:
                         self._buffer.write(data)
-                        self._fresh_data()
+                        self._parse()
                     else:
-                        raise RemoteClosedError(self._buffer.getvalue())
+                        raise SSHSessionClosedError(self._buffer.getvalue())
                 if not q.empty() and chan.send_ready():
                     data = q.get() + MSG_DELIM
                     while data:
                         n = chan.send(data)
                         if n <= 0:
-                            raise RemoteClosedError(self._buffer.getvalue(), data)
+                            raise SSHSessionClosedError(self._buffer.getvalue(), data)
                         data = data[n:]
         except Exception as e:
             self.close()
             logger.debug('*** broke out of main loop ***')
             self.dispatch('error', e)
     
-    def set_keepalive(self, interval=0):
-        self._transport.set_keepalive(interval)
\ No newline at end of file
+    @property
+    def transport(self):
+        '''Get underlying paramiko.transport object; this is provided so methods
+        like transport.set_keepalive can be called.
+        '''
+        return self._transport