git-svn-id: http://ncclient.googlecode.com/svn/trunk@86 6dbcf712-26ac-11de-a2f3-13738...
authorShikhar Bhushan <shikhar@schmizz.net>
Wed, 29 Apr 2009 20:49:18 +0000 (20:49 +0000)
committerShikhar Bhushan <shikhar@schmizz.net>
Wed, 29 Apr 2009 20:49:18 +0000 (20:49 +0000)
ncclient/glue.py
ncclient/operations/rpc.py
ncclient/operations/session.py
ncclient/transport/hello.py
ncclient/transport/session.py
ncclient/transport/ssh.py
ncclient/transport/util.py

index 4fe6367..7500fc9 100644 (file)
 "TODO: docstring"
 
 from cStringIO import StringIO
 "TODO: docstring"
 
 from cStringIO import StringIO
+from threading import Thread
 from Queue import Queue
 from threading import Lock
 from xml.etree import cElementTree as ET
 
 from Queue import Queue
 from threading import Lock
 from xml.etree import cElementTree as ET
 
+import logging
+logger = logging.getLogger('ncclient.glue')
 
 def parse_root(raw):
     '''Parse the top-level element from a string representing an XML document.
 
 def parse_root(raw):
     '''Parse the top-level element from a string representing an XML document.
@@ -32,14 +35,15 @@ def parse_root(raw):
         return (element.tag, element.attrib)
 
 
         return (element.tag, element.attrib)
 
 
-class Subject(object):
+class Subject(Thread):
     
     'Meant for subclassing by transport.Session'
 
     def __init__(self):
         "TODO: docstring"
     
     'Meant for subclassing by transport.Session'
 
     def __init__(self):
         "TODO: docstring"
+        Thread.__init__(self)
         self._q = Queue()
         self._q = Queue()
-        self._listeners = set()
+        self._listeners = set() # TODO(?) weakref
         self._lock = Lock()
     
     def _dispatch_received(self, raw):
         self._lock = Lock()
     
     def _dispatch_received(self, raw):
@@ -48,6 +52,7 @@ class Subject(object):
         with self._lock:
             listeners = list(self._listeners)
         for l in listeners:
         with self._lock:
             listeners = list(self._listeners)
         for l in listeners:
+            logger.debug('[dispatching] message to %s' % l)
             l.callback(root, raw)
     
     def _dispatch_error(self, err):
             l.callback(root, raw)
     
     def _dispatch_error(self, err):
@@ -55,15 +60,18 @@ class Subject(object):
         with self._lock:
             listeners = list(self._listeners)
         for l in listeners:
         with self._lock:
             listeners = list(self._listeners)
         for l in listeners:
+            logger.debug('[dispatching] error to %s' % l)
             l.errback(err)
     
     def add_listener(self, listener):
         "TODO: docstring"
             l.errback(err)
     
     def add_listener(self, listener):
         "TODO: docstring"
+        logger.debug('[installing listener] %r' % listener)
         with self._lock:
             self._listeners.add(listener)
     
     def remove_listener(self, listener):
         "TODO: docstring"
         with self._lock:
             self._listeners.add(listener)
     
     def remove_listener(self, listener):
         "TODO: docstring"
+        logger.debug('[discarding listener] %r' % listener)
         with self._lock:
             self._listeners.discard(listener)
     
         with self._lock:
             self._listeners.discard(listener)
     
@@ -78,7 +86,7 @@ class Subject(object):
     
     def send(self, message):
         "TODO: docstring"
     
     def send(self, message):
         "TODO: docstring"
-        logger.debug('queueing:%s' % message)
+        logger.debug('[queueing] %s' % message)
         self._q.put(message)
 
 
         self._q.put(message)
 
 
index 8cf978e..1687198 100644 (file)
 
 from threading import Event, Lock
 from uuid import uuid1
 
 from threading import Event, Lock
 from uuid import uuid1
+from weakref import WeakValueDictionary
 
 
-from ncclient.content import TreeBuilder, BASE_NS
+from ncclient.content import TreeBuilder
+from ncclient.content import qualify as _
+from ncclient.content import unqualify as __
 from ncclient.glue import Listener
 
 from . import logger
 from ncclient.glue import Listener
 
 from . import logger
