commitdump
authorShikhar Bhushan <shikhar@schmizz.net>
Wed, 22 Apr 2009 04:40:32 +0000 (04:40 +0000)
committerShikhar Bhushan <shikhar@schmizz.net>
Wed, 22 Apr 2009 04:40:32 +0000 (04:40 +0000)
git-svn-id: http://ncclient.googlecode.com/svn/trunk@39 6dbcf712-26ac-11de-a2f3-1373824ab735

ncclient/capabilities.py [moved from ncclient/capability.py with 100% similarity]
ncclient/content.py
ncclient/listeners.py
ncclient/rpc.py
ncclient/session.py
ncclient/ssh.py

index d3e841a..861a885 100644 (file)
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import logging
 from xml.etree import cElementTree as ElementTree
 
-NAMESPACE = 'urn:ietf:params:xml:ns:netconf:base:1.0'
+logger = logging.getLogger('ncclient.content')
 
-def qualify(tag, ns=NAMESPACE):
+BASE_NS = 'urn:ietf:params:xml:ns:netconf:base:1.0'
+NOTIFICATION_NS = 'urn:ietf:params:xml:ns:netconf:notification:1.0'
+
+def qualify(tag, ns=BASE_NS):
     return '{%s}%s' % (ns, tag)
 
 _ = qualify
 
 def make_hello(capabilities):
-    return '<hello xmlns="%s">%s</hello>' % (NAMESPACE, capabilities)
+    return '<hello xmlns="%s">%s</hello>' % (BASE_NS, capabilities)
 
 def make_rpc(id, op):
-    return '<rpc message-id="%s" xmlns="%s">%s</rpc>' % (id, NAMESPACE, op)
+    return '<rpc message-id="%s" xmlns="%s">%s</rpc>' % (id, BASE_NS, op)
 
 def parse_hello(raw):
-    from capability import Capabilities
+    from capabilities import Capabilities
     id, capabilities = 0, Capabilities()
     root = ElementTree.fromstring(raw)
     if root.tag == _('hello'):
@@ -40,9 +44,13 @@ def parse_hello(raw):
                     capabilities.add(cap.text)
     return id, capabilities
 
-def parse_message_type(raw):
-    
-    target = RootElementParser()
-    parser = ElementTree.XMLTreeBuilder(target=target)
-    parser.feed(raw)
-    return target.id
+def parse_message_root(raw):
+    from cStringIO import StringIO
+    fp = StringIO(raw)
+    for event, element in ElementTree.iterparse(fp, events=('start',)):
+        if element.tag == _('rpc'):
+            return element.attrib['message-id']
+        elif element.tag == _('notification', NOTIFICATION_NS):
+            return 'notification'
+        else:
+            return None
\ No newline at end of file
index 711e09a..3169746 100644 (file)
@@ -1,4 +1,4 @@
-                                                                                                                                    # Copyright 2009 Shikhar Bhushan
+# Copyright 2009 Shikhar Bhushan
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # limitations under the License.
 
 import logging
-import weakref
+from weakref import WeakValueDictionary
+
+import content
 
 logger = logging.getLogger('ncclient.listeners')
 
-import content
+session_listeners = {}
+def session_listener_factory(session):
+    try:
+        return session_listeners[session]
+    except KeyError:
+        session_listeners[session] = SessionListener()
+        return session_listeners[session]
 
 class SessionListener(object):
     
-    'A multiton - one listener per session'
-    
-    instances = weakref.WeakValueDictionary()
-    
-    def __new__(cls, sid):
-        if sid in instances:# not been gc'd
-            return cls.instances[sid]
-        else:
-            inst = object.__new__(cls)
-            cls.instances[sid] = inst
-            return inst
+    def __init__(self):
+        self._id2rpc = WeakValueDictionary()
+        self._expecting_close = False
+        self._subscription = None
     
     def __str__(self):
         return 'SessionListener'
     
-    def set_subscription(self, id):     
+    def set_subscription(self, id):   
         self._subscription = id
     
+    def expect_close(self):
+        self._expecting_close = True
+    
     def register(self, id, op):
         self._id2rpc[id] = op
     
-    def unregister(self, id):
-        del self._id2prc[id]
-    
     ### Events
     
     def reply(self, raw):
