Statistics
| Branch: | Tag: | Revision:

root / ncclient / transport / ssh.py @ d095a59e

History | View | Annotate | Download (10.2 kB)

1
# Copyright 2009 Shikhar Bhushan
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#    http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

    
15
import os
16
import socket
17
from binascii import hexlify
18
from cStringIO import StringIO
19
from select import select
20

    
21
import paramiko
22

    
23
import session
24
from . import logger
25
from error import SSHError, SSHUnknownHostError, SSHAuthenticationError, SSHSessionClosedError
26
from session import Session
27

    
28
BUF_SIZE = 4096
29
MSG_DELIM = ']]>]]>'
30
TICK = 0.1
31

    
32
class SSHSession(Session):
33

    
34
    def __init__(self):
35
        Session.__init__(self)
36
        self._system_host_keys = paramiko.HostKeys()
37
        self._host_keys = paramiko.HostKeys()
38
        self._host_keys_filename = None
39
        self._transport = None
40
        self._connected = False
41
        self._channel = None
42
        self._buffer = StringIO() # for incoming data
43
        # parsing-related, see _parse()
44
        self._parsing_state = 0 
45
        self._parsing_pos = 0
46
    
47
    def _parse(self):
48
        '''Messages ae delimited by MSG_DELIM. The buffer could have grown by a
49
        maximum of BUF_SIZE bytes everytime this method is called. Retains state
50
        across method calls and if a byte has been read it will not be considered
51
        again.
52
        '''
53
        delim = MSG_DELIM
54
        n = len(delim) - 1
55
        expect = self._parsing_state
56
        buf = self._buffer
57
        buf.seek(self._parsing_pos)
58
        while True:
59
            x = buf.read(1)
60
            if not x: # done reading
61
                break
62
            elif x == delim[expect]: # what we expected
63
                expect += 1 # expect the next delim char
64
            else:
65
                continue
66
            # loop till last delim char expected, break if other char encountered
67
            for i in range(expect, n):
68
                x = buf.read(1)
69
                if not x: # done reading
70
                    break
71
                if x == delim[expect]: # what we expected
72
                    expect += 1 # expect the next delim char
73
                else:
74
                    expect = 0 # reset
75
                    break
76
            else: # if we didn't break out of the loop, full delim was parsed
77
                msg_till = buf.tell() - n
78
                buf.seek(0)
79
                msg = buf.read(msg_till)
80
                self.dispatch('received', msg)
81
                buf.seek(n+1, os.SEEK_CUR)
82
                rest = buf.read()
83
                buf = StringIO()
84
                buf.write(rest)
85
                buf.seek(0)
86
                state = 0
87
        self._buffer = buf
88
        self._parsing_state = expect
89
        self._parsing_pos = self._buffer.tell()
90
    
91
    def load_system_host_keys(self, filename=None):
92
        if filename is None:
93
            # try the user's .ssh key file, and mask exceptions
94
            filename = os.path.expanduser('~/.ssh/known_hosts')
95
            try:
96
                self._system_host_keys.load(filename)
97
            except IOError:
98
                pass
99
            return
100
        self._system_host_keys.load(filename)
101
    
102
    def load_host_keys(self, filename):
103
        self._host_keys_filename = filename
104
        self._host_keys.load(filename)
105

    
106
    def add_host_key(self, key):
107
        self._host_keys.add(key)
108
    
109
    def save_host_keys(self, filename):
110
        f = open(filename, 'w')
111
        for hostname, keys in self._host_keys.iteritems():
112
            for keytype, key in keys.iteritems():
113
                f.write('%s %s %s\n' % (hostname, keytype, key.get_base64()))
114
        f.close()    
115
    
116
    def close(self):
117
        if self._transport.is_active():
118
            self._transport.close()
119
        self._connected = False
120
    
121
    def connect(self, hostname, port=830, timeout=None,
122
                unknown_host_cb=None, username=None, password=None,
123
                key_filename=None, allow_agent=True, look_for_keys=True):
124
        
125
        assert(username is not None)
126
        
127
        for (family, socktype, proto, canonname, sockaddr) in \
128
        socket.getaddrinfo(hostname, port):
129
            if socktype==socket.SOCK_STREAM:
130
                af = family
131
                addr = sockaddr
132
                break
133
        else:
134
            raise SSHError('No suitable address family for %s' % hostname)
135
        sock = socket.socket(af, socket.SOCK_STREAM)
136
        sock.settimeout(timeout)
137
        sock.connect(addr)
138
        t = self._transport = paramiko.Transport(sock)
139
        t.set_log_channel(logger.name)
140
        
141
        try:
142
            t.start_client()
143
        except paramiko.SSHException:
144
            raise SSHError('Negotiation failed')
145
        
146
        # host key verification
147
        server_key = t.get_remote_server_key()
148
        known_host = self._host_keys.check(hostname, server_key) or \
149
                        self._system_host_keys.check(hostname, server_key)
