Revision 5a684638

/dev/null
1
# Copyright 2009 Shikhar Bhushan
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#    http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

  
15
import logging
16
logger = logging.getLogger('ncclient.session')
/dev/null
1
# Copyright 2009 Shikhar Bhushan
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#    http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

  
15
class Capabilities:
16
    
17
    def __init__(self, capabilities=None):
18
        self._dict = {}
19
        if isinstance(capabilities, dict):
20
            self._dict = capabilities
21
        elif isinstance(capabilities, list):
22
            for uri in capabilities:
23
                self._dict[uri] = Capabilities.guess_shorthand(uri)
24
    
25
    def __contains__(self, key):
26
        return ( key in self._dict ) or ( key in self._dict.values() )
27
    
28
    def __iter__(self):
29
        return self._dict.keys().__iter__()
30
    
31
    def __repr__(self):
32
        return repr(self._dict.keys())
33
    
34
    def add(self, uri, shorthand=None):
35
        if shorthand is None:
36
            shorthand = Capabilities.guess_shorthand(uri)
37
        self._dict[uri] = shorthand
38
    
39
    set = add
40
    
41
    def remove(self, key):
42
        if key in self._dict:
43
            del self._dict[key]
44
        else:
45
            for uri in self._dict:
46
                if self._dict[uri] == key:
47
                    del self._dict[uri]
48
                    break
49
    
50
    @staticmethod
51
    def guess_shorthand(uri):
52
        if uri.startswith('urn:ietf:params:netconf:capability:'):
53
            return (':' + uri.split(':')[5])
54

  
55

  
56
CAPABILITIES = Capabilities([
57
    'urn:ietf:params:netconf:base:1.0',
58
    'urn:ietf:params:netconf:capability:writable-running:1.0',
59
    'urn:ietf:params:netconf:capability:candidate:1.0',
60
    'urn:ietf:params:netconf:capability:confirmed-commit:1.0',
61
    'urn:ietf:params:netconf:capability:rollback-on-error:1.0',
62
    'urn:ietf:params:netconf:capability:startup:1.0',
63
    'urn:ietf:params:netconf:capability:url:1.0',
64
    'urn:ietf:params:netconf:capability:validate:1.0',
65
    'urn:ietf:params:netconf:capability:xpath:1.0',
66
    'urn:ietf:params:netconf:capability:notification:1.0',
67
    'urn:ietf:params:netconf:capability:interleave:1.0'
68
    ])
69

  
70
if __name__ == "__main__":
71
    assert(':validate' in CAPABILITIES) # test __contains__
/dev/null
1
# Copyright 2009 Shikhar Bhushan
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#    http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

  
15
from ncclient import ClientError
16

  
17
class SessionError(ClientError):
18
    pass
19

  
20
class SSHError(SessionError):
21
    pass
22

  
23
class SSHUnknownHostError(SSHError):
24
    
25
    def __init__(self, hostname, key):
26
        self.hostname = hostname
27
        self.key = key
28
    
29
    def __str__(self):
30
        from binascii import hexlify
31
        return ('Unknown host key [%s] for [%s]' %
32
                (hexlify(self.key.get_fingerprint()), self.hostname))
33

  
34
class SSHAuthenticationError(SSHError):
35
    pass
36

  
37
class SSHSessionClosedError(SSHError):
38
    
39
    def __init__(self, in_buf, out_buf=None):
40
        SessionError.__init__(self, "Unexpected session close.")
41
        self._in_buf, self._out_buf = in_buf, out_buf
42
        
43
    def __str__(self):
44
        msg = SessionError(self).__str__()
45
        if self._in_buf:
46
            msg += '\nIN_BUFFER: %s' % self._in_buf
47
        if self._out_buf:
48
            msg += '\nOUT_BUFFER: %s' % self._out_buf
49
        return msg
/dev/null
1
# Copyright 2009 Shikhar Bhushan
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#    http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

  
15
from threading import Thread, Lock, Event
16
from Queue import Queue
17

  
18
from . import logger
19
from capabilities import Capabilities, CAPABILITIES
20

  
21

  
22
class Subject:
23

  
24
    def __init__(self):
25
        self._listeners = set([])
26
        self._lock = Lock()
27
        
28
    def has_listener(self, listener):
29
        with self._lock:
30
            return (listener in self._listeners)
31
    
32
    def add_listener(self, listener):
33
        with self._lock:
34
            self._listeners.add(listener)
35
    
