Statistics
| Branch: | Tag: | Revision:

root / ncclient / transport / ssh.py @ 0b7d3b31

History | View | Annotate | Download (12.6 kB)

1
# Copyright 2009 Shikhar Bhushan
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#    http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

    
15
import os
16
import socket
17
from binascii import hexlify
18
from cStringIO import StringIO
19
from select import select
20

    
21
import paramiko
22

    
23
from errors import AuthenticationError, SessionCloseError, SSHError, SSHUnknownHostError
24
from session import Session
25

    
26
import logging
27
logger = logging.getLogger('ncclient.transport.ssh')
28

    
29
BUF_SIZE = 4096
30
MSG_DELIM = ']]>]]>'
31
TICK = 0.1
32

    
33
def default_unknown_host_cb(host, key):
34
    """An `unknown host callback` returns :const:`True` if it finds the key
35
    acceptable, and :const:`False` if not.
36

37
    This default callback always returns :const:`False`, which would lead to
38
    :meth:`connect` raising a :exc:`SSHUnknownHost` exception.
39

40
    Supply another valid callback if you need to verify the host key
41
    programatically.
42

43
    :arg host: the host for whom key needs to be verified
44
    :type host: string
45

46
    :arg key: a hex string representing the host key fingerprint
47
    :type key: string
48
    """
49
    return False
50

    
51

    
52
class SSHSession(Session):
53

    
54
    "Implements a :rfc:`4742` NETCONF session over SSH."
55

    
56
    def __init__(self, capabilities):
57
        Session.__init__(self, capabilities)
58
        self._host_keys = paramiko.HostKeys()
59
        self._transport = None
60
        self._connected = False
61
        self._channel = None
62
        self._expecting_close = False
63
        self._buffer = StringIO() # for incoming data
64
        # parsing-related, see _parse()
65
        self._parsing_state = 0
66
        self._parsing_pos = 0
67

    
68
    def _parse(self):
69
        '''Messages ae delimited by MSG_DELIM. The buffer could have grown by a
70
        maximum of BUF_SIZE bytes everytime this method is called. Retains state
71
        across method calls and if a byte has been read it will not be
72
        considered again. '''
73
        delim = MSG_DELIM
74
        n = len(delim) - 1
75
        expect = self._parsing_state
76
        buf = self._buffer
77
        buf.seek(self._parsing_pos)
78
        while True:
79
            x = buf.read(1)
80
            if not x: # done reading
81
                break
82
            elif x == delim[expect]: # what we expected
83
                expect += 1 # expect the next delim char
84
            else:
85
                continue
86
            # loop till last delim char expected, break if other char encountered
87
            for i in range(expect, n):
88
                x = buf.read(1)
89
                if not x: # done reading
90
                    break
91
                if x == delim[expect]: # what we expected
92
                    expect += 1 # expect the next delim char
93
                else:
94
                    expect = 0 # reset
95
                    break
96
            else: # if we didn't break out of the loop, full delim was parsed
97
                msg_till = buf.tell() - n
98
                buf.seek(0)
99
                logger.debug('parsed new message')
100
                self._dispatch_message(buf.read(msg_till).strip())
101
                buf.seek(n+1, os.SEEK_CUR)
102
                rest = buf.read()
103
                buf = StringIO()
104
                buf.write(rest)
105
                buf.seek(0)
106
                expect = 0
107
        self._buffer = buf
108
        self._parsing_state = expect
109
        self._parsing_pos = self._buffer.tell()
110

    
111
    def load_known_hosts(self, filename=None):
112
        """Load host keys from a :file:`known_hosts`-style file. Can be called multiple
113
        times.
114

115
        If *filename* is not specified, looks in the default locations i.e.
116
        :file:`~/.ssh/known_hosts` and :file:`~/ssh/known_hosts` for Windows.
117
        """
118
        if filename is None:
119
            filename = os.path.expanduser('~/.ssh/known_hosts')
120
            try:
121
                self._host_keys.load(filename)
122
            except IOError:
123
                # for windows
124
                filename = os.path.expanduser('~/ssh/known_hosts')
125
                try:
