Revision f5c75f88

b/ncclient/__init__.py
18 18
    raise RuntimeError('You need python 2.5 for this module.')
19 19

  
20 20
__version__ = "0.05"
21

  
22
class ClientError(Exception):
23
    pass
b/ncclient/content/__init__.py
15 15
'This module serves as an XML abstraction layer'
16 16

  
17 17
import logging
18
logger = logging.getLogger('ncclient.content')
18
logger = logging.getLogger('ncclient.content')
b/ncclient/content/builders.py
66 66
    
67 67
    @staticmethod
68 68
    def build(msgid, op, encoding='utf-8'):
69
        if isinstance(opspec, basestring):
70
            return build_from_string(msgid, op, encoding)
69
        if isinstance(op, basestring):
70
            return RPCBuilder.build_from_string(msgid, op, encoding)
71 71
        else:
72
            return build_from_spec(msgid, op, encoding)
72
            return RPCBuilder.build_from_spec(msgid, op, encoding)
73 73
    
74 74
    @staticmethod
75 75
    def build_from_spec(msgid, opspec, encoding='utf-8'):
b/ncclient/content/parsers.py
24 24
        'Returns tuple of (session-id, ["capability_uri", ...])'
25 25
        sid, capabilities = 0, []
26 26
        root = ET.fromstring(raw)
27
        if root.tag == _('hello', BASE_NS):
27
        if root.tag in ('hello', _('hello', BASE_NS)):
28 28
            for child in root.getchildren():
29
                if child.tag == _('session-id', BASE_NS):
29
                if child.tag in ('session-id', _('session-id', BASE_NS)):
30 30
                    sid = child.text
31
                elif child.tag == _('capabilities', BASE_NS):
31
                elif child.tag in ('capabilities', _('capabilities', BASE_NS)):
32
                    for cap in child.getiterator('capability'): 
33
                        capabilities.append(cap.text)
32 34
                    for cap in child.getiterator(_('capability', BASE_NS)):
33 35
                        capabilities.append(cap.text)
34 36
        return sid, capabilities
b/ncclient/operations/session.py
15 15
'Session-related NETCONF operations'
16 16

  
17 17
class CloseSession(RPC):
18
    
19 18
    pass
20 19

  
21 20
class KillSession(RPC):
22
    
23 21
    pass
b/ncclient/session/__init__.py
13 13
# limitations under the License.
14 14

  
15 15
import logging
16
logger = logging.getLogger('ncclient.session')
17

  
18
from session import DebugListener, SessionError, SessionCloseError
19
from ssh import SSHSession
20
from capabilities import CAPABILITIES, Capabilities
21

  
22
__all__ = [
23
    'DebugListener'
24
    'Session'
25
    'SSHSession',
26
    'SessionError',
27
    'SessionCloseError',
28
    'Capabilities',
29
    'CAPABILITIES'
30
]
16
logger = logging.getLogger('ncclient.session')
b/ncclient/session/capabilities.py
52 52
        if uri.startswith('urn:ietf:params:netconf:capability:'):
53 53
            return (':' + uri.split(':')[5])
54 54

  
55
    