36
    def remove_listener(self, listener):
37
        with self._lock:
38
            self._listeners.discard(listener)
39
    
40
    def dispatch(self, event, *args, **kwds):
41
        # holding the lock while doing callbacks could lead to a deadlock
42
        # if one of the above methods is called
43
        with self._lock:
44
            listeners = list(self._listeners)
45
        for l in listeners:
46
            try:
47
                logger.debug('dispatching [%s] to [%s]' % (event, l))
48
                getattr(l, event)(*args, **kwds)
49
            except Exception as e:
50
                pass # if a listener doesn't care for some event we don't care
51

  
52

  
53
class Session(Thread, Subject):
54
    
55
    def __init__(self):
56
        Thread.__init__(self, name='session')
57
        Subject.__init__(self)
58
        self._client_capabilities = CAPABILITIES
59
        self._server_capabilities = None # yet
60
        self._id = None # session-id
61
        self._q = Queue()
62
        self._connected = False # to be set/cleared by subclass implementation
63
    
64
    def _post_connect(self):
65
        from ncclient.content.builders import HelloBuilder
66
        self.send(HelloBuilder.build(self._client_capabilities))
67
        error = None
68
        init_event = Event()
69
        def ok_cb(id, capabilities):
70
            self._id, self._capabilities = id, Capabilities(capabilities)
71
            init_event.set()
72
        def err_cb(err):
73
            error = err
74
            init_event.set()
75
        listener = HelloListener(ok_cb, err_cb)
76
        self.add_listener(listener)
77
        # start the subclass' main loop
78
        self.start()        
79
        # we expect server's hello message
80
        init_event.wait()
81
        # received hello message or an error happened
82
        self.remove_listener(listener)
83
        if error:
84
            raise error
85
        logger.debug('initialized:session-id:%s' % self._id)
86
    
87
    def send(self, message):
88
        logger.debug('queueing:%s' % message)
89
        self._q.put(message)
90
    
91
    def connect(self):
92
        raise NotImplementedError
93

  
94
    def run(self):
95
        raise NotImplementedError
96
        
97
    def capabilities(self, whose='client'):
98
        if whose == 'client':
99
            return self._client_capabilities
100
        elif whose == 'server':
101
            return self._server_capabilities
102
    
103
    ### Properties
104
    
105
    @property
106
    def client_capabilities(self):
107
        return self._client_capabilities
108
    
109
    @property
110
    def server_capabilities(self):
111
        return self._server_capabilities
112
    
113
    @property
114
    def connected(self):
115
        return self._connected
116
    
117
    @property
118
    def id(self):
119
        return self._id
120

  
121

  
122
class HelloListener:
123
    
124
    def __init__(self, init_cb, error_cb):
125
        self._init_cb, self._error_cb = init_cb, error_cb
126
    
127
    def __str__(self):
128
        return 'HelloListener'
129
    
130
    ### Events
131
    
132
    def received(self, raw):
133
        logger.debug(raw)
134
        from ncclient.content.parsers import HelloParser
135
        try:
136
            id, capabilities = HelloParser.parse(raw)
137
        except Exception as e:
138
            self._error_cb(e)
139
        else:
140
            self._init_cb(id, capabilities)
141
    
142
    def error(self, err):
143
        self._error_cb(err)
144

  
145

  
146
class DebugListener:
147
    
148
    def __str__(self):
149
        return 'DebugListener'
150
    
151
    def received(self, raw):
152
        logger.debug('DebugListener:[received]:%s' % raw)
153
    
154
    def error(self, err):
155
        logger.debug('DebugListener:[error]:%s' % err)
/dev/null
1
# Copyright 2009 Shikhar Bhushan
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#    http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

  
15
import os
16
import socket
17
from binascii import hexlify
18
from cStringIO import StringIO
19
from select import select
20

  
21
import paramiko
22

  
23
import session
24
from . import logger
25
from error import SSHError, SSHUnknownHostError, SSHAuthenticationError, SSHSessionClosedError
26
from session import Session
27

  
28
BUF_SIZE = 4096
29
MSG_DELIM = ']]>]]>'
30
TICK = 0.1
31

  
32
class SSHSession(Session):
33

  
34
    def __init__(self):
35
        Session.__init__(self)
36
        self._system_host_keys = paramiko.HostKeys()
37
        self._host_keys = paramiko.HostKeys()
38
        self._host_keys_filename = None
39
        self._transport = None
40
        self._connected = False
