* cisco compatibility in the face of non-compliance * other fixes from testing *
authorShikhar Bhushan <shikhar@schmizz.net>
Thu, 30 Apr 2009 03:35:36 +0000 (03:35 +0000)
committerShikhar Bhushan <shikhar@schmizz.net>
Thu, 30 Apr 2009 03:35:36 +0000 (03:35 +0000)
git-svn-id: http://ncclient.googlecode.com/svn/trunk@88 6dbcf712-26ac-11de-a2f3-1373824ab735

ncclient/operations/reply.py
ncclient/operations/rpc.py
ncclient/operations/session.py
ncclient/transport/ssh.py
ncclient/util.py

index ceaedf5..b66812d 100644 (file)
@@ -31,13 +31,14 @@ class RPCReply:
         return self._raw
     
     def parse(self):
+        if self._parsed: return
         root = ET.fromstring(self._raw) # <rpc-reply> element
         
         # per rfc 4741 an <ok/> tag is sent when there are no errors or warnings
         oktags = _('ok')
         for oktag in oktags:
             if root.find(oktag) is not None:
-                logger.debug('found %s' % oktag)
+                logger.debug('parsed [%s]' % oktag)
                 self._parsed = True
                 return
         
@@ -45,9 +46,10 @@ class RPCReply:
         errtags = _('rpc-error')
         for errtag in errtags:
             for err in root.getiterator(errtag): # a particular <rpc-error>
+                logger.debug('parsed [%s]' % errtag)
                 d = {}
                 for err_detail in err.getchildren(): # <error-type> etc..
-                    d[__(err_detail)] = err_detail.text
+                    d[__(err_detail.tag)] = err_detail.text
                 self._errors.append(RPCError(d))
             if self._errors:
                 break
@@ -128,4 +130,4 @@ class RPCError(Exception): # raise it if you like
     
     items = lambda self: self._dict.items()
     
-    __repr__ = lambda self: repr(self._dict)
\ No newline at end of file
+    __repr__ = lambda self: repr(self._dict)
index fe69530..ced32fc 100644 (file)
@@ -24,10 +24,16 @@ from ncclient.glue import Listener
 from . import logger
 from reply import RPCReply
 
+# Cisco does not include message-id attribute in <rpc-reply> in case of an error.
+# This is messed up however we have to deal with it.
+# So essentially, there can be only one operation at a time if we are talking to
+# a Cisco device.
 
 class RPC(object):
     
     def __init__(self, session, async=False):
+        if session.is_remote_cisco and async:
+            raise UserWarning('Asynchronous mode not supported for Cisco devices')
         self._session = session
         self._async = async
         self._id = uuid1().urn
@@ -114,14 +120,15 @@ class RPCReplyListener(Listener):
     
     # TODO - determine if need locking
     
-    # one instance per subject    
-    def __new__(cls, subject):
-        instance = subject.get_listener_instance(cls)
+    # one instance per session
+    def __new__(cls, session):
+        instance = session.get_listener_instance(cls)
         if instance is None:
             instance = object.__new__(cls)
             instance._id2rpc = WeakValueDictionary()
+            instance._cisco = session.is_remote_cisco
             instance._errback = None
-            subject.add_listener(instance)
+            session.add_listener(instance)
         return instance
     
     def __str__(self):
@@ -137,22 +144,27 @@ class RPCReplyListener(Listener):
         tag, attrs = root
         if __(tag) != 'rpc-reply':
             return
+        rpc = None
         for key in attrs:
             if __(key) == 'message-id':
                 id = attrs[key]
                 try:
-                    rpc = self._id2rpc[id]
-                    rpc.deliver(raw)
+                    rpc = self._id2rpc.pop(id)
                 except KeyError:
-                    logger.warning('[RPCReplyListener.callback] no RPC '
+                    logger.warning('[RPCReplyListener.callback] no object '
                                    + 'registered for message-id: [%s]' % id)
-                    logger.debug('[RPCReplyListener.callback] registered: %r '
-                                 % dict(self._id2rpc))
                 except Exception as e:
                     logger.debug('[RPCReplyListener.callback] error - %r' % e)
                 break
         else:
-            logger.warning('<rpc-reply> without message-id received: %s' % raw)
+            if self._cisco:
+                assert(len(self._id2rpc) == 1)
+                rpc = self._id2rpc.values()[0]
+                self._id2rpc.clear()
+            else:
+                logger.warning('<rpc-reply> without message-id received: %s' % raw)
+        logger.debug('[RPCReplyListener.callback] delivering to %r' % rpc)
+        rpc.deliver(raw)
     
     def errback(self, err):
         if self._errback is not None:
index 0d225cb..be3dcba 100644 (file)
@@ -27,6 +27,7 @@ class CloseSession(RPC):
     def _delivery_hook(self):
         if self.reply.ok:
             self.session.expect_close()
+        self.session.close()
     
     def request(self):
         return self._request(self.spec)
index 1e6d508..c633ca9 100644 (file)
@@ -241,6 +241,7 @@ class SSHSession(Session):
                 logger.debug(e)
         
         if saved_exception is not None:
+            # need pep-3134 to do this right
             raise SSHAuthenticationError(repr(saved_exception))
         
         raise SSHAuthenticationError('No authentication methods available')
@@ -286,3 +287,7 @@ class SSHSession(Session):
         documentation for details.
         '''
         return self._transport
+    
+    @property
+    def is_remote_cisco(self):
+        return 'Cisco' in self._transport.remote_version
index 414aef9..4ac526a 100644 (file)
@@ -20,8 +20,8 @@ logger = logging.getLogger('PrintListener')
 class PrintListener(Listener):
     
     def callback(self, root, raw):
-        tag, attrs = root
-        print '\n$ RECEIVED MESSAGE with root=[tag=%r, attrs=%r]:\n%r\n' % (tag, attrs, raw)
+        print('\n# RECEIVED MESSAGE with root=[tag=%r, attrs=%r] #\n%r\n' %
+              (root[0], root[1], raw))
     
     def errback(self, err):
-        print '\n$ RECEIVED ERROR:\n%r\n' % err
+        print('\n# RECEIVED ERROR #\n%r\n' % err)