55

  
56 56
CAPABILITIES = Capabilities([
57 57
    'urn:ietf:params:netconf:base:1.0',
58 58
    'urn:ietf:params:netconf:capability:writable-running:1.0',
b/ncclient/session/error.py
1
# Copyright 2009 Shikhar Bhushan
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#    http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

  
15
from ncclient import ClientError
16

  
17
class SessionError(ClientError):
18
    pass
19

  
20
class RemoteClosedError(SessionError):
21
    
22
    def __init__(self, in_buf, out_buf=None):
23
        SessionError.__init__(self)
24
        self._in_buf, self._out_buf = in_buf, out_buf
25
        
26
    def __str__(self):
27
        msg = 'Session closed by remote endpoint.'
28
        if self._in_buf:
29
            msg += '\nIN_BUFFER: %s' % self._in_buf
30
        if self._out_buf:
31
            msg += '\nOUT_BUFFER: %s' % self._out_buf
32
        return msg
33

  
34
class AuthenticationError(SessionError):
35
    pass
36

  
37
class SSHError(SessionError):
38
    pass
39

  
40
class SSHUnknownHostError(SSHError):
41
    
42
    def __init__(self, hostname, key):
43
        self.hostname = hostname
44
        self.key = key
45
    
46
    def __str__(self):
47
        from binascii import hexlify
48
        return ('Unknown host key [%s] for [%s]' %
49
                (hexlify(self.key.get_fingerprint()), self.hostname))
50

  
51
class SSHAuthenticationError(AuthenticationError, SSHError):
52
    'wraps a paramiko exception that occured during auth'
53
    
54
    def __init__(self, ex):
55
        self.ex = ex
56
    
57
    def __repr__(self):
58
        return repr(ex)
b/ncclient/session/session.py
12 12
# See the License for the specific language governing permissions and
13 13
# limitations under the License.
14 14

  
15
import logging
16 15
from threading import Thread, Lock, Event
17 16
from Queue import Queue
18 17

  
18
from . import logger
19 19
from capabilities import Capabilities, CAPABILITIES
20 20

  
21
logger = logging.getLogger('ncclient.session')
22 21

  
23
class SessionError(Exception):
24
    
25
    pass
22
class Subject:
26 23

  
27
class SessionCloseError(SessionError):
28
    
29
    def __init__(self, in_buf, out_buf=None):
30
        SessionError.__init__(self)
31
        self._in_buf, self._out_buf = in_buf, out_buf
24
    def __init__(self):
25
        self._listeners = set([])
26
        self._lock = Lock()
32 27
        
33
    def __str__(self):
34
        msg = 'Session closed by remote endpoint.'
35
        if self._in_buf:
36
            msg += '\nIN_BUFFER: %s' % self._in_buf
37
        if self._out_buf:
38
            msg += '\nOUT_BUFFER: %s' % self._out_buf
39
        return msg
28
    def has_listener(self, listener):
29
        with self._lock:
30
            return (listener in self._listeners)
40 31
    
41
class Session(Thread):
32
    def add_listener(self, listener):
33
        with self._lock:
34
            self._listeners.add(listener)
35
    
36
    def remove_listener(self, listener):
37
        with self._lock:
38
            self._listeners.discard(listener)
39
    
40
    def dispatch(self, event, *args, **kwds):
41
        # holding the lock while doing callbacks could lead to a deadlock
42
        # if one of the above methods is called
43
        with self._lock:
44
            listeners = list(self._listeners)
45
        for l in listeners:
46
            try:
47
                logger.debug('dispatching [%s] to [%s]' % (event, l))
48
                getattr(l, event)(*args, **kwds)
49
            except Exception as e:
50
                pass # if a listener doesn't care for some event we don't care
51

  
52

  
53
class Session(Thread, Subject):
42 54
    
43 55
    def __init__(self):
44 56
        Thread.__init__(self, name='session')
57
        Subject.__init__(self)
45 58
        self._client_capabilities = CAPABILITIES
46 59
        self._server_capabilities = None # yet
47 60
        self._id = None # session-id
48 61
        self._q = Queue()
49 62
        self._connected = False # to be set/cleared by subclass implementation
50
        self._listeners = set([])
51
        self._lock = Lock()
52 63
    
53 64
    def _post_connect(self):
54 65
        from ncclient.content.builders import HelloBuilder
55
        # queue client's hello message for sending
56 66
        self.send(HelloBuilder.build(self._client_capabilities))
57
        
58 67
        error = None
59
        proceed = Event()
68
        init_event = Event()
60 69
        def ok_cb(id, capabilities):
61 70
            self._id, self._capabilities = id, Capabilities(capabilities)
62
            proceed.set()
71
            init_event.set()
63 72
        def err_cb(err):
64 73
            error = err
65
            proceed.set()
74
            init_event.set()
66 75
        listener = HelloListener(ok_cb, err_cb)
67 76
        self.add_listener(listener)
68
        
69 77
        # start the subclass' main loop
70 78
        self.start()        
71 79
        # we expect server's hello message
72
        proceed.wait()
80
        init_event.wait()
73 81
        # received hello message or an error happened
74 82
        self.remove_listener(listener)
75 83
        if error:
76
            self._close()
77
            raise self._error
84
            raise error
85
        logger.debug('initialized:session-id:%s' % self._id)
78 86
    
79 87
    def send(self, message):
80
        logger.debug('queueing message: \n%s' % message)
88
        logger.debug('queueing:%s' % message)
81 89
        self._q.put(message)
82 90
    
83 91
    def connect(self):
......
92 100
        elif whose == 'server':
93 101
            return self._server_capabilities
94 102
    
95
    ### Session is a subject for arbitary listeners
96
    
97
    def has_listener(self, listener):
98
        with self._lock:
99
            return (listener in self._listeners)
100
    
101
    def add_listener(self, listener):
102
        with self._lock:
103
            self._listeners.add(listener)
104
    
105
    def remove_listener(self, listener):
106
        with self._lock:
107
            self._listeners.discard(listener)
108
    
109
    def dispatch(self, event, *args, **kwds):
110
        # holding the lock while doing callbacks could lead to a deadlock
111
        # if one of the above methods is called
112
        with self._lock:
113
            listeners = list(self._listeners)
114
        for l in listeners:
115
            try:
116
                logger.debug('dispatching [%s] to [%s]' % (event, l))
117
                getattr(l, event)(*args, **kwds)
118
            except Exception as e:
119
                logger.warning(e)
120
    
121 103
    ### Properties
122 104
    
123 105
    @property
......
147 129
    
148 130
    ### Events
149 131
    
150
    def reply(self, raw):
132
    def received(self, raw):
133
        logger.debug(raw)
151 134
        from ncclient.content.parsers import HelloParser
152 135
        try:
153 136
            id, capabilities = HelloParser.parse(raw)
......
165 148
    def __str__(self):
166 149
        return 'DebugListener'
167 150
    
168
    def reply(self, raw):
169
        logger.debug('DebugListener:reply:%s' % raw)
151
    def received(self, raw):
152
        logger.debug('DebugListener:[received]:%s' % raw)
170 153
    
171 154
    def error(self, err):
172
        logger.debug('DebugListener:error:%s' % err)
155
        logger.debug('DebugListener:[error]:%s' % err)
b/ncclient/session/ssh.py
12 12
# See the License for the specific language governing permissions and
13 13
# limitations under the License.
14 14

  
15
import os
16
import socket
17
from binascii import hexlify
15 18
from cStringIO import StringIO
16
from os import SEEK_CUR
17 19
from select import select
18 20

  
19 21
import paramiko
20 22

  
23
import session
21 24
from . import logger
22
from session import Session, SessionError, SessionCloseError
25
from error import SSHError, SSHUnknownHostError, SSHAuthenticationError, RemoteClosedError
26
from session import Session
23 27

  
24 28
BUF_SIZE = 4096
25 29
MSG_DELIM = ']]>]]>'
26

  
27
# TODO:
28
# chuck SSHClient and use paramiko low-level api to get cisco compatibility
29
# and finer control over host key verification, authentication, and error
30
# handling
30
TICK = 0.1
31 31

  
32 32
class SSHSession(Session):
33 33

  
34
    def __init__(self, load_known_hosts=True,
35
                 missing_host_key_policy=paramiko.RejectPolicy()):
34
    def __init__(self):
36 35
        Session.__init__(self)
37
        self._client = paramiko.SSHClient()
36
        self._system_host_keys = paramiko.HostKeys()
37
        self._host_keys = paramiko.HostKeys()
38
        self._host_keys_filename = None
39
        self._transport = None
40
        self._connected = False
38 41
        self._channel = None
39
        if load_known_hosts:
40
            self._client.load_system_host_keys()
41
        self._client.set_missing_host_key_policy(missing_host_key_policy)
42
        self._in_buf = StringIO()
43
        self._parsing_state = 0
42
        self._buffer = StringIO() # for incoming data
43
        # parsing-related, see _fresh_data()
44
        self._parsing_state = 0 
44 45
        self._parsing_pos = 0
45 46
    
46
    def _close(self):
47
        self._channel.close()
48
        self._connected = False
49
    
50 47
    def _fresh_data(self):
51 48
        '''The buffer could have grown by a maximum of BUF_SIZE bytes everytime 
52 49
        this method is called. Retains state across method calls and if a byte
......
55 52
        delim = MSG_DELIM
56 53
        n = len(delim) - 1
57 54
        state = self._parsing_state
58
        buf = self._in_buf
55
        buf = self._buffer
59 56
        buf.seek(self._parsing_pos)
60 57
        while True:
61 58
            x = buf.read(1)
......
79 76
                till = buf.tell() - n
80 77
                buf.seek(0)
81 78
                msg = buf.read(till)
82
                self.dispatch('reply', msg)
83
                buf.seek(n+1, SEEK_CUR)
79
                self.dispatch('received', msg)
80
                buf.seek(n+1, os.SEEK_CUR)
84 81
                rest = buf.read()
85 82
                buf = StringIO()
86 83
                buf.write(rest)
87 84
                buf.seek(0)
88 85
                state = 0
89
        self._in_buf = buf
86
        self._buffer = buf
90 87
        self._parsing_state = state
91
        self._parsing_pos = self._in_buf.tell()
92

  
88
        self._parsing_pos = self._buffer.tell()
89
    
90
    def load_system_host_keys(self, filename=None):
91
        if filename is None:
92
            # try the user's .ssh key file, and mask exceptions
93
            filename = os.path.expanduser('~/.ssh/known_hosts')
94
            try:
95
                self._system_host_keys.load(filename)
96
            except IOError:
97
                pass
98
            return
99
        self._system_host_keys.load(filename)
100
    
93 101
    def load_host_keys(self, filename):
94
        self._client.load_host_keys(filename)
102
        self._host_keys_filename = filename
103
        self._host_keys.load(filename)
95 104

  
96
    def set_missing_host_key_policy(self, policy):
97
        self._client.set_missing_host_key_policy(policy)
98

  
99
    def connect(self, hostname, port=830, username=None, password=None,
100
                key_filename=None, timeout=None, allow_agent=True,
101
                look_for_keys=True):
102
        self._client.connect(hostname, port=port, username=username,
103
                            password=password, key_filename=key_filename,
104
                            timeout=timeout, allow_agent=allow_agent,
105
                            look_for_keys=look_for_keys)    
106
        transport = self._client.get_transport()
107
        self._channel = transport.open_session()
108
        self._channel.invoke_subsystem('netconf')
109
        self._channel.set_name('netconf')
110
        self._connected = True
111
        self._post_connect()
105
    def add_host_key(self, key):
106
        self._host_keys.add(key)
107
    
108
    def save_host_keys(self, filename):
109
        f = open(filename, 'w')
110
        for hostname, keys in self._host_keys.iteritems():
111
            for keytype, key in keys.iteritems():
112
                f.write('%s %s %s\n' % (hostname, keytype, key.get_base64()))
113
        f.close()    
114
    
115
    def close(self):
116
        if self._transport.is_active():
117
            self._transport.close()
118
        self._connected = False
119
    
120
    def connect(self, hostname, port=830, timeout=None,
121
                unknown_host_cb=None, username=None, password=None,
122
                key_filename=None, allow_agent=True, look_for_keys=True,
123
                authtypes=['publickey', 'password', 'keyboard-interactive']):
124
        
125
        assert(username is not None)
126
        
127
        for (family, socktype, proto, canonname, sockaddr) in \
128
        socket.getaddrinfo(hostname, port):
129
            if socktype==socket.SOCK_STREAM:
130
                af = family
131
                addr = sockaddr
132
                break
133
        else:
134
            raise SSHError('No suitable address family for %s' % hostname)
135
        sock = socket.socket(af, socket.SOCK_STREAM)
136
        sock.settimeout(timeout)
137
        sock.connect(addr)
138
        t = self._transport = paramiko.Transport(sock)
139
        t.set_log_channel(logger.name)
140
        
141
        try:
142
            t.start_client()
143
        except paramiko.SSHException:
144
            raise SSHError('Negotiation failed')
145
        
146
        # host key verification
147
        server_key = t.get_remote_server_key()
148
        known_host = self._host_keys.check(hostname, server_key) or \
149
                        self._system_host_keys.check(hostname, server_key)
150
        
151
        if unknown_host_cb is None:
152
            unknown_host_cb = lambda *args: False
153
        if not known_host and not unknown_host_cb(hostname, server_key):
154
                raise SSHUnknownHostError(hostname, server_key)
155
        
156
        if key_filename is None:
157
            key_filenames = []
158
        elif isinstance(key_filename, basestring):
159
            key_filenames = [ key_filename ]
160
        else:
161
            key_filenames = key_filename
162
        
163
        self._auth(username, password, key_filenames, allow_agent, look_for_keys)
164
        
165
        self._connected = True # there was no error authenticating
166
        
167
        c = self._channel = self._transport.open_session()
168
        c.invoke_subsystem('netconf')
169
        c.set_name('netconf')
170
        
171
        Session._post_connect(self)
172
    
173
    # on the lines of paramiko.SSHClient._auth()
174
    def _auth(self, username, password, key_filenames, allow_agent,
175
              look_for_keys):
176
        saved_exception = None
177
        
178
        allowed = ['publickey', 'keyboard-interactive', 'password']
179
        
180
        for key_filename in key_filenames:
181
            if 'publickey' not in allowed:
182
                    break
183
            for cls in (paramiko.RSAKey, paramiko.DSSKey):
184
                try:
185
                    key = cls.from_private_key_file(key_filename, password)
186
                    logger.debug('Trying key %s from %s' %
187
                              (hexlify(key.get_fingerprint()), key_filename))
188
                    self._transport.auth_publickey(username, key)
189
                    return
190
                except paramiko.BadAuthenticationType as e:
191
                    allowed = e.allowed_types
192
                    logger.debug(e)
193
                except Exception as e:
194
                    saved_exception = e
195
                    logger.debug(e)
196
        
197
        if allow_agent:
198
            for key in paramiko.Agent().get_keys():
199
                if 'publickey' not in allowed:
200
                    break
201
                try:
202
                    logger.debug('Trying SSH agent key %s' %
203
                                 hexlify(key.get_fingerprint()))
204
                    logger.error( self._transport.auth_publickey(username, key) )
205
                    return
206
                except paramiko.BadAuthenticationType as e:
207
                    allowed = e.allowed_types
208
                    logger.debug(e)
209
                except Exception as e:
210
                    saved_exception = e
211
                    logger.debug(e)
212
        
213
        keyfiles = []
214
        if look_for_keys and 'publickey' in allowed:
215
            rsa_key = os.path.expanduser('~/.ssh/id_rsa')
216
            dsa_key = os.path.expanduser('~/.ssh/id_dsa')
217
            if os.path.isfile(rsa_key):
218
                keyfiles.append((paramiko.RSAKey, rsa_key))
219
            if os.path.isfile(dsa_key):
220
                keyfiles.append((paramiko.DSSKey, dsa_key))
221
            # look in ~/ssh/ for windows users:
222
            rsa_key = os.path.expanduser('~/ssh/id_rsa')
223
            dsa_key = os.path.expanduser('~/ssh/id_dsa')
224
            if os.path.isfile(rsa_key):
225
                keyfiles.append((paramiko.RSAKey, rsa_key))
226
            if os.path.isfile(dsa_key):
227
                keyfiles.append((paramiko.DSSKey, dsa_key))
228
        
229
        for cls, filename in keyfiles:
230
            try:
231
                key = cls.from_private_key_file(filename, password)
232
                logger.debug('Trying discovered key %s in %s' %
233
                          (hexlify(key.get_fingerprint()), filename))
234
                allowed = self._transport.auth_publickey(username, key)
235
                return
236
            except Exception as e:
237
                saved_exception = e
238
                logger.debug(e)
239
        
240
        if password is not None:
241
            try:
242
                self._transport.auth_password(username, password)
243
                return
244
            except Exception as e:
245
                saved_exception = e
246
                logger.debug(e)
247
        
248
        if saved_exception is not None:
249
            raise SSHAuthenticationError(saved_exception)
250
        
251
        raise SSHAuthenticationError('No authentication methods available')
112 252
    
113 253
    def run(self):
114 254
        chan = self._channel
......
119 259
                # select on a paramiko ssh channel object does not ever
120 260
                # return it in the writable list, so it does not exactly
121 261
                # emulate the socket api
122
                r, w, e = select([chan], [], [], 0.1)
262
                r, w, e = select([chan], [], [], TICK)
263
                # will wakeup evey TICK seconds to check if something
264
                # to send, more if something to read (due to select returning chan
265
                # in readable list)
123 266
                if r:
124 267
                    data = chan.recv(BUF_SIZE)
125 268
                    if data:
126
                        self._in_buf.write(data)
269
                        self._buffer.write(data)
127 270
                        self._fresh_data()
128 271
                    else:
129
                        raise SessionCloseError(self._in_buf.getvalue())
272
                        raise RemoteClosedError(self._buffer.getvalue())
130 273
                if not q.empty() and chan.send_ready():
131 274
                    data = q.get() + MSG_DELIM
132 275
                    while data:
133 276
                        n = chan.send(data)
134 277
                        if n <= 0:
135
                            raise SessionCloseError(self._in_buf.getvalue(), data)
278
                            raise RemoteClosedError(self._buffer.getvalue(), data)
136 279
                        data = data[n:]
137 280
        except Exception as e:
281
            self.close()
138 282
            logger.debug('*** broke out of main loop ***')
139 283
            self.dispatch('error', e)
284
    
285
    def set_keepalive(self, interval=0):
286
        self._transport.set_keepalive()

Also available in: Unified diff