126
                    self._host_keys.load(filename)
127
                except IOError:
128
                    pass
129
        else:
130
            self._host_keys.load(filename)
131

    
132
    def close(self):
133
        self._expecting_close = True
134
        if self._transport.is_active():
135
            self._transport.close()
136
        self._connected = False
137

    
138
    def connect(self, host, port=830, timeout=None,
139
                unknown_host_cb=default_unknown_host_cb,
140
                username=None, password=None,
141
                key_filename=None, allow_agent=True, look_for_keys=True):
142
        """Connect via SSH and initialize the NETCONF session. First attempts
143
        the publickey authentication method and then password authentication.
144

145
        To disable attemting publickey authentication altogether, call with
146
        *allow_agent* and *look_for_keys* as :const:`False`. This may be needed
147
        for Cisco devices which immediately disconnect on an incorrect
148
        authentication attempt.
149

150
        :arg host: the hostname or IP address to connect to
151
        :type host: `string`
152

153
        :arg port: by default 830, but some devices use the default SSH port of 22 so this may need to be specified
154
        :type port: `int`
155

156
        :arg timeout: an optional timeout for the TCP handshake
157
        :type timeout: `int`
158

159
        :arg unknown_host_cb: called when a host key is not recognized
160
        :type unknown_host_cb: see :meth:`signature <ssh.default_unknown_host_cb>`
161

162
        :arg username: the username to use for SSH authentication
163
        :type username: `string`
164

165
        :arg password: the password used if using password authentication, or the passphrase to use for unlocking keys that require it
166
        :type password: `string`
167

168
        :arg key_filename: a filename where a the private key to be used can be found
169
        :type key_filename: `string`
170

171
        :arg allow_agent: enables querying SSH agent (if found) for keys
172
        :type allow_agent: `bool`
173

174
        :arg look_for_keys: enables looking in the usual locations for ssh keys (e.g. :file:`~/.ssh/id_*`)
175
        :type look_for_keys: `bool`
176
        """
177

    
178
        assert(username is not None)
179

    
180
        for (family, socktype, proto, canonname, sockaddr) in \
181
        socket.getaddrinfo(host, port):
182
            if socktype == socket.SOCK_STREAM:
183
                af = family
184
                addr = sockaddr
185
                break
186
        else:
187
            raise SSHError('No suitable address family for %s' % host)
188
        sock = socket.socket(af, socket.SOCK_STREAM)
189
        sock.settimeout(timeout)
190
        sock.connect(addr)
191
        t = self._transport = paramiko.Transport(sock)
192
        t.set_log_channel(logger.name)
193

    
194
        try:
195
            t.start_client()
196
        except paramiko.SSHException:
197
            raise SSHError('Negotiation failed')
198

    
199
        # host key verification
200
        server_key = t.get_remote_server_key()
201
        known_host = self._host_keys.check(host, server_key)
202

    
203
        fingerprint = hexlify(server_key.get_fingerprint())
204

    
205
        if not known_host and not unknown_host_cb(host, fingerprint):
206
            raise SSHUnknownHostError(host, fingerprint)
207

    
208
        if key_filename is None:
209
            key_filenames = []
210
        elif isinstance(key_filename, basestring):
211
            key_filenames = [ key_filename ]
212
        else:
213
            key_filenames = key_filename
214

    
215
        self._auth(username, password, key_filenames, allow_agent, look_for_keys)
216

    
217
        self._connected = True # there was no error authenticating
218

    
219
        c = self._channel = self._transport.open_session()
220
        c.set_name('netconf')
221
        c.invoke_subsystem('netconf')
222

    
223
        self._post_connect()
224

    
225
    # on the lines of paramiko.SSHClient._auth()
226
    def _auth(self, username, password, key_filenames, allow_agent,
227
              look_for_keys):
228
        saved_exception = None
229

    
230
        for key_filename in key_filenames:
231
            for cls in (paramiko.RSAKey, paramiko.DSSKey):
232
                try:
233
                    key = cls.from_private_key_file(key_filename, password)
234
                    logger.debug('Trying key %s from %s' %
235
                              (hexlify(key.get_fingerprint()), key_filename))