41
        self._channel = None
42
        self._buffer = StringIO() # for incoming data
43
        # parsing-related, see _parse()
44
        self._parsing_state = 0 
45
        self._parsing_pos = 0
46
    
47
    def _parse(self):
48
        '''Messages ae delimited by MSG_DELIM. The buffer could have grown by a
49
        maximum of BUF_SIZE bytes everytime this method is called. Retains state
50
        across method calls and if a byte has been read it will not be considered
51
        again.
52
        '''
53
        delim = MSG_DELIM
54
        n = len(delim) - 1
55
        expect = self._parsing_state
56
        buf = self._buffer
57
        buf.seek(self._parsing_pos)
58
        while True:
59
            x = buf.read(1)
60
            if not x: # done reading
61
                break
62
            elif x == delim[expect]: # what we expected
63
                expect += 1 # expect the next delim char
64
            else:
65
                continue
66
            # loop till last delim char expected, break if other char encountered
67
            for i in range(expect, n):
68
                x = buf.read(1)
69
                if not x: # done reading
70
                    break
71
                if x == delim[expect]: # what we expected
72
                    expect += 1 # expect the next delim char
73
                else:
74
                    expect = 0 # reset
75
                    break
76
            else: # if we didn't break out of the loop, full delim was parsed
77
                msg_till = buf.tell() - n
78
                buf.seek(0)
79
                msg = buf.read(msg_till)
80
                self.dispatch('received', msg)
81
                buf.seek(n+1, os.SEEK_CUR)
82
                rest = buf.read()
83
                buf = StringIO()
84
                buf.write(rest)
85
                buf.seek(0)
86
                state = 0
87
        self._buffer = buf
88
        self._parsing_state = expect
89
        self._parsing_pos = self._buffer.tell()
90
    
91
    def load_system_host_keys(self, filename=None):
92
        if filename is None:
93
            # try the user's .ssh key file, and mask exceptions
94
            filename = os.path.expanduser('~/.ssh/known_hosts')
95
            try:
96
                self._system_host_keys.load(filename)
97
            except IOError:
98
                pass
99
            return
100
        self._system_host_keys.load(filename)
101
    
102
    def load_host_keys(self, filename):
103
        self._host_keys_filename = filename
104
        self._host_keys.load(filename)
105

  
106
    def add_host_key(self, key):
107
        self._host_keys.add(key)
108
    
109
    def save_host_keys(self, filename):
110
        f = open(filename, 'w')
111
        for hostname, keys in self._host_keys.iteritems():
112
            for keytype, key in keys.iteritems():
113
                f.write('%s %s %s\n' % (hostname, keytype, key.get_base64()))
114
        f.close()    
115
    
116
    def close(self):
117
        if self._transport.is_active():
118
            self._transport.close()
119
        self._connected = False
120
    
121
    def connect(self, hostname, port=830, timeout=None,
122
                unknown_host_cb=None, username=None, password=None,
123
                key_filename=None, allow_agent=True, look_for_keys=True):
124
        
125
        assert(username is not None)
126
        
127
        for (family, socktype, proto, canonname, sockaddr) in \
128
        socket.getaddrinfo(hostname, port):
129
            if socktype==socket.SOCK_STREAM:
130
                af = family
131
                addr = sockaddr
132
                break
133
        else:
134
            raise SSHError('No suitable address family for %s' % hostname)
135
        sock = socket.socket(af, socket.SOCK_STREAM)
136
        sock.settimeout(timeout)
137
        sock.connect(addr)
138
        t = self._transport = paramiko.Transport(sock)
139
        t.set_log_channel(logger.name)
140
        
141
        try:
142
            t.start_client()
143
        except paramiko.SSHException:
144
            raise SSHError('Negotiation failed')
145
        
146
        # host key verification
147
        server_key = t.get_remote_server_key()
148
        known_host = self._host_keys.check(hostname, server_key) or \
149
                        self._system_host_keys.check(hostname, server_key)
150
        
151
        if unknown_host_cb is None:
152
            unknown_host_cb = lambda *args: False
153
        if not known_host and not unknown_host_cb(hostname, server_key):
154
                raise SSHUnknownHostError(hostname, server_key)
155
        
156
        if key_filename is None:
157
            key_filenames = []
158
        elif isinstance(key_filename, basestring):
159
            key_filenames = [ key_filename ]
160
        else:
161
            key_filenames = key_filename
162
        
