Statistics
| Branch: | Tag: | Revision:

root / ncclient / transport / ssh.py @ 179b00d4

History | View | Annotate | Download (10.8 kB)

1
# Copyright 2009 Shikhar Bhushan
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#    http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

    
15
import os
16
import socket
17
from binascii import hexlify
18
from cStringIO import StringIO
19
from select import select
20

    
21
import paramiko
22

    
23
from errors import AuthenticationError, SessionCloseError, SSHError, SSHUnknownHostError
24
from session import Session
25

    
26
import logging
27
logger = logging.getLogger('ncclient.transport.ssh')
28

    
29
BUF_SIZE = 4096
30
MSG_DELIM = ']]>]]>'
31
TICK = 0.1
32

    
33
class SSHSession(Session):
34
    
35
    def __init__(self, *args, **kwds):
36
        Session.__init__(self, *args, **kwds)
37
        self._host_keys = paramiko.HostKeys()
38
        self._system_host_keys = paramiko.HostKeys()
39
        self._transport = None
40
        self._connected = False
41
        self._channel = None
42
        self._expecting_close = False
43
        self._buffer = StringIO() # for incoming data
44
        # parsing-related, see _parse()
45
        self._parsing_state = 0 
46
        self._parsing_pos = 0
47
    
48
    def _parse(self):
49
        '''Messages ae delimited by MSG_DELIM. The buffer could have grown by a
50
        maximum of BUF_SIZE bytes everytime this method is called. Retains state
51
        across method calls and if a byte has been read it will not be considered
52
        again.
53
        '''
54
        delim = MSG_DELIM
55
        n = len(delim) - 1
56
        expect = self._parsing_state
57
        buf = self._buffer
58
        buf.seek(self._parsing_pos)
59
        while True:
60
            x = buf.read(1)
61
            if not x: # done reading
62
                break
63
            elif x == delim[expect]: # what we expected
64
                expect += 1 # expect the next delim char
65
            else:
66
                continue
67
            # loop till last delim char expected, break if other char encountered
68
            for i in range(expect, n):
69
                x = buf.read(1)
70
                if not x: # done reading
71
                    break
72
                if x == delim[expect]: # what we expected
73
                    expect += 1 # expect the next delim char
74
                else:
75
                    expect = 0 # reset
76
                    break
77
            else: # if we didn't break out of the loop, full delim was parsed
78
                msg_till = buf.tell() - n
79
                buf.seek(0)
80
                logger.debug('parsed new message')
81
                self._dispatch_message(buf.read(msg_till).strip())
82
                buf.seek(n+1, os.SEEK_CUR)
83
                rest = buf.read()
84
                buf = StringIO()
85
                buf.write(rest)
86
                buf.seek(0)
87
                expect = 0
88
        self._buffer = buf
89
        self._parsing_state = expect
90
        self._parsing_pos = self._buffer.tell()
91
    
92
    def expect_close(self):
93
        self._expecting_close = True
94
    
95
    def load_system_host_keys(self, filename=None):
96
        if filename is None:
97
            filename = os.path.expanduser('~/.ssh/known_hosts')
98
            try:
99
                self._system_host_keys.load(filename)
100
            except IOError:
101
                # for windows
102
                filename = os.path.expanduser('~/ssh/known_hosts')
103
                try:
104
                    self._system_host_keys.load(filename)
105
                except IOError:
106
                    pass
107
            return
108
        self._system_host_keys.load(filename)
109
    
110
    def load_host_keys(self, filename):
111
        self._host_keys.load(filename)
112

    
113
    def add_host_key(self, key):
114
        self._host_keys.add(key)
115
    
116
    def save_host_keys(self, filename):
117
        f = open(filename, 'w')
118
        for hostname, keys in self._host_keys.iteritems():
119
            for keytype, key in keys.iteritems():
120
                f.write('%s %s %s\n' % (hostname, keytype, key.get_base64()))
121
        f.close()    
122
    
123
    def close(self):
124
        self.expect_close()
125
        if self._transport.is_active():
126
            self._transport.close()
127
        self._connected = False
128
    
129
    def connect(self, hostname, port=830, timeout=None,
130
                unknown_host_cb=None, username=None, password=None,
131
                key_filename=None, allow_agent=True, look_for_keys=True):
132
        
133
        assert(username is not None)
134
        
135
        for (family, socktype, proto, canonname, sockaddr) in \
136
        socket.getaddrinfo(hostname, port):
137
            if socktype==socket.SOCK_STREAM:
138
                af = family
139
                addr = sockaddr
140
                break
141
        else:
142
            raise SSHError('No suitable address family for %s' % hostname)
143
        sock = socket.socket(af, socket.SOCK_STREAM)
144
        sock.settimeout(timeout)
145
        sock.connect(addr)
146
        t = self._transport = paramiko.Transport(sock)
147
        t.set_log_channel(logger.name)
148
        
149
        try:
150
            t.start_client()
151
        except paramiko.SSHException:
152
            raise SSHError('Negotiation failed')
153
        
154
        # host key verification
155
        server_key = t.get_remote_server_key()
156
        known_host = self._host_keys.check(hostname, server_key) or \
157
                        self._system_host_keys.check(hostname, server_key)
158
        