-        id = content.parse_message(raw)
-        if id:
-            self._id2rpc[id]._deliver(raw)
-        else:
-            self._id2rpc[self._sub_id]._notify(raw)
+        try:
+            id = content.parse_message_root(raw)
+            if id is None:
+                pass
+            elif id == 'notification':
+                self._id2rpc[self._sub_id]._notify(raw)
+            else:
+                self._id2rpc[id]._response_cb(raw)
+        except Exception as e:
+            logger.warning(e)
     
-    def close(self, buf):
-        pass # TODO
+    def error(self, err):
+        from ssh import SessionCloseError
+        if err is SessionCloseError:
+            logger.debug('received session close, expecting_close=%s' %
+                         self._expecting_close)
+            if not self._expecting_close:
+                raise err
 
 class DebugListener:
     
@@ -63,7 +74,7 @@ class DebugListener:
         return 'DebugListener'
     
     def reply(self, raw):
-        logger.debug('reply:\n%s' % raw)
+        logger.debug('DebugListener:reply:\n%s' % raw)
     
     def error(self, err):
-        logger.debug(err)
+        logger.debug('DebugListener:error:\n%s' % err)
index 196ef72..e65e655 100644 (file)
 # limitations under the License.
 
 from threading import Event, Lock
-
-from listener import SessionListener
-
 from uuid import uuid1
 
+import content
+from listeners import session_listener_factory
+
 class RPC:
     
-    def __init__(self, session, async=False):
+    def __init__(self, session, async=False, parse=True):
         self._session = session
         self._async = async
+        self._id = uuid1().urn
+        self._listener = session_listener_factory(self._session)
+        listener.register(self._id, self)
+        session.add_listener(self._listener)
         self._reply = None
         self._reply_event = Event()
-        self._id = uuid1().urn
 
-    def _listener(self):
-        if not RPC.listeners.has_key(self._session.id):
-            RPC.listeners[self._session.id] = SessionListener()
-        return RPC.listeners[self._session.id]
-
-    def request(self, async=False):
-        self._async = async
-        listener = SessionListener(self._session.id)
-        session.add_listener(listener)
-        listener.register(self._id, self)
-        self._session.send(self.to_xml())
-    
-    def response_cb(self, reply):
-        self._reply = reply # does callback parse??
+    def _response_cb(self, reply):
+        self._reply = reply
         self._event.set()
     
+    def _do_request(self, op):
+        self._session.send(content.make_rpc(self._id, op))
+        if not self._async:
+            self._reply_event.wait()
+        return self._reply
+    
+    def request(self):
+        raise NotImplementedError
+    
+    def wait_for_reply(self, timeout=None):
+        self._reply_event.wait(timeout)
+    
     @property
     def has_reply(self):
         return self._reply_event.isSet()
     
-    def wait_on_reply(self, timeout=None):
-        self._reply_event.wait(timeout)
-    
     @property
     def is_async(self):
         return self._async
     
     @property
+    def reply(self):
+        return self._reply
+    
+    @property
     def id(self):
-        return self._id
\ No newline at end of file
+        return self._id
+    
+    @property
+    def session(self):
+        return self._session
+
+
+class RPCReply:
+    pass
+
+class RPCError:
+    pass
\ No newline at end of file
index fb17d90..dbafb3b 100644 (file)
 # limitations under the License.
 
 import logging
-
-import content
-
 from threading import Thread, Event
 from Queue import Queue
 
-from capability import CAPABILITIES
+import content
+from capabilities import CAPABILITIES
 from error import ClientError
 from subject import Subject
 
 logger = logging.getLogger('ncclient.session')
 
-class SessionError(ClientError):
-    
-    pass
+class SessionError(ClientError): pass
 
 class Session(Thread, Subject):
     
@@ -37,13 +33,26 @@ class Session(Thread, Subject):
         self._client_capabilities = CAPABILITIES
         self._server_capabilities = None # yet
         self._id = None # session-id
-        self._connected = False # subclasses should set this
         self._error = None
         self._init_event = Event()
         self._q = Queue()
+        self._connected = False # to be set/cleared by subclass
+    
+    def _post_connect(self):
+        # start the subclass' main loop
+        self.start()
+        # queue client's hello message for sending
+        self.send(content.make_hello(self._client_capabilities))
+        # we expect server's hello message, wait for _init_event to be set by HelloListener
+        self._init_event.wait()
+        # there may have been an error
+        if self._error:
+            self._close()
+            raise self._error
     
     def send(self, message):
-        message = (u'<?xml version="1.0" encoding="UTF-8"?>%s' % message).encode('utf-8')
+        message = (u'<?xml version="1.0" encoding="UTF-8"?>%s' %
+                   message).encode('utf-8')
         logger.debug('queueing message: \n%s' % message)
         self._q.put(message)
     
