Statistics
| Branch: | Tag: | Revision:

root / ncclient / transport / ssh.py @ 4ba5e843

History | View | Annotate | Download (10.3 kB)

1
# Copyright 2009 Shikhar Bhushan
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#    http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

    
15
import os
16
import socket
17
from binascii import hexlify
18
from cStringIO import StringIO
19
from select import select
20

    
21
import paramiko
22

    
23
import session
24
from . import logger
25
from error import SSHError, SSHUnknownHostError, SSHAuthenticationError, SSHSessionClosedError
26
from session import Session
27

    
28
BUF_SIZE = 4096
29
MSG_DELIM = ']]>]]>'
30
TICK = 0.1
31

    
32
class SSHSession(Session):
33

    
34
    def __init__(self):
35
        Session.__init__(self)
36
        self._host_keys = paramiko.HostKeys()
37
        self._system_host_keys = paramiko.HostKeys()
38
        self._transport = None
39
        self._connected = False
40
        self._channel = None
41
        self._buffer = StringIO() # for incoming data
42
        # parsing-related, see _parse()
43
        self._parsing_state = 0 
44
        self._parsing_pos = 0
45
    
46
    def _parse(self):
47
        '''Messages ae delimited by MSG_DELIM. The buffer could have grown by a
48
        maximum of BUF_SIZE bytes everytime this method is called. Retains state
49
        across method calls and if a byte has been read it will not be considered
50
        again.
51
        '''
52
        delim = MSG_DELIM
53
        n = len(delim) - 1
54
        expect = self._parsing_state
55
        buf = self._buffer
56
        buf.seek(self._parsing_pos)
57
        while True:
58
            x = buf.read(1)
59
            if not x: # done reading
60
                break
61
            elif x == delim[expect]: # what we expected
62
                expect += 1 # expect the next delim char
63
            else:
64
                continue
65
            # loop till last delim char expected, break if other char encountered
66
            for i in range(expect, n):
67
                x = buf.read(1)
68
                if not x: # done reading
69
                    break
70
                if x == delim[expect]: # what we expected
71
                    expect += 1 # expect the next delim char
72
                else:
73
                    expect = 0 # reset
74
                    break
75
            else: # if we didn't break out of the loop, full delim was parsed
76
                msg_till = buf.tell() - n
77
                buf.seek(0)
78
                msg = buf.read(msg_till)
79
                self.dispatch('received', msg)
80
                buf.seek(n+1, os.SEEK_CUR)
81
                rest = buf.read()
82
                buf = StringIO()
83
                buf.write(rest)
84
                buf.seek(0)
85
                expect = 0
86
        self._buffer = buf
87
        self._parsing_state = expect
88
        self._parsing_pos = self._buffer.tell()
89
    
90
    def load_system_host_keys(self, filename=None):
91
        if filename is None:
92
            filename = os.path.expanduser('~/.ssh/known_hosts')
93
            try:
94
                self._system_host_keys.load(filename)
95
            except IOError:
96
                # for windows
97
                filename = os.path.expanduser('~/ssh/known_hosts')
98
                try:
99
                    self._system_host_keys.load(filename)
100
                except IOError:
101
                    pass
102
            return
103
        self._system_host_keys.load(filename)
104
    
105
    def load_host_keys(self, filename):
106
        self._host_keys.load(filename)
107

    
108
    def add_host_key(self, key):
109
        self._host_keys.add(key)
110
    
111
    def save_host_keys(self, filename):
112
        f = open(filename, 'w')
113
        for hostname, keys in self._host_keys.iteritems():
114
            for keytype, key in keys.iteritems():
115
                f.write('%s %s %s\n' % (hostname, keytype, key.get_base64()))
116
        f.close()    
117
    
118
    def close(self):
119
        if self._transport.is_active():
120
            self._transport.close()
121
        self._connected = False
122
    
123
    def connect(self, hostname, port=830, timeout=None,
124
                unknown_host_cb=None, username=None, password=None,
125
                key_filename=None, allow_agent=True, look_for_keys=True):
126
        
127
        assert(username is not None)
128
        
129
        for (family, socktype, proto, canonname, sockaddr) in \
130
        socket.getaddrinfo(hostname, port):
131
            if socktype==socket.SOCK_STREAM:
132
                af = family
133
                addr = sockaddr
134
                break
135
        else:
136
            raise SSHError('No suitable address family for %s' % hostname)
137
        sock = socket.socket(af, socket.SOCK_STREAM)
138
        sock.settimeout(timeout)
139
        sock.connect(addr)
140
        t = self._transport = paramiko.Transport(sock)
141
        t.set_log_channel(logger.name)
142
        
143
        try:
144
            t.start_client()
145
        except paramiko.SSHException:
146
            raise SSHError('Negotiation failed')
147
        
148
        # host key verification
149
        server_key = t.get_remote_server_key()
150
        known_host = self._host_keys.check(hostname, server_key) or \