236
                    self._transport.auth_publickey(username, key)
237
                    return
238
                except Exception as e:
239
                    saved_exception = e
240
                    logger.debug(e)
241

    
242
        if allow_agent:
243
            for key in paramiko.Agent().get_keys():
244
                try:
245
                    logger.debug('Trying SSH agent key %s' %
246
                                 hexlify(key.get_fingerprint()))
247
                    self._transport.auth_publickey(username, key)
248
                    return
249
                except Exception as e:
250
                    saved_exception = e
251
                    logger.debug(e)
252

    
253
        keyfiles = []
254
        if look_for_keys:
255
            rsa_key = os.path.expanduser('~/.ssh/id_rsa')
256
            dsa_key = os.path.expanduser('~/.ssh/id_dsa')
257
            if os.path.isfile(rsa_key):
258
                keyfiles.append((paramiko.RSAKey, rsa_key))
259
            if os.path.isfile(dsa_key):
260
                keyfiles.append((paramiko.DSSKey, dsa_key))
261
            # look in ~/ssh/ for windows users:
262
            rsa_key = os.path.expanduser('~/ssh/id_rsa')
263
            dsa_key = os.path.expanduser('~/ssh/id_dsa')
264
            if os.path.isfile(rsa_key):
265
                keyfiles.append((paramiko.RSAKey, rsa_key))
266
            if os.path.isfile(dsa_key):
267
                keyfiles.append((paramiko.DSSKey, dsa_key))
268

    
269
        for cls, filename in keyfiles:
270
            try:
271
                key = cls.from_private_key_file(filename, password)
272
                logger.debug('Trying discovered key %s in %s' %
273
                          (hexlify(key.get_fingerprint()), filename))
274
                self._transport.auth_publickey(username, key)
275
                return
276
            except Exception as e:
277
                saved_exception = e
278
                logger.debug(e)
279

    
280
        if password is not None:
281
            try:
282
                self._transport.auth_password(username, password)
283
                return
284
            except Exception as e:
285
                saved_exception = e
286
                logger.debug(e)
287

    
288
        if saved_exception is not None:
289
            # need pep-3134 to do this right
290
            raise AuthenticationError(repr(saved_exception))
291

    
292
        raise AuthenticationError('No authentication methods available')
293

    
294
    def run(self):
295
        chan = self._channel
296
        chan.setblocking(0)
297
        q = self._q
298
        try:
299
            while True:
300
                # select on a paramiko ssh channel object does not ever return
301
                # it in the writable list, so it channel's don't exactly emulate
302
                # the socket api
303
                r, w, e = select([chan], [], [], TICK)
304
                # will wakeup evey TICK seconds to check if something
305
                # to send, more if something to read (due to select returning
306
                # chan in readable list)
307
                if r:
308
                    data = chan.recv(BUF_SIZE)
309
                    if data:
310
                        self._buffer.write(data)
311
                        self._parse()
312
                    else:
313
                        raise SessionCloseError(self._buffer.getvalue())
314
                if not q.empty() and chan.send_ready():
315
                    logger.debug('sending message')
316
                    data = q.get() + MSG_DELIM
317
                    while data:
318
                        n = chan.send(data)
319
                        if n <= 0:
320
                            raise SessionCloseError(self._buffer.getvalue(), data)
321
                        data = data[n:]
322
        except Exception as e:
323
            logger.debug('broke out of main loop')
324
            expecting = self._expecting_close
325
            self.close()
326
            logger.debug('error=%r' % e)
327
            logger.debug('expecting_close=%r' % expecting)
328
            if not (isinstance(e, SessionCloseError) and expecting):
329
                self._dispatch_error(e)
330

    
331
    @property
332
    def transport(self):
333
        """Underlying `paramiko.Transport
334
        <http://www.lag.net/paramiko/docs/paramiko.Transport-class.html>`_
335
        object. This makes it possible to call methods like set_keepalive on it.
336
        """
337
        return self._transport
338

    
339
    @property
340
    def can_pipeline(self):
341
        if 'Cisco' in self._transport.remote_version:
342
            return False
343
        # elif ..
344
        return True