use StringIO for buffers; efficient implementation for resynchronizing around message...
authorShikhar Bhushan <shikhar@schmizz.net>
Wed, 22 Apr 2009 01:32:39 +0000 (01:32 +0000)
committerShikhar Bhushan <shikhar@schmizz.net>
Wed, 22 Apr 2009 01:32:39 +0000 (01:32 +0000)
git-svn-id: http://ncclient.googlecode.com/svn/trunk@35 6dbcf712-26ac-11de-a2f3-1373824ab735

ncclient/ssh.py

index 77d9527..a1db987 100644 (file)
 import logging
 import paramiko
 
+from os import SEEK_CUR
+from cStringIO import StringIO
+
 from session import Session, SessionError
 
 logger = logging.getLogger('ncclient.ssh')
 
+
 class SessionCloseError(SessionError):
     
     def __str__(self):
@@ -28,27 +32,34 @@ class SessionCloseError(SessionError):
         SessionError.__init__(self)
         self._in_buf, self._out_buf = in_buf, out_buf
 
+
 class SSHSession(Session):
 
     BUF_SIZE = 4096
-    MSG_DELIM = ']]>]]>'
+    MSG_DELIM = ']]>>]]'
+    
     
     def __init__(self, load_known_hosts=True,
                  missing_host_key_policy=paramiko.RejectPolicy):
         Session.__init__(self)
-        self._in_buf = ''
-        self._out_buf = ''
         self._client = paramiko.SSHClient()
         if load_known_hosts:
             self._client.load_system_host_keys()
         self._client.set_missing_host_key_policy(missing_host_key_policy)
+        self._in_buf = StringIO()
+        self._out_buf = StringIO()
+        self._parsing_state = -1
+        self._parsing_pos = 0
+    
     
     def load_host_keys(self, filename):
         self._client.load_host_keys(filename)
     
+    
     def set_missing_host_key_policy(self, policy):
         self._client.set_missing_host_key_policy(policy)
     
+    
     # paramiko exceptions ok?
     # user might be looking for ClientError
     def connect(self, hostname, port=830, username=None, password=None,
@@ -64,6 +75,7 @@ class SSHSession(Session):
         self._channel.set_name('netconf')
         self._connect()
     
+    
     def run(self):
         
         chan = self._channel
@@ -71,41 +83,77 @@ class SSHSession(Session):
         q = self._q
         
         while True:
-            
             if chan.closed:
                 break
-            
             if chan.recv_ready():
                 data = chan.recv(SSHSession.BUF_SIZE)
                 if data:
-                    self._in_buf += data
-                    while True:
-                        before, delim, after = self._in_buf.partition(SSHSession.MSG_DELIM)
-                        if delim:
-                            self.dispatch('reply', before)
-                            self._in_buf = after
-                        else:
-                            break
+                    self._in_buf.write(data)
+                    self._parse()
                 else:
                     break
-            
             if chan.send_ready():
                 if not q.empty():
-                    msg = q.get()
-                    self._out_buf += ( msg + SSHSession.MSG_DELIM )
-                    while self._out_buf:
-                        n = chan.send(self._out_buf)
-                        if n <= 0:
-                            break
-                        self._out_buf = self._out_buf[n:]
+                    self._out_buf.write(q.get() + SSHSession.MSG_DELIM)
+                    self._dump()
         
         logger.debug('** broke out of main loop **')
         self.dispatch('close', SessionCloseError(self._in_buf, self._out_buf))
     
+    
     def _close(self):
         self._channel.close()
         Session._close(self)
-
+    
+    def _dump(self):
+        for line in self._out_buf:
+            while line:
+                n = chan.send(line)
+                if n <= 0:
+                    break
+                line = self._out_buf[n:]
+    
+    def _parse(self):
+        delim = SSHSession.MSG_DELIM
+        n = len(delim) - 1
+        state = self._parsing_state
+        buf = self._in_buf
+        buf.seek(self._parsing_pos)
+        
+        while True:
+            
+            x = buf.read(1)
+            if not x: # done reading
+                break
+            elif x == delim[state]:
+                state += 1
+            else:
+                continue
+            # loop till last delim char expected, break if other char encountered
+            for i in range(state, n):
+                x = buf.read(1)
+                if not x: # done reading
+                    break
+                if x==delim[i]: # what we expected
+                    state += 1 # expect the next delim char
+                else:
+                    state = 0 # reset
+                    break
+            else: # if we didn't break out of above loop, full delim parsed
+                till = buf.tell() - n
+                buf.seek(0)
+                msg = buf.read(till)
+                self.dispatch('reply', msg)
+                buf.seek(n+1, SEEK_CUR)
+                rest = buf.read()
+                buf = StringIO()
+                buf.write(rest)
+                buf.seek(0)
+                state = 0
+        
+        self._parsing_state = state
+        self._in_buf = buf
+        self._parsing_pos = self._in_buf.tell()
 
 class MissingHostKeyPolicy(paramiko.MissingHostKeyPolicy):