159
        if unknown_host_cb is None:
160
            unknown_host_cb = lambda *args: False
161
        if not known_host and not unknown_host_cb(hostname, server_key):
162
                raise SSHUnknownHostError(hostname, server_key)
163
        
164
        if key_filename is None:
165
            key_filenames = []
166
        elif isinstance(key_filename, basestring):
167
            key_filenames = [ key_filename ]
168
        else:
169
            key_filenames = key_filename
170
        
171
        self._auth(username, password, key_filenames, allow_agent, look_for_keys)
172
        
173
        self._connected = True # there was no error authenticating
174
        
175
        c = self._channel = self._transport.open_session()
176
        c.set_name('netconf')
177
        c.invoke_subsystem('netconf')
178
        
179
        self._post_connect()
180
    
181
    # on the lines of paramiko.SSHClient._auth()
182
    def _auth(self, username, password, key_filenames, allow_agent,
183
              look_for_keys):
184
        saved_exception = None
185
        
186
        for key_filename in key_filenames:
187
            for cls in (paramiko.RSAKey, paramiko.DSSKey):
188
                try:
189
                    key = cls.from_private_key_file(key_filename, password)
190
                    logger.debug('Trying key %s from %s' %
191
                              (hexlify(key.get_fingerprint()), key_filename))
192
                    self._transport.auth_publickey(username, key)
193
                    return
194
                except Exception as e:
195
                    saved_exception = e
196
                    logger.debug(e)
197
        
198
        if allow_agent:
199
            for key in paramiko.Agent().get_keys():
200
                try:
201
                    logger.debug('Trying SSH agent key %s' %
202
                                 hexlify(key.get_fingerprint()))
203
                    self._transport.auth_publickey(username, key)
204
                    return
205
                except Exception as e:
206
                    saved_exception = e
207
                    logger.debug(e)
208
        
209
        keyfiles = []
210
        if look_for_keys:
211
            rsa_key = os.path.expanduser('~/.ssh/id_rsa')
212
            dsa_key = os.path.expanduser('~/.ssh/id_dsa')
213
            if os.path.isfile(rsa_key):
214
                keyfiles.append((paramiko.RSAKey, rsa_key))
215
            if os.path.isfile(dsa_key):
216
                keyfiles.append((paramiko.DSSKey, dsa_key))
217
            # look in ~/ssh/ for windows users:
218
            rsa_key = os.path.expanduser('~/ssh/id_rsa')
219
            dsa_key = os.path.expanduser('~/ssh/id_dsa')
220
            if os.path.isfile(rsa_key):
221
                keyfiles.append((paramiko.RSAKey, rsa_key))
222
            if os.path.isfile(dsa_key):
223
                keyfiles.append((paramiko.DSSKey, dsa_key))
224
        
225
        for cls, filename in keyfiles:
226
            try:
227
                key = cls.from_private_key_file(filename, password)
228
                logger.debug('Trying discovered key %s in %s' %
229
                          (hexlify(key.get_fingerprint()), filename))
230
                self._transport.auth_publickey(username, key)
231
                return
232
            except Exception as e:
233
                saved_exception = e
234
                logger.debug(e)
235
        
236
        if password is not None:
237
            try:
238
                self._transport.auth_password(username, password)
239
                return
240
            except Exception as e:
241
                saved_exception = e
242
                logger.debug(e)
243
        
244
        if saved_exception is not None:
245
            # need pep-3134 to do this right
246
            raise SSHAuthenticationError(repr(saved_exception))
247
        
248
        raise SSHAuthenticationError('No authentication methods available')
249
    
250
    def run(self):
251
        chan = self._channel
252
        chan.setblocking(0)
253
        q = self._q
254
        try:
255
            while True:
256
                # select on a paramiko ssh channel object does not ever return
257
                # it in the writable list, so it channel's don't exactly emulate 
258
                # the socket api
259
                r, w, e = select([chan], [], [], TICK)
260
                # will wakeup evey TICK seconds to check if something
261
                # to send, more if something to read (due to select returning
262
                # chan in readable list)
263
                if r:
264
                    data = chan.recv(BUF_SIZE)
265
                    if data:
266
                        self._buffer.write(data)
267
                        self._parse()
268
                    else:
269
                        raise SessionCloseError(self._buffer.getvalue())
270
                if not q.empty() and chan.send_ready():
271
                    logger.debug('sending message')
272
                    data = q.get() + MSG_DELIM
273
                    while data:
274
                        n = chan.send(data)
275
                        if n <= 0:
276
                            raise SessionCloseError(self._buffer.getvalue(), data)
277
                        data = data[n:]
278
        except Exception as e:
279
            logger.debug('broke out of main loop')
280
            self.close()
281
            if not (isinstance(e, SessionCloseError) and self._expecting_close):
282
                self._dispatch_error(e)
283
    
284
    @property
285
    def transport(self):
286
        '''Get underlying paramiko transport object; this is provided so methods
287
        like set_keepalive can be called on it. See paramiko.Transport
288
        documentation for details.
289
        '''
290
        return self._transport
291
    
292
    @property
293
    def can_pipeline(self):
294
        if 'Cisco' in self._transport.remote_version:
295
            return False
296
        # elif ..
297
        return True