@@ -26,6 +29,7 @@ class RPC(object):
     
     def __init__(self, session, async=False):
         self._session = session
     
     def __init__(self, session, async=False):
         self._session = session
+        self._async = async
         self._id = uuid1().urn
         self._listener = RPCReplyListener(session)
         self._listener.register(self._id, self)
         self._id = uuid1().urn
         self._listener = RPCReplyListener(session)
         self._listener.register(self._id, self)
@@ -41,7 +45,7 @@ class RPC(object):
     def _request(self, op):
         req = self._build(op)
         self._session.send(req)
     def _request(self, op):
         req = self._build(op)
         self._session.send(req)
-        if async:
+        if self._async:
             self._reply_event.wait()
             self._reply.parse()
             return self._reply
             self._reply_event.wait()
             self._reply.parse()
             return self._reply
@@ -74,7 +78,7 @@ class RPC(object):
     def build_from_spec(msgid, opspec, encoding='utf-8'):
         "TODO: docstring"
         spec = {
     def build_from_spec(msgid, opspec, encoding='utf-8'):
         "TODO: docstring"
         spec = {
-            'tag': _('rpc', BASE_NS),
+            'tag': _('rpc'),
             'attributes': {'message-id': msgid},
             'children': opspec
             }
             'attributes': {'message-id': msgid},
             'children': opspec
             }
@@ -132,4 +136,4 @@ class RPCReplyListener(Listener):
     def errback(self, err):
         logger.error('RPCReplyListener.errback: %r' % err)
         if self._errback is not None:
     def errback(self, err):
         logger.error('RPCReplyListener.errback: %r' % err)
         if self._errback is not None:
-            self._errback(err)
\ No newline at end of file
+            self._errback(err)
index 49c85bb..b92533c 100644 (file)
@@ -30,8 +30,8 @@ class CloseSession(RPC):
             self._session.expect_close()
         self._session.close()
     
             self._session.expect_close()
         self._session.close()
     
-    def request(self, reply_event=None):
-        self._request(self.spec, reply_event)
+    def request(self):
+        self._request(self.spec)
 
 
 class KillSession(RPC):
 
 
 class KillSession(RPC):
index b17f51c..ee8693f 100644 (file)
@@ -32,7 +32,7 @@ class HelloHandler(Listener):
     def callback(self, root, raw):
         if __(root[0]) == 'hello':
             try:
     def callback(self, root, raw):
         if __(root[0]) == 'hello':
             try:
-                id, capabilities = parse(raw)
+                id, capabilities = HelloHandler.parse(raw)
             except Exception as e:
                 self._error_cb(e)
             else:
             except Exception as e:
                 self._error_cb(e)
             else:
index 3b2527e..ca488cd 100644 (file)
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from threading import Thread, Event
+from threading import Event
 
 from ncclient.capabilities import Capabilities, CAPABILITIES
 from ncclient.glue import Subject
 
 from ncclient.capabilities import Capabilities, CAPABILITIES
 from ncclient.glue import Subject
@@ -20,44 +20,47 @@ from ncclient.glue import Subject
 from . import logger
 from hello import HelloHandler
 
 from . import logger
 from hello import HelloHandler
 
-class Session(Thread, Subject):
+class Session(Subject):
     
     "TODO: docstring"
     
     def __init__(self):
         "TODO: docstring"
         Subject.__init__(self)
     
     "TODO: docstring"
     
     def __init__(self):
         "TODO: docstring"
         Subject.__init__(self)
-        Thread.__init__(self, name='session')
-        self.setDaemon(True)
+        self.setName('session')
+        self.setDaemon(True) #hmm
         self._client_capabilities = CAPABILITIES
         self._server_capabilities = None # yet
         self._id = None # session-id
         self._connected = False # to be set/cleared by subclass implementation
         self._client_capabilities = CAPABILITIES
         self._server_capabilities = None # yet
         self._id = None # session-id
         self._connected = False # to be set/cleared by subclass implementation
+        logger.debug('[session object created] client_capabilities=%r' %
+                     self._client_capabilities)
     
     def _post_connect(self):
         "TODO: docstring"
     
     def _post_connect(self):
         "TODO: docstring"
-        self.send(HelloHandler.build(self._client_capabilities))
-        error = None
         init_event = Event()
         init_event = Event()