150
        
151
        if unknown_host_cb is None:
152
            unknown_host_cb = lambda *args: False
153
        if not known_host and not unknown_host_cb(hostname, server_key):
154
                raise SSHUnknownHostError(hostname, server_key)
155
        
156
        if key_filename is None:
157
            key_filenames = []
158
        elif isinstance(key_filename, basestring):
159
            key_filenames = [ key_filename ]
160
        else:
161
            key_filenames = key_filename
162
        
163
        self._auth(username, password, key_filenames, allow_agent, look_for_keys)
164
        
165
        self._connected = True # there was no error authenticating
166
        
167
        c = self._channel = self._transport.open_session()
168
        c.invoke_subsystem('netconf')
169
        c.set_name('netconf')
170
        
171
        self._post_connect()
172
    
173
    # on the lines of paramiko.SSHClient._auth()
174
    def _auth(self, username, password, key_filenames, allow_agent,
175
              look_for_keys):
176
        saved_exception = None
177
        
178
        for key_filename in key_filenames:
179
            for cls in (paramiko.RSAKey, paramiko.DSSKey):
180
                try:
181
                    key = cls.from_private_key_file(key_filename, password)
182
                    logger.debug('Trying key %s from %s' %
183
                              (hexlify(key.get_fingerprint()), key_filename))
184
                    self._transport.auth_publickey(username, key)
185
                    return
186
                except Exception as e:
187
                    saved_exception = e
188
                    logger.debug(e)
189
        
190
        if allow_agent:
191
            for key in paramiko.Agent().get_keys():
192
                try:
193
                    logger.debug('Trying SSH agent key %s' %
194
                                 hexlify(key.get_fingerprint()))
195
                    self._transport.auth_publickey(username, key)
196
                    return
197
                except Exception as e:
198
                    saved_exception = e
199
                    logger.debug(e)
200
        
201
        keyfiles = []
202
        if look_for_keys:
203
            rsa_key = os.path.expanduser('~/.ssh/id_rsa')
204
            dsa_key = os.path.expanduser('~/.ssh/id_dsa')
205
            if os.path.isfile(rsa_key):
206
                keyfiles.append((paramiko.RSAKey, rsa_key))
207
            if os.path.isfile(dsa_key):
208
                keyfiles.append((paramiko.DSSKey, dsa_key))
209
            # look in ~/ssh/ for windows users:
210
            rsa_key = os.path.expanduser('~/ssh/id_rsa')
211
            dsa_key = os.path.expanduser('~/ssh/id_dsa')
212
            if os.path.isfile(rsa_key):
213
                keyfiles.append((paramiko.RSAKey, rsa_key))
214
            if os.path.isfile(dsa_key):
215
                keyfiles.append((paramiko.DSSKey, dsa_key))
216
        
217
        for cls, filename in keyfiles:
218
            try:
219
                key = cls.from_private_key_file(filename, password)
220
                logger.debug('Trying discovered key %s in %s' %
221
                          (hexlify(key.get_fingerprint()), filename))
222
                self._transport.auth_publickey(username, key)
223
                return
224
            except Exception as e:
225
                saved_exception = e
226
                logger.debug(e)
227
        
228
        if password is not None:
229
            try:
230
                self._transport.auth_password(username, password)
231
                return
232
            except Exception as e:
233
                saved_exception = e
234
                logger.debug(e)
235
        
236
        if saved_exception is not None:
237
            raise AuthenticationError(repr(saved_exception))
238
        
239
        raise AuthenticationError('No authentication methods available')
240
    
241
    def run(self):
242
        chan = self._channel
243
        chan.setblocking(0)
244
        q = self._q
245
        try:
246
            while True:
247
                # select on a paramiko ssh channel object does not ever return
248
                # it in the writable list, so it channel's don't exactly emulate 
249
                # the socket api
250
                r, w, e = select([chan], [], [], TICK)
251
                # will wakeup evey TICK seconds to check if something
252
                # to send, more if something to read (due to select returning
253
                # chan in readable list)
254
                if r:
255
                    data = chan.recv(BUF_SIZE)
256
                    if data:
257
                        self._buffer.write(data)
258
                        self._parse()
259
                    else:
260
                        raise SSHSessionClosedError(self._buffer.getvalue())
261
                if not q.empty() and chan.send_ready():
262
                    data = q.get() + MSG_DELIM
263
                    while data:
264
                        n = chan.send(data)
265
                        if n <= 0:
266
                            raise SSHSessionClosedError(self._buffer.getvalue(), data)
267
                        data = data[n:]
268
        except Exception as e:
269
            self.close()
270
            logger.debug('*** broke out of main loop ***')
271
            self.dispatch('error', e)
272
    
273
    @property
274
    def transport(self):
275
        '''Get underlying paramiko.transport object; this is provided so methods
276
        like transport.set_keepalive can be called.
277
        '''
278
        return self._transport