Statistics
| Branch: | Tag: | Revision:

root / ncclient / transport / ssh.py @ 1d540e60

History | View | Annotate | Download (10.5 kB)

1
# Copyright 2009 Shikhar Bhushan
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#    http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

    
15
import os
16
import socket
17
from binascii import hexlify
18
from cStringIO import StringIO
19
from select import select
20

    
21
import paramiko
22

    
23
from . import logger
24
from errors import AuthenticationError, SessionCloseError, SSHError, SSHUnknownHostError
25
from session import Session
26

    
27
BUF_SIZE = 4096
28
MSG_DELIM = ']]>]]>'
29
TICK = 0.1
30

    
31
class SSHSession(Session):
32
    
33
    def __init__(self):
34
        Session.__init__(self)
35
        self._host_keys = paramiko.HostKeys()
36
        self._system_host_keys = paramiko.HostKeys()
37
        self._transport = None
38
        self._connected = False
39
        self._channel = None
40
        self._expecting_close = False
41
        self._buffer = StringIO() # for incoming data
42
        # parsing-related, see _parse()
43
        self._parsing_state = 0 
44
        self._parsing_pos = 0
45
        logger.debug('[SSHSession object created]')
46
    
47
    def _parse(self):
48
        '''Messages ae delimited by MSG_DELIM. The buffer could have grown by a
49
        maximum of BUF_SIZE bytes everytime this method is called. Retains state
50
        across method calls and if a byte has been read it will not be considered
51
        again.
52
        '''
53
        delim = MSG_DELIM
54
        n = len(delim) - 1
55
        expect = self._parsing_state
56
        buf = self._buffer
57
        buf.seek(self._parsing_pos)
58
        while True:
59
            x = buf.read(1)
60
            if not x: # done reading
61
                break
62
            elif x == delim[expect]: # what we expected
63
                expect += 1 # expect the next delim char
64
            else:
65
                continue
66
            # loop till last delim char expected, break if other char encountered
67
            for i in range(expect, n):
68
                x = buf.read(1)
69
                if not x: # done reading
70
                    break
71
                if x == delim[expect]: # what we expected
72
                    expect += 1 # expect the next delim char
73
                else:
74
                    expect = 0 # reset
75
                    break
76
            else: # if we didn't break out of the loop, full delim was parsed
77
                msg_till = buf.tell() - n
78
                buf.seek(0)
79
                self._dispatch_received(buf.read(msg_till).strip())
80
                buf.seek(n+1, os.SEEK_CUR)
81
                rest = buf.read()
82
                buf = StringIO()
83
                buf.write(rest)
84
                buf.seek(0)
85
                expect = 0
86
        self._buffer = buf
87
        self._parsing_state = expect
88
        self._parsing_pos = self._buffer.tell()
89
    
90
    def expect_close(self):
91
        self._expecting_close = True
92
    
93
    def load_system_host_keys(self, filename=None):
94
        if filename is None:
95
            filename = os.path.expanduser('~/.ssh/known_hosts')
96
            try:
97
                self._system_host_keys.load(filename)
98
            except IOError:
99
                # for windows
100
                filename = os.path.expanduser('~/ssh/known_hosts')
101
                try:
102
                    self._system_host_keys.load(filename)
103
                except IOError:
104
                    pass
105
            return
106
        self._system_host_keys.load(filename)
107
    
108
    def load_host_keys(self, filename):
109
        self._host_keys.load(filename)
110

    
111
    def add_host_key(self, key):
112
        self._host_keys.add(key)
113
    
114
    def save_host_keys(self, filename):
115
        f = open(filename, 'w')
116
        for hostname, keys in self._host_keys.iteritems():
117
            for keytype, key in keys.iteritems():
118
                f.write('%s %s %s\n' % (hostname, keytype, key.get_base64()))
119
        f.close()    
120
    
121
    def close(self):
122
        if self._transport.is_active():
123
            self._transport.close()
124
        self._connected = False
125
    
126
    def connect(self, hostname, port=830, timeout=None,
127
                unknown_host_cb=None, username=None, password=None,
128
                key_filename=None, allow_agent=True, look_for_keys=True):
129
        
130
        assert(username is not None)
131
        
132
        for (family, socktype, proto, canonname, sockaddr) in \
133
        socket.getaddrinfo(hostname, port):
134
            if socktype==socket.SOCK_STREAM:
135
                af = family
136
                addr = sockaddr
137
                break
138
        else:
139
            raise SSHError('No suitable address family for %s' % hostname)
140
        sock = socket.socket(af, socket.SOCK_STREAM)
141
        sock.settimeout(timeout)
142
        sock.connect(addr)
143
        t = self._transport = paramiko.Transport(sock)
144
        t.set_log_channel(logger.name)
145
        
146
        try:
147
            t.start_client()
148
        except paramiko.SSHException:
149
            raise SSHError('Negotiation failed')
150
        
151
        # host key verification
152
        server_key = t.get_remote_server_key()
153
        known_host = self._host_keys.check(hostname, server_key) or \