+        error = [None] # so that err_cb can bind error[0]. just how it is.
         # callbacks
         def ok_cb(id, capabilities):
         # callbacks
         def ok_cb(id, capabilities):
-            self._id, self._server_capabilities = id, Capabilities(capabilities)
+            self._id = id
+            self._server_capabilities = Capabilities(capabilities)
             init_event.set()
         def err_cb(err):
             init_event.set()
         def err_cb(err):
-            error = err
+            error[0] = err
             init_event.set()
         listener = HelloHandler(ok_cb, err_cb)
         self.add_listener(listener)
             init_event.set()
         listener = HelloHandler(ok_cb, err_cb)
         self.add_listener(listener)
-        # start the subclass' main loop
+        self.send(HelloHandler.build(self._client_capabilities))
+        logger.debug('[starting main loop]')
         self.start()
         # we expect server's hello message
         init_event.wait()
         # received hello message or an error happened
         self.remove_listener(listener)
         self.start()
         # we expect server's hello message
         init_event.wait()
         # received hello message or an error happened
         self.remove_listener(listener)
-        if error:
-            raise error
+        if error[0]:
+            raise error[0]
         logger.info('initialized: session-id=%s | server_capabilities=%s' %
         logger.info('initialized: session-id=%s | server_capabilities=%s' %
-                     (self.id, self.server_capabilities))
+                     (self._id, self._server_capabilities))
     
     def connect(self, *args, **kwds):
         "TODO: docstring"
     
     def connect(self, *args, **kwds):
         "TODO: docstring"
index ea5f4b9..dcc0be0 100644 (file)
@@ -37,10 +37,12 @@ class SSHSession(Session):
         self._transport = None
         self._connected = False
         self._channel = None
         self._transport = None
         self._connected = False
         self._channel = None
+        self._expecting_close = False
         self._buffer = StringIO() # for incoming data
         # parsing-related, see _parse()
         self._parsing_state = 0 
         self._parsing_pos = 0
         self._buffer = StringIO() # for incoming data
         # parsing-related, see _parse()
         self._parsing_state = 0 
         self._parsing_pos = 0
+        logger.debug('[SSHSession object created]')
     
     def _parse(self):
         '''Messages ae delimited by MSG_DELIM. The buffer could have grown by a
     
     def _parse(self):
         '''Messages ae delimited by MSG_DELIM. The buffer could have grown by a
@@ -85,6 +87,9 @@ class SSHSession(Session):
         self._parsing_state = expect
         self._parsing_pos = self._buffer.tell()
     
         self._parsing_state = expect
         self._parsing_pos = self._buffer.tell()
     
+    def expect_close(self):
+        self._expecting_close = True
+    
     def load_system_host_keys(self, filename=None):
         if filename is None:
             filename = os.path.expanduser('~/.ssh/known_hosts')
     def load_system_host_keys(self, filename=None):
         if filename is None:
             filename = os.path.expanduser('~/.ssh/known_hosts')
@@ -266,9 +271,10 @@ class SSHSession(Session):
                             raise SessionCloseError(self._buffer.getvalue(), data)
                         data = data[n:]
         except Exception as e:
                             raise SessionCloseError(self._buffer.getvalue(), data)
                         data = data[n:]
         except Exception as e:
-            self.close()
             logger.debug('*** broke out of main loop ***')
             logger.debug('*** broke out of main loop ***')
-            self._dispatch_error(e)
+            self.close()
+            if not (isinstance(e, SessionCloseError) and self._expecting_close):
+                self._dispatch_error(e)
     
     @property
     def transport(self):
     
     @property
     def transport(self):
index b38a5b3..e47c5c6 100644 (file)
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from . import logger
+from ncclient.glue import Listener
 
 
-class DebugListener:
+import logging
+logger = logging.getLogger('DebugListener')
+
+class DebugListener(Listener):
     
     def __str__(self):
         return 'DebugListener'
     
     def received(self, raw):
     
     def __str__(self):
         return 'DebugListener'
     
     def received(self, raw):
-        logger.debug('DebugListener:[received]:||%s||' % raw)
+        logger.debug('[received]:||%s||' % raw)
     
     
-    def error(self, err):
-        logger.debug('DebugListener:[error]:%r' % err)
+    def errback(self, err):
+        logger.debug('[error]:%r' % err)