Statistics
| Branch: | Tag: | Revision:

root / ncclient / transport / ssh.py @ 94803aaf

History | View | Annotate | Download (10.8 kB)

1
# Copyright 2009 Shikhar Bhushan
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#    http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

    
15
import os
16
import socket
17
from binascii import hexlify
18
from cStringIO import StringIO
19
from select import select
20

    
21
import paramiko
22

    
23
from errors import AuthenticationError, SessionCloseError, SSHError, SSHUnknownHostError
24
from session import Session
25

    
26
import logging
27
logger = logging.getLogger('ncclient.transport.ssh')
28

    
29
BUF_SIZE = 4096
30
MSG_DELIM = ']]>]]>'
31
TICK = 0.1
32

    
33
class SSHSession(Session):
34
    
35
    def __init__(self):
36
        Session.__init__(self)
37
        self._host_keys = paramiko.HostKeys()
38
        self._system_host_keys = paramiko.HostKeys()
39
        self._transport = None
40
        self._connected = False
41
        self._channel = None
42
        self._expecting_close = False
43
        self._buffer = StringIO() # for incoming data
44
        # parsing-related, see _parse()
45
        self._parsing_state = 0 
46
        self._parsing_pos = 0
47
    
48
    def _parse(self):
49
        '''Messages ae delimited by MSG_DELIM. The buffer could have grown by a
50
        maximum of BUF_SIZE bytes everytime this method is called. Retains state
51
        across method calls and if a byte has been read it will not be considered
52
        again.
53
        '''
54
        delim = MSG_DELIM
55
        n = len(delim) - 1
56
        expect = self._parsing_state
57
        buf = self._buffer
58
        buf.seek(self._parsing_pos)
59
        while True:
60
            x = buf.read(1)
61
            if not x: # done reading
62
                break
63
            elif x == delim[expect]: # what we expected
64
                expect += 1 # expect the next delim char
65
            else:
66
                continue
67
            # loop till last delim char expected, break if other char encountered
68
            for i in range(expect, n):
69
                x = buf.read(1)
70
                if not x: # done reading
71
                    break
72
                if x == delim[expect]: # what we expected
73
                    expect += 1 # expect the next delim char
74
                else:
75
                    expect = 0 # reset
76
                    break
77
            else: # if we didn't break out of the loop, full delim was parsed
78
                msg_till = buf.tell() - n
79
                buf.seek(0)
80
                logger.debug('parsed new message')
81
                self._dispatch_message(buf.read(msg_till).strip())
82
                buf.seek(n+1, os.SEEK_CUR)
83
                rest = buf.read()
84
                buf = StringIO()
85
                buf.write(rest)
86
                buf.seek(0)
87
                expect = 0
88
        self._buffer = buf
89
        self._parsing_state = expect
90
        self._parsing_pos = self._buffer.tell()
91
    
92
    def expect_close(self):
93
        self._expecting_close = True
94
    
95
    def load_system_host_keys(self, filename=None):
96
        if filename is None:
97
            filename = os.path.expanduser('~/.ssh/known_hosts')
98
            try:
99
                self._system_host_keys.load(filename)
100
            except IOError:
101
                # for windows
102
                filename = os.path.expanduser('~/ssh/known_hosts')
103
                try:
104
                    self._system_host_keys.load(filename)
105
                except IOError:
106
                    pass
107
            return
108
        self._system_host_keys.load(filename)
109
    
110
    def load_host_keys(self, filename):
111
        self._host_keys.load(filename)
112

    
113
    def add_host_key(self, key):
114
        self._host_keys.add(key)
115
    
116
    def save_host_keys(self, filename):
117
        f = open(filename, 'w')
118
        for hostname, keys in self._host_keys.iteritems():
119
            for keytype, key in keys.iteritems():
120
                f.write('%s %s %s\n' % (hostname, keytype, key.get_base64()))
121
        f.close()    
122
    
123
    def close(self):
124
        if self._transport.is_active():
125
            self._transport.close()
126
        self._connected = False
127
    
128
    def connect(self, hostname, port=830, timeout=None,
129
                unknown_host_cb=None, username=None, password=None,
130
                key_filename=None, allow_agent=True, look_for_keys=True):
131
        
132
        assert(username is not None)
133
        
134
        for (family, socktype, proto, canonname, sockaddr) in \
135
        socket.getaddrinfo(hostname, port):
136
            if socktype==socket.SOCK_STREAM:
137
                af = family
138
                addr = sockaddr
139
                break
140
        else:
141
            raise SSHError('No suitable address family for %s' % hostname)
142
        sock = socket.socket(af, socket.SOCK_STREAM)
143
        sock.settimeout(timeout)
144
        sock.connect(addr)
145
        t = self._transport = paramiko.Transport(sock)
146
        t.set_log_channel(logger.name)
147
        
148
        try:
149
            t.start_client()
150
        except paramiko.SSHException:
151
            raise SSHError('Negotiation failed')
152
        
153
        # host key verification
154
        server_key = t.get_remote_server_key()
155
        known_host = self._host_keys.check(hostname, server_key) or \
156
                        self._system_host_keys.check(hostname, server_key)
157
        
158
        if unknown_host_cb is None:
