rename session to transport
authorShikhar Bhushan <shikhar@schmizz.net>
Sat, 25 Apr 2009 15:49:52 +0000 (15:49 +0000)
committerShikhar Bhushan <shikhar@schmizz.net>
Sat, 25 Apr 2009 15:49:52 +0000 (15:49 +0000)
git-svn-id: http://ncclient.googlecode.com/svn/trunk@58 6dbcf712-26ac-11de-a2f3-1373824ab735

ncclient/transport/__init__.py [new file with mode: 0644]
ncclient/transport/capabilities.py [new file with mode: 0644]
ncclient/transport/error.py [new file with mode: 0644]
ncclient/transport/session.py [new file with mode: 0644]
ncclient/transport/ssh.py [new file with mode: 0644]

diff --git a/ncclient/transport/__init__.py b/ncclient/transport/__init__.py
new file mode 100644 (file)
index 0000000..e62d53e
--- /dev/null
@@ -0,0 +1,16 @@
+# Copyright 2009 Shikhar Bhushan
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+logger = logging.getLogger('ncclient.session')
\ No newline at end of file
diff --git a/ncclient/transport/capabilities.py b/ncclient/transport/capabilities.py
new file mode 100644 (file)
index 0000000..7f924d0
--- /dev/null
@@ -0,0 +1,71 @@
+# Copyright 2009 Shikhar Bhushan
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+class Capabilities:
+    
+    def __init__(self, capabilities=None):
+        self._dict = {}
+        if isinstance(capabilities, dict):
+            self._dict = capabilities
+        elif isinstance(capabilities, list):
+            for uri in capabilities:
+                self._dict[uri] = Capabilities.guess_shorthand(uri)
+    
+    def __contains__(self, key):
+        return ( key in self._dict ) or ( key in self._dict.values() )
+    
+    def __iter__(self):
+        return self._dict.keys().__iter__()
+    
+    def __repr__(self):
+        return repr(self._dict.keys())
+    
+    def add(self, uri, shorthand=None):
+        if shorthand is None:
+            shorthand = Capabilities.guess_shorthand(uri)
+        self._dict[uri] = shorthand
+    
+    set = add
+    
+    def remove(self, key):
+        if key in self._dict:
+            del self._dict[key]
+        else:
+            for uri in self._dict:
+                if self._dict[uri] == key:
+                    del self._dict[uri]
+                    break
+    
+    @staticmethod
+    def guess_shorthand(uri):
+        if uri.startswith('urn:ietf:params:netconf:capability:'):
+            return (':' + uri.split(':')[5])
+
+
+CAPABILITIES = Capabilities([
+    'urn:ietf:params:netconf:base:1.0',
+    'urn:ietf:params:netconf:capability:writable-running:1.0',
+    'urn:ietf:params:netconf:capability:candidate:1.0',
+    'urn:ietf:params:netconf:capability:confirmed-commit:1.0',
+    'urn:ietf:params:netconf:capability:rollback-on-error:1.0',
+    'urn:ietf:params:netconf:capability:startup:1.0',
+    'urn:ietf:params:netconf:capability:url:1.0',
+    'urn:ietf:params:netconf:capability:validate:1.0',
+    'urn:ietf:params:netconf:capability:xpath:1.0',
+    'urn:ietf:params:netconf:capability:notification:1.0',
+    'urn:ietf:params:netconf:capability:interleave:1.0'
+    ])
+
+if __name__ == "__main__":
+    assert(':validate' in CAPABILITIES) # test __contains__
\ No newline at end of file
diff --git a/ncclient/transport/error.py b/ncclient/transport/error.py
new file mode 100644 (file)
index 0000000..45dc422
--- /dev/null
@@ -0,0 +1,49 @@
+# Copyright 2009 Shikhar Bhushan
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ncclient import ClientError
+
+class SessionError(ClientError):
+    pass
+
+class SSHError(SessionError):
+    pass
+
+class SSHUnknownHostError(SSHError):
+    
+    def __init__(self, hostname, key):
+        self.hostname = hostname
+        self.key = key
+    
+    def __str__(self):
+        from binascii import hexlify
+        return ('Unknown host key [%s] for [%s]' %
+                (hexlify(self.key.get_fingerprint()), self.hostname))
+
+class SSHAuthenticationError(SSHError):
+    pass
+
+class SSHSessionClosedError(SSHError):
+    
+    def __init__(self, in_buf, out_buf=None):
+        SessionError.__init__(self, "Unexpected session close.")
+        self._in_buf, self._out_buf = in_buf, out_buf
+        
+    def __str__(self):
+        msg = SessionError(self).__str__()
+        if self._in_buf:
+            msg += '\nIN_BUFFER: %s' % self._in_buf
+        if self._out_buf:
+            msg += '\nOUT_BUFFER: %s' % self._out_buf
+        return msg
\ No newline at end of file
diff --git a/ncclient/transport/session.py b/ncclient/transport/session.py
new file mode 100644 (file)
index 0000000..2da6b3c
--- /dev/null
@@ -0,0 +1,155 @@
+# Copyright 2009 Shikhar Bhushan
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from threading import Thread, Lock, Event
+from Queue import Queue
+
+from . import logger
+from capabilities import Capabilities, CAPABILITIES
+
+
+class Subject:
+
+    def __init__(self):
+        self._listeners = set([])
+        self._lock = Lock()
+        
+    def has_listener(self, listener):
+        with self._lock:
+            return (listener in self._listeners)
+    
+    def add_listener(self, listener):
+        with self._lock:
+            self._listeners.add(listener)
+    
+    def remove_listener(self, listener):
+        with self._lock:
+            self._listeners.discard(listener)
+    
+    def dispatch(self, event, *args, **kwds):
+        # holding the lock while doing callbacks could lead to a deadlock
+        # if one of the above methods is called
+        with self._lock:
+            listeners = list(self._listeners)
+        for l in listeners:
+            try:
+                logger.debug('dispatching [%s] to [%s]' % (event, l))
+                getattr(l, event)(*args, **kwds)
+            except Exception as e:
+                pass # if a listener doesn't care for some event we don't care
+
+
+class Session(Thread, Subject):
+    
+    def __init__(self):
+        Thread.__init__(self, name='session')
+        Subject.__init__(self)
+        self._client_capabilities = CAPABILITIES
+        self._server_capabilities = None # yet
+        self._id = None # session-id
+        self._q = Queue()
+        self._connected = False # to be set/cleared by subclass implementation
+    
+    def _post_connect(self):
+        from ncclient.content.builders import HelloBuilder
+        self.send(HelloBuilder.build(self._client_capabilities))
+        error = None
+        init_event = Event()
+        def ok_cb(id, capabilities):
+            self._id, self._capabilities = id, Capabilities(capabilities)
+            init_event.set()
+        def err_cb(err):
+            error = err
+            init_event.set()
+        listener = HelloListener(ok_cb, err_cb)
+        self.add_listener(listener)
+        # start the subclass' main loop
+        self.start()        
+        # we expect server's hello message
+        init_event.wait()
+        # received hello message or an error happened
+        self.remove_listener(listener)
+        if error:
+            raise error
+        logger.debug('initialized:session-id:%s' % self._id)
+    
+    def send(self, message):
+        logger.debug('queueing:%s' % message)
+        self._q.put(message)
+    
+    def connect(self):
+        raise NotImplementedError
+
+    def run(self):
+        raise NotImplementedError
+        
+    def capabilities(self, whose='client'):
+        if whose == 'client':
+            return self._client_capabilities
+        elif whose == 'server':
+            return self._server_capabilities
+    
+    ### Properties
+    
+    @property
+    def client_capabilities(self):
+        return self._client_capabilities
+    
+    @property
+    def server_capabilities(self):
+        return self._server_capabilities
+    
+    @property
+    def connected(self):
+        return self._connected
+    
+    @property
+    def id(self):
+        return self._id
+
+
+class HelloListener:
+    
+    def __init__(self, init_cb, error_cb):
+        self._init_cb, self._error_cb = init_cb, error_cb
+    
+    def __str__(self):
+        return 'HelloListener'
+    
+    ### Events
+    
+    def received(self, raw):
+        logger.debug(raw)
+        from ncclient.content.parsers import HelloParser
+        try:
+            id, capabilities = HelloParser.parse(raw)
+        except Exception as e:
+            self._error_cb(e)
+        else:
+            self._init_cb(id, capabilities)
+    
+    def error(self, err):
+        self._error_cb(err)
+
+
+class DebugListener:
+    
+    def __str__(self):
+        return 'DebugListener'
+    
+    def received(self, raw):
+        logger.debug('DebugListener:[received]:%s' % raw)
+    
+    def error(self, err):
+        logger.debug('DebugListener:[error]:%s' % err)
diff --git a/ncclient/transport/ssh.py b/ncclient/transport/ssh.py
new file mode 100644 (file)
index 0000000..ff0db2b
--- /dev/null
@@ -0,0 +1,278 @@
+# Copyright 2009 Shikhar Bhushan
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import socket
+from binascii import hexlify
+from cStringIO import StringIO
+from select import select
+
+import paramiko
+
+import session
+from . import logger
+from error import SSHError, SSHUnknownHostError, SSHAuthenticationError, SSHSessionClosedError
+from session import Session
+
+BUF_SIZE = 4096
+MSG_DELIM = ']]>]]>'
+TICK = 0.1
+
+class SSHSession(Session):
+
+    def __init__(self):
+        Session.__init__(self)
+        self._system_host_keys = paramiko.HostKeys()
+        self._host_keys = paramiko.HostKeys()
+        self._host_keys_filename = None
+        self._transport = None
+        self._connected = False
+        self._channel = None
+        self._buffer = StringIO() # for incoming data
+        # parsing-related, see _parse()
+        self._parsing_state = 0 
+        self._parsing_pos = 0
+    
+    def _parse(self):
+        '''Messages ae delimited by MSG_DELIM. The buffer could have grown by a
+        maximum of BUF_SIZE bytes everytime this method is called. Retains state
+        across method calls and if a byte has been read it will not be considered
+        again.
+        '''
+        delim = MSG_DELIM
+        n = len(delim) - 1
+        expect = self._parsing_state
+        buf = self._buffer
+        buf.seek(self._parsing_pos)
+        while True:
+            x = buf.read(1)
+            if not x: # done reading
+                break
+            elif x == delim[expect]: # what we expected
+                expect += 1 # expect the next delim char
+            else:
+                continue
+            # loop till last delim char expected, break if other char encountered
+            for i in range(expect, n):
+                x = buf.read(1)
+                if not x: # done reading
+                    break
+                if x == delim[expect]: # what we expected
+                    expect += 1 # expect the next delim char
+                else:
+                    expect = 0 # reset
+                    break
+            else: # if we didn't break out of the loop, full delim was parsed
+                msg_till = buf.tell() - n
+                buf.seek(0)
+                msg = buf.read(msg_till)
+                self.dispatch('received', msg)
+                buf.seek(n+1, os.SEEK_CUR)
+                rest = buf.read()
+                buf = StringIO()
+                buf.write(rest)
+                buf.seek(0)
+                state = 0
+        self._buffer = buf
+        self._parsing_state = expect
+        self._parsing_pos = self._buffer.tell()
+    
+    def load_system_host_keys(self, filename=None):
+        if filename is None:
+            # try the user's .ssh key file, and mask exceptions
+            filename = os.path.expanduser('~/.ssh/known_hosts')
+            try:
+                self._system_host_keys.load(filename)
+            except IOError:
+                pass
+            return
+        self._system_host_keys.load(filename)
+    
+    def load_host_keys(self, filename):
+        self._host_keys_filename = filename
+        self._host_keys.load(filename)
+
+    def add_host_key(self, key):
+        self._host_keys.add(key)
+    
+    def save_host_keys(self, filename):
+        f = open(filename, 'w')
+        for hostname, keys in self._host_keys.iteritems():
+            for keytype, key in keys.iteritems():
+                f.write('%s %s %s\n' % (hostname, keytype, key.get_base64()))
+        f.close()    
+    
+    def close(self):
+        if self._transport.is_active():
+            self._transport.close()
+        self._connected = False
+    
+    def connect(self, hostname, port=830, timeout=None,
+                unknown_host_cb=None, username=None, password=None,
+                key_filename=None, allow_agent=True, look_for_keys=True):
+        
+        assert(username is not None)
+        
+        for (family, socktype, proto, canonname, sockaddr) in \
+        socket.getaddrinfo(hostname, port):
+            if socktype==socket.SOCK_STREAM:
+                af = family
+                addr = sockaddr
+                break
+        else:
+            raise SSHError('No suitable address family for %s' % hostname)
+        sock = socket.socket(af, socket.SOCK_STREAM)
+        sock.settimeout(timeout)
+        sock.connect(addr)
+        t = self._transport = paramiko.Transport(sock)
+        t.set_log_channel(logger.name)
+        
+        try:
+            t.start_client()
+        except paramiko.SSHException:
+            raise SSHError('Negotiation failed')
+        
+        # host key verification
+        server_key = t.get_remote_server_key()
+        known_host = self._host_keys.check(hostname, server_key) or \
+                        self._system_host_keys.check(hostname, server_key)
+        
+        if unknown_host_cb is None:
+            unknown_host_cb = lambda *args: False
+        if not known_host and not unknown_host_cb(hostname, server_key):
+                raise SSHUnknownHostError(hostname, server_key)
+        
+        if key_filename is None:
+            key_filenames = []
+        elif isinstance(key_filename, basestring):
+            key_filenames = [ key_filename ]
+        else:
+            key_filenames = key_filename
+        
+        self._auth(username, password, key_filenames, allow_agent, look_for_keys)
+        
+        self._connected = True # there was no error authenticating
+        
+        c = self._channel = self._transport.open_session()
+        c.invoke_subsystem('netconf')
+        c.set_name('netconf')
+        
+        self._post_connect()
+    
+    # on the lines of paramiko.SSHClient._auth()
+    def _auth(self, username, password, key_filenames, allow_agent,
+              look_for_keys):
+        saved_exception = None
+        
+        for key_filename in key_filenames:
+            for cls in (paramiko.RSAKey, paramiko.DSSKey):
+                try:
+                    key = cls.from_private_key_file(key_filename, password)
+                    logger.debug('Trying key %s from %s' %
+                              (hexlify(key.get_fingerprint()), key_filename))
+                    self._transport.auth_publickey(username, key)
+                    return
+                except Exception as e:
+                    saved_exception = e
+                    logger.debug(e)
+        
+        if allow_agent:
+            for key in paramiko.Agent().get_keys():
+                try:
+                    logger.debug('Trying SSH agent key %s' %
+                                 hexlify(key.get_fingerprint()))
+                    self._transport.auth_publickey(username, key)
+                    return
+                except Exception as e:
+                    saved_exception = e
+                    logger.debug(e)
+        
+        keyfiles = []
+        if look_for_keys:
+            rsa_key = os.path.expanduser('~/.ssh/id_rsa')
+            dsa_key = os.path.expanduser('~/.ssh/id_dsa')
+            if os.path.isfile(rsa_key):
+                keyfiles.append((paramiko.RSAKey, rsa_key))
+            if os.path.isfile(dsa_key):
+                keyfiles.append((paramiko.DSSKey, dsa_key))
+            # look in ~/ssh/ for windows users:
+            rsa_key = os.path.expanduser('~/ssh/id_rsa')
+            dsa_key = os.path.expanduser('~/ssh/id_dsa')
+            if os.path.isfile(rsa_key):
+                keyfiles.append((paramiko.RSAKey, rsa_key))
+            if os.path.isfile(dsa_key):
+                keyfiles.append((paramiko.DSSKey, dsa_key))
+        
+        for cls, filename in keyfiles:
+            try:
+                key = cls.from_private_key_file(filename, password)
+                logger.debug('Trying discovered key %s in %s' %
+                          (hexlify(key.get_fingerprint()), filename))
+                self._transport.auth_publickey(username, key)
+                return
+            except Exception as e:
+                saved_exception = e
+                logger.debug(e)
+        
+        if password is not None:
+            try:
+                self._transport.auth_password(username, password)
+                return
+            except Exception as e:
+                saved_exception = e
+                logger.debug(e)
+        
+        if saved_exception is not None:
+            raise AuthenticationError(repr(saved_exception))
+        
+        raise AuthenticationError('No authentication methods available')
+    
+    def run(self):
+        chan = self._channel
+        chan.setblocking(0)
+        q = self._q
+        try:
+            while True:
+                # select on a paramiko ssh channel object does not ever return
+                # it in the writable list, so it channel's don't exactly emulate 
+                # the socket api
+                r, w, e = select([chan], [], [], TICK)
+                # will wakeup evey TICK seconds to check if something
+                # to send, more if something to read (due to select returning
+                # chan in readable list)
+                if r:
+                    data = chan.recv(BUF_SIZE)
+                    if data:
+                        self._buffer.write(data)
+                        self._parse()
+                    else:
+                        raise SSHSessionClosedError(self._buffer.getvalue())
+                if not q.empty() and chan.send_ready():
+                    data = q.get() + MSG_DELIM
+                    while data:
+                        n = chan.send(data)
+                        if n <= 0:
+                            raise SSHSessionClosedError(self._buffer.getvalue(), data)
+                        data = data[n:]
+        except Exception as e:
+            self.close()
+            logger.debug('*** broke out of main loop ***')
+            self.dispatch('error', e)
+    
+    @property
+    def transport(self):
+        '''Get underlying paramiko.transport object; this is provided so methods
+        like transport.set_keepalive can be called.
+        '''
+        return self._transport