163
        self._auth(username, password, key_filenames, allow_agent, look_for_keys)
164
        
165
        self._connected = True # there was no error authenticating
166
        
167
        c = self._channel = self._transport.open_session()
168
        c.invoke_subsystem('netconf')
169
        c.set_name('netconf')
170
        
171
        self._post_connect()
172
    
173
    # on the lines of paramiko.SSHClient._auth()
174
    def _auth(self, username, password, key_filenames, allow_agent,
175
              look_for_keys):
176
        saved_exception = None
177
        
178
        for key_filename in key_filenames:
179
            for cls in (paramiko.RSAKey, paramiko.DSSKey):
180
                try:
181
                    key = cls.from_private_key_file(key_filename, password)
182
                    logger.debug('Trying key %s from %s' %
183
                              (hexlify(key.get_fingerprint()), key_filename))
184
                    self._transport.auth_publickey(username, key)
185
                    return
186
                except Exception as e:
187
                    saved_exception = e
188
                    logger.debug(e)
189
        
190
        if allow_agent:
191
            for key in paramiko.Agent().get_keys():
192
                try:
193
                    logger.debug('Trying SSH agent key %s' %
194
                                 hexlify(key.get_fingerprint()))
195
                    self._transport.auth_publickey(username, key)
196
                    return
197
                except Exception as e:
198
                    saved_exception = e
199
                    logger.debug(e)
200
        
201
        keyfiles = []
202
        if look_for_keys:
203
            rsa_key = os.path.expanduser('~/.ssh/id_rsa')
204
            dsa_key = os.path.expanduser('~/.ssh/id_dsa')
205
            if os.path.isfile(rsa_key):
206
                keyfiles.append((paramiko.RSAKey, rsa_key))
207
            if os.path.isfile(dsa_key):
208
                keyfiles.append((paramiko.DSSKey, dsa_key))
209
            # look in ~/ssh/ for windows users:
210
            rsa_key = os.path.expanduser('~/ssh/id_rsa')
211
            dsa_key = os.path.expanduser('~/ssh/id_dsa')
212
            if os.path.isfile(rsa_key):
213
                keyfiles.append((paramiko.RSAKey, rsa_key))
214
            if os.path.isfile(dsa_key):
215
                keyfiles.append((paramiko.DSSKey, dsa_key))
216
        
217
        for cls, filename in keyfiles:
218
            try:
219
                key = cls.from_private_key_file(filename, password)
220
                logger.debug('Trying discovered key %s in %s' %
221
                          (hexlify(key.get_fingerprint()), filename))
222
                self._transport.auth_publickey(username, key)
223
                return
224
            except Exception as e:
225
                saved_exception = e
226
                logger.debug(e)
227
        
228
        if password is not None:
229
            try:
230
                self._transport.auth_password(username, password)
231
                return
232
            except Exception as e:
233
                saved_exception = e
234
                logger.debug(e)
235
        
236
        if saved_exception is not None:
237
            raise AuthenticationError(repr(saved_exception))
238
        
239
        raise AuthenticationError('No authentication methods available')
240
    
241
    def run(self):
242
        chan = self._channel
243
        chan.setblocking(0)
244
        q = self._q
245
        try:
246
            while True:
247
                # select on a paramiko ssh channel object does not ever return
248
                # it in the writable list, so it channel's don't exactly emulate 
249
                # the socket api
250
                r, w, e = select([chan], [], [], TICK)
251
                # will wakeup evey TICK seconds to check if something
252
                # to send, more if something to read (due to select returning
253
                # chan in readable list)
254
                if r:
255
                    data = chan.recv(BUF_SIZE)
256
                    if data:
257
                        self._buffer.write(data)
258
                        self._parse()
259
                    else:
260
                        raise SSHSessionClosedError(self._buffer.getvalue())
261
                if not q.empty() and chan.send_ready():
262
                    data = q.get() + MSG_DELIM
263
                    while data:
264
                        n = chan.send(data)
265
                        if n <= 0:
266
                            raise SSHSessionClosedError(self._buffer.getvalue(), data)
267
                        data = data[n:]
268
        except Exception as e:
269
            self.close()
270
            logger.debug('*** broke out of main loop ***')
271
            self.dispatch('error', e)
272
    
273
    @property
274
    def transport(self):
275
        '''Get underlying paramiko.transport object; this is provided so methods
276
        like transport.set_keepalive can be called.
277
        '''
278
        return self._transport

Also available in: Unified diff