159
            unknown_host_cb = lambda *args: False
160
        if not known_host and not unknown_host_cb(hostname, server_key):
161
                raise SSHUnknownHostError(hostname, server_key)
162
        
163
        if key_filename is None:
164
            key_filenames = []
165
        elif isinstance(key_filename, basestring):
166
            key_filenames = [ key_filename ]
167
        else:
168
            key_filenames = key_filename
169
        
170
        self._auth(username, password, key_filenames, allow_agent, look_for_keys)
171
        
172
        self._connected = True # there was no error authenticating
173
        
174
        c = self._channel = self._transport.open_session()
175
        c.set_name('netconf')
176
        c.invoke_subsystem('netconf')
177
        
178
        self._post_connect()
179
    
180
    # on the lines of paramiko.SSHClient._auth()
181
    def _auth(self, username, password, key_filenames, allow_agent,
182
              look_for_keys):
183
        saved_exception = None
184
        
185
        for key_filename in key_filenames:
186
            for cls in (paramiko.RSAKey, paramiko.DSSKey):
187
                try:
188
                    key = cls.from_private_key_file(key_filename, password)
189
                    logger.debug('Trying key %s from %s' %
190
                              (hexlify(key.get_fingerprint()), key_filename))
191
                    self._transport.auth_publickey(username, key)
192
                    return
193
                except Exception as e:
194
                    saved_exception = e
195
                    logger.debug(e)
196
        
197
        if allow_agent:
198
            for key in paramiko.Agent().get_keys():
199
                try:
200
                    logger.debug('Trying SSH agent key %s' %
201
                                 hexlify(key.get_fingerprint()))
202
                    self._transport.auth_publickey(username, key)
203
                    return
204
                except Exception as e:
205
                    saved_exception = e
206
                    logger.debug(e)
207
        
208
        keyfiles = []
209
        if look_for_keys:
210
            rsa_key = os.path.expanduser('~/.ssh/id_rsa')
211
            dsa_key = os.path.expanduser('~/.ssh/id_dsa')
212
            if os.path.isfile(rsa_key):
213
                keyfiles.append((paramiko.RSAKey, rsa_key))
214
            if os.path.isfile(dsa_key):
215
                keyfiles.append((paramiko.DSSKey, dsa_key))
216
            # look in ~/ssh/ for windows users:
217
            rsa_key = os.path.expanduser('~/ssh/id_rsa')
218
            dsa_key = os.path.expanduser('~/ssh/id_dsa')
219
            if os.path.isfile(rsa_key):
220
                keyfiles.append((paramiko.RSAKey, rsa_key))
221
            if os.path.isfile(dsa_key):
222
                keyfiles.append((paramiko.DSSKey, dsa_key))
223
        
224
        for cls, filename in keyfiles:
225
            try:
226
                key = cls.from_private_key_file(filename, password)
227
                logger.debug('Trying discovered key %s in %s' %
228
                          (hexlify(key.get_fingerprint()), filename))
229
                self._transport.auth_publickey(username, key)
230
                return
231
            except Exception as e:
232
                saved_exception = e
233
                logger.debug(e)
234
        
235
        if password is not None:
236
            try:
237
                self._transport.auth_password(username, password)
238
                return
239
            except Exception as e:
240
                saved_exception = e
241
                logger.debug(e)
242
        
243
        if saved_exception is not None:
244
            # need pep-3134 to do this right
245
            raise SSHAuthenticationError(repr(saved_exception))
246
        
247
        raise SSHAuthenticationError('No authentication methods available')
248
    
249
    def run(self):
250
        chan = self._channel
251
        chan.setblocking(0)
252
        q = self._q
253
        try:
254
            while True:
255
                # select on a paramiko ssh channel object does not ever return
256
                # it in the writable list, so it channel's don't exactly emulate 
257
                # the socket api
258
                r, w, e = select([chan], [], [], TICK)
259
                # will wakeup evey TICK seconds to check if something
260
                # to send, more if something to read (due to select returning
261
                # chan in readable list)
262
                if r:
263
                    data = chan.recv(BUF_SIZE)
264
                    if data:
265
                        self._buffer.write(data)
266
                        self._parse()
267
                    else:
268
                        raise SessionCloseError(self._buffer.getvalue())
269
                if not q.empty() and chan.send_ready():
270
                    logger.debug('sending message')
271
                    data = q.get() + MSG_DELIM
272
                    while data:
273
                        n = chan.send(data)
274
                        if n <= 0:
275
                            raise SessionCloseError(self._buffer.getvalue(), data)
276
                        data = data[n:]
277
        except Exception as e:
278
            logger.debug('broke out of main loop')
279
            self.close()
280
            if not (isinstance(e, SessionCloseError) and self._expecting_close):
281
                self._dispatch_error(e)
282
    
283
    @property
284
    def transport(self):
285
        '''Get underlying paramiko transport object; this is provided so methods
286
        like set_keepalive can be called on it. See paramiko.Transport
287
        documentation for details.
288
        '''
289
        return self._transport
290
    
291
    @property
292
    def can_pipeline(self):
293
        if 'Cisco' in self._transport.remote_version:
294
            return False
295
        # elif ..
296
        return True