151
                        self._system_host_keys.check(hostname, server_key)
152
        
153
        if unknown_host_cb is None:
154
            unknown_host_cb = lambda *args: False
155
        if not known_host and not unknown_host_cb(hostname, server_key):
156
                raise SSHUnknownHostError(hostname, server_key)
157
        
158
        if key_filename is None:
159
            key_filenames = []
160
        elif isinstance(key_filename, basestring):
161
            key_filenames = [ key_filename ]
162
        else:
163
            key_filenames = key_filename
164
        
165
        self._auth(username, password, key_filenames, allow_agent, look_for_keys)
166
        
167
        self._connected = True # there was no error authenticating
168
        
169
        c = self._channel = self._transport.open_session()
170
        c.invoke_subsystem('netconf')
171
        c.set_name('netconf')
172
        
173
        self._post_connect()
174
    
175
    # on the lines of paramiko.SSHClient._auth()
176
    def _auth(self, username, password, key_filenames, allow_agent,
177
              look_for_keys):
178
        saved_exception = None
179
        
180
        for key_filename in key_filenames:
181
            for cls in (paramiko.RSAKey, paramiko.DSSKey):
182
                try:
183
                    key = cls.from_private_key_file(key_filename, password)
184
                    logger.debug('Trying key %s from %s' %
185
                              (hexlify(key.get_fingerprint()), key_filename))
186
                    self._transport.auth_publickey(username, key)
187
                    return
188
                except Exception as e:
189
                    saved_exception = e
190
                    logger.debug(e)
191
        
192
        if allow_agent:
193
            for key in paramiko.Agent().get_keys():
194
                try:
195
                    logger.debug('Trying SSH agent key %s' %
196
                                 hexlify(key.get_fingerprint()))
197
                    self._transport.auth_publickey(username, key)
198
                    return
199
                except Exception as e:
200
                    saved_exception = e
201
                    logger.debug(e)
202
        
203
        keyfiles = []
204
        if look_for_keys:
205
            rsa_key = os.path.expanduser('~/.ssh/id_rsa')
206
            dsa_key = os.path.expanduser('~/.ssh/id_dsa')
207
            if os.path.isfile(rsa_key):
208
                keyfiles.append((paramiko.RSAKey, rsa_key))
209
            if os.path.isfile(dsa_key):
210
                keyfiles.append((paramiko.DSSKey, dsa_key))
211
            # look in ~/ssh/ for windows users:
212
            rsa_key = os.path.expanduser('~/ssh/id_rsa')
213
            dsa_key = os.path.expanduser('~/ssh/id_dsa')
214
            if os.path.isfile(rsa_key):
215
                keyfiles.append((paramiko.RSAKey, rsa_key))
216
            if os.path.isfile(dsa_key):
217
                keyfiles.append((paramiko.DSSKey, dsa_key))
218
        
219
        for cls, filename in keyfiles:
220
            try:
221
                key = cls.from_private_key_file(filename, password)
222
                logger.debug('Trying discovered key %s in %s' %
223
                          (hexlify(key.get_fingerprint()), filename))
224
                self._transport.auth_publickey(username, key)
225
                return
226
            except Exception as e:
227
                saved_exception = e
228
                logger.debug(e)
229
        
230
        if password is not None:
231
            try:
232
                self._transport.auth_password(username, password)
233
                return
234
            except Exception as e:
235
                saved_exception = e
236
                logger.debug(e)
237
        
238
        if saved_exception is not None:
239
            raise SSHAuthenticationError(repr(saved_exception))
240
        
241
        raise SSHAuthenticationError('No authentication methods available')
242
    
243
    def run(self):
244
        chan = self._channel
245
        chan.setblocking(0)
246
        q = self._q
247
        try:
248
            while True:
249
                # select on a paramiko ssh channel object does not ever return
250
                # it in the writable list, so it channel's don't exactly emulate 
251
                # the socket api
252
                r, w, e = select([chan], [], [], TICK)
253
                # will wakeup evey TICK seconds to check if something
254
                # to send, more if something to read (due to select returning
255
                # chan in readable list)
256
                if r:
257
                    data = chan.recv(BUF_SIZE)
258
                    if data:
259
                        self._buffer.write(data)
260
                        self._parse()
261
                    else:
262
                        raise SSHSessionClosedError(self._buffer.getvalue())
263
                if not q.empty() and chan.send_ready():
264
                    data = q.get() + MSG_DELIM
265
                    while data:
266
                        n = chan.send(data)
267
                        if n <= 0:
268
                            raise SSHSessionClosedError(self._buffer.getvalue(), data)
269
                        data = data[n:]
270
        except Exception as e:
271
            self.close()
272
            logger.debug('*** broke out of main loop ***')
273
            self.dispatch('error', e)
274
    
275
    @property
276
    def transport(self):
277
        '''Get underlying paramiko.transport object; this is provided so methods
278
        like transport.set_keepalive can be called.
279
        '''
280
        return self._transport