@@ -60,7 +69,7 @@ class Session(Thread, Subject):
         return self._client_capabilities
     
     @property
-    def serve_capabilities(self):
+    def server_capabilities(self):
         return self._server_capabilities
     
     @property
@@ -71,18 +80,6 @@ class Session(Thread, Subject):
     def id(self):
         return self._id
     
-    def _post_connect(self):
-        # start the subclass' main loop
-        self.start()
-        # queue client's hello message for sending
-        self.send(content.make_hello(self._client_capabilities))
-        # we expect server's hello message, wait for _init_event to be set by HelloListener
-        self._init_event.wait()
-        # there may have been an error
-        if self._error:
-            self._close()
-            raise self._error
-    
     class HelloListener:
         
         def __str__(self):
index dfb90d6..7d4be34 100644 (file)
 # limitations under the License.
 
 import logging
-import paramiko
-
-from os import SEEK_CUR
 from cStringIO import StringIO
+from os import SEEK_CUR
+
+import paramiko
 
 from session import Session, SessionError
 
@@ -42,6 +42,7 @@ class SSHSession(Session):
                  missing_host_key_policy=paramiko.RejectPolicy):
         Session.__init__(self)
         self._client = paramiko.SSHClient()
+        self._channel = None
         if load_known_hosts:
             self._client.load_system_host_keys()
         self._client.set_missing_host_key_policy(missing_host_key_policy)
@@ -49,6 +50,49 @@ class SSHSession(Session):
         self._parsing_state = 0
         self._parsing_pos = 0
     
+    def _close(self):
+        self._channel.close()
+        self._connected = False
+    
+    def _fresh_data(self):
+        delim = SSHSession.MSG_DELIM
+        n = len(delim) - 1
+        state = self._parsing_state
+        buf = self._in_buf
+        buf.seek(self._parsing_pos)
+        while True:
+            x = buf.read(1)
+            if not x: # done reading
+                break
+            elif x == delim[state]:
+                state += 1
+            else:
+                continue
+            # loop till last delim char expected, break if other char encountered
+            for i in range(state, n):
+                x = buf.read(1)
+                if not x: # done reading
+                    break
+                if x==delim[i]: # what we expected
+                    state += 1 # expect the next delim char
+                else:
+                    state = 0 # reset
+                    break
+            else: # if we didn't break out of above loop, full delim parsed
+                till = buf.tell() - n
+                buf.seek(0)
+                msg = buf.read(till)
+                self.dispatch('reply', msg)
+                buf.seek(n+1, SEEK_CUR)
+                rest = buf.read()
+                buf = StringIO()
+                buf.write(rest)
+                buf.seek(0)
+                state = 0
+        self._in_buf = buf
+        self._parsing_state = state
+        self._parsing_pos = self._in_buf.tell()
+
     def load_host_keys(self, filename):
         self._client.load_host_keys(filename)
     
@@ -96,49 +140,6 @@ class SSHSession(Session):
         except Exception as e:
             logger.debug('*** broke out of main loop ***')
             self.dispatch('error', e)
-    
-    def _close(self):
-        self._channel.close()
-        self._connected = False
-    
-    def _fresh_data(self):
-        delim = SSHSession.MSG_DELIM
-        n = len(delim) - 1
-        state = self._parsing_state
-        buf = self._in_buf
-        buf.seek(self._parsing_pos)
-        while True:
-            x = buf.read(1)
-            if not x: # done reading
-                break
-            elif x == delim[state]:
-                state += 1
-            else:
-                continue
-            # loop till last delim char expected, break if other char encountered
-            for i in range(state, n):
-                x = buf.read(1)
-                if not x: # done reading
-                    break
-                if x==delim[i]: # what we expected
-                    state += 1 # expect the next delim char
-                else:
-                    state = 0 # reset
-                    break
-            else: # if we didn't break out of above loop, full delim parsed
-                till = buf.tell() - n
-                buf.seek(0)
-                msg = buf.read(till)
-                self.dispatch('reply', msg)
-                buf.seek(n+1, SEEK_CUR)
-                rest = buf.read()
-                buf = StringIO()
-                buf.write(rest)
-                buf.seek(0)
-                state = 0
-        self._in_buf = buf
-        self._parsing_state = state
-        self._parsing_pos = self._in_buf.tell()
 
 class MissingHostKeyPolicy(paramiko.MissingHostKeyPolicy):