154
                        self._system_host_keys.check(hostname, server_key)
155
        
156
        if unknown_host_cb is None:
157
            unknown_host_cb = lambda *args: False
158
        if not known_host and not unknown_host_cb(hostname, server_key):
159
                raise SSHUnknownHostError(hostname, server_key)
160
        
161
        if key_filename is None:
162
            key_filenames = []
163
        elif isinstance(key_filename, basestring):
164
            key_filenames = [ key_filename ]
165
        else:
166
            key_filenames = key_filename
167
        
168
        self._auth(username, password, key_filenames, allow_agent, look_for_keys)
169
        
170
        self._connected = True # there was no error authenticating
171
        
172
        c = self._channel = self._transport.open_session()
173
        c.invoke_subsystem('netconf')
174
        c.set_name('netconf')
175
        
176
        self._post_connect()
177
    
178
    # on the lines of paramiko.SSHClient._auth()
179
    def _auth(self, username, password, key_filenames, allow_agent,
180
              look_for_keys):
181
        saved_exception = None
182
        
183
        for key_filename in key_filenames:
184
            for cls in (paramiko.RSAKey, paramiko.DSSKey):
185
                try:
186
                    key = cls.from_private_key_file(key_filename, password)
187
                    logger.debug('Trying key %s from %s' %
188
                              (hexlify(key.get_fingerprint()), key_filename))
189
                    self._transport.auth_publickey(username, key)
190
                    return
191
                except Exception as e:
192
                    saved_exception = e
193
                    logger.debug(e)
194
        
195
        if allow_agent:
196
            for key in paramiko.Agent().get_keys():
197
                try:
198
                    logger.debug('Trying SSH agent key %s' %
199
                                 hexlify(key.get_fingerprint()))
200
                    self._transport.auth_publickey(username, key)
201
                    return
202
                except Exception as e:
203
                    saved_exception = e
204
                    logger.debug(e)
205
        
206
        keyfiles = []
207
        if look_for_keys:
208
            rsa_key = os.path.expanduser('~/.ssh/id_rsa')
209
            dsa_key = os.path.expanduser('~/.ssh/id_dsa')
210
            if os.path.isfile(rsa_key):
211
                keyfiles.append((paramiko.RSAKey, rsa_key))
212
            if os.path.isfile(dsa_key):
213
                keyfiles.append((paramiko.DSSKey, dsa_key))
214
            # look in ~/ssh/ for windows users:
215
            rsa_key = os.path.expanduser('~/ssh/id_rsa')
216
            dsa_key = os.path.expanduser('~/ssh/id_dsa')
217
            if os.path.isfile(rsa_key):
218
                keyfiles.append((paramiko.RSAKey, rsa_key))
219
            if os.path.isfile(dsa_key):
220
                keyfiles.append((paramiko.DSSKey, dsa_key))
221
        
222
        for cls, filename in keyfiles:
223
            try:
224
                key = cls.from_private_key_file(filename, password)
225
                logger.debug('Trying discovered key %s in %s' %
226
                          (hexlify(key.get_fingerprint()), filename))
227
                self._transport.auth_publickey(username, key)
228
                return
229
            except Exception as e:
230
                saved_exception = e
231
                logger.debug(e)
232
        
233
        if password is not None:
234
            try:
235
                self._transport.auth_password(username, password)
236
                return
237
            except Exception as e:
238
                saved_exception = e
239
                logger.debug(e)
240
        
241
        if saved_exception is not None:
242
            raise SSHAuthenticationError(repr(saved_exception))
243
        
244
        raise SSHAuthenticationError('No authentication methods available')
245
    
246
    def run(self):
247
        chan = self._channel
248
        chan.setblocking(0)
249
        q = self._q
250
        try:
251
            while True:
252
                # select on a paramiko ssh channel object does not ever return
253
                # it in the writable list, so it channel's don't exactly emulate 
254
                # the socket api
255
                r, w, e = select([chan], [], [], TICK)
256
                # will wakeup evey TICK seconds to check if something
257
                # to send, more if something to read (due to select returning
258
                # chan in readable list)
259
                if r:
260
                    data = chan.recv(BUF_SIZE)
261
                    if data:
262
                        self._buffer.write(data)
263
                        self._parse()
264
                    else:
265
                        raise SessionCloseError(self._buffer.getvalue())
266
                if not q.empty() and chan.send_ready():
267
                    data = q.get() + MSG_DELIM
268
                    while data:
269
                        n = chan.send(data)
270
                        if n <= 0:
271
                            raise SessionCloseError(self._buffer.getvalue(), data)
272
                        data = data[n:]
273
        except Exception as e:
274
            logger.debug('*** broke out of main loop ***')
275
            self.close()
276
            if not (isinstance(e, SessionCloseError) and self._expecting_close):
277
                self._dispatch_error(e)
278
    
279
    @property
280
    def transport(self):
281
        '''Get underlying paramiko transport object; this is provided so methods
282
        like set_keepalive can be called on it. See paramiko.Transport
283
        documentation for details.
284
        '''
285
        return self._transport