Statistics
| Branch: | Tag: | Revision:

root / ncclient / transport / ssh.py @ bb700ea5

History | View | Annotate | Download (12.7 kB)

1
# Copyright 2009 Shikhar Bhushan
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#    http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

    
15
import os
16
import socket
17
from binascii import hexlify
18
from cStringIO import StringIO
19
from select import select
20

    
21
import paramiko
22

    
23
from errors import AuthenticationError, SessionCloseError, SSHError, SSHUnknownHostError
24
from session import Session
25

    
26
import logging
27
logger = logging.getLogger('ncclient.transport.ssh')
28

    
29
BUF_SIZE = 4096
30
MSG_DELIM = ']]>]]>'
31
TICK = 0.1
32

    
33
def default_unknown_host_cb(host, key):
34
    """An `unknown host callback` returns :const:`True` if it finds the key
35
    acceptable, and :const:`False` if not.
36

37
    This default callback always returns :const:`False`, which would lead to
38
    :meth:`connect` raising a :exc:`SSHUnknownHost` exception.
39

40
    Supply another valid callback if you need to verify the host key
41
    programatically.
42

43
    :arg host: the host for whom key needs to be verified
44
    :type host: string
45

46
    :arg key: a hex string representing the host key fingerprint
47
    :type key: string
48
    """
49
    return False
50

    
51

    
52
class SSHSession(Session):
53

    
54
    "Implements a :rfc:`4742` NETCONF session over SSH."
55

    
56
    def __init__(self, capabilities):
57
        Session.__init__(self, capabilities)
58
        self._host_keys = paramiko.HostKeys()
59
        self._transport = None
60
        self._connected = False
61
        self._channel = None
62
        self._expecting_close = False
63
        self._buffer = StringIO() # for incoming data
64
        # parsing-related, see _parse()
65
        self._parsing_state = 0
66
        self._parsing_pos = 0
67

    
68
    def _parse(self):
69
        '''Messages ae delimited by MSG_DELIM. The buffer could have grown by a
70
        maximum of BUF_SIZE bytes everytime this method is called. Retains state
71
        across method calls and if a byte has been read it will not be
72
        considered again. '''
73
        delim = MSG_DELIM
74
        n = len(delim) - 1
75
        expect = self._parsing_state
76
        buf = self._buffer
77
        buf.seek(self._parsing_pos)
78
        while True:
79
            x = buf.read(1)
80
            if not x: # done reading
81
                break
82
            elif x == delim[expect]: # what we expected
83
                expect += 1 # expect the next delim char
84
            else:
85
                continue
86
            # loop till last delim char expected, break if other char encountered
87
            for i in range(expect, n):
88
                x = buf.read(1)
89
                if not x: # done reading
90
                    break
91
                if x == delim[expect]: # what we expected
92
                    expect += 1 # expect the next delim char
93
                else:
94
                    expect = 0 # reset
95
                    break
96
            else: # if we didn't break out of the loop, full delim was parsed
97
                msg_till = buf.tell() - n
98
                buf.seek(0)
99
                logger.debug('parsed new message')
100
                self._dispatch_message(buf.read(msg_till).strip())
101
                buf.seek(n+1, os.SEEK_CUR)
102
                rest = buf.read()
103
                buf = StringIO()
104
                buf.write(rest)
105
                buf.seek(0)
106
                expect = 0
107
        self._buffer = buf
108
        self._parsing_state = expect
109
        self._parsing_pos = self._buffer.tell()
110

    
111
    def load_known_hosts(self, filename=None):
112
        """Load host keys from a :file:`known_hosts`-style file. Can be called multiple
113
        times.
114

115
        If *filename* is not specified, looks in the default locations i.e.
116
        :file:`~/.ssh/known_hosts` and :file:`~/ssh/known_hosts` for Windows.
117
        """
118
        if filename is None:
119
            filename = os.path.expanduser('~/.ssh/known_hosts')
120
            try:
121
                self._host_keys.load(filename)
122
            except IOError:
123
                # for windows
124
                filename = os.path.expanduser('~/ssh/known_hosts')
125
                try:
126
                    self._host_keys.load(filename)
127
                except IOError:
128
                    pass
129
        else:
130
            self._host_keys.load(filename)
131

    
132
    def close(self):
133
        self._expecting_close = True
134
        if self._transport.is_active():
135
            self._transport.close()
136
        self._connected = False
137

    
138
    def connect(self, host, port=830, timeout=None,
139
                unknown_host_cb=default_unknown_host_cb,
140
                username=None, password=None,
141
                key_filename=None, allow_agent=True, look_for_keys=True):
142
        """Connect via SSH and initialize the NETCONF session. First attempts
143
        the publickey authentication method and then password authentication.
144

145
        To disable attemting publickey authentication altogether, call with
146
        *allow_agent* and *look_for_keys* as :const:`False`. This may be needed
147
        for Cisco devices which immediately disconnect on an incorrect
148
        authentication attempt.
149

150
        :arg host: the hostname or IP address to connect to
151
        :type host: `string`
152

153
        :arg port: by default 830, but some devices use the default SSH port of 22 so this may need to be specified
154
        :type port: `int`
155

156
        :arg timeout: an optional timeout for the TCP handshake
157
        :type timeout: `int`
158

159
        :arg unknown_host_cb: called when a host key is not recognized
160
        :type unknown_host_cb: see :meth:`signature <ssh.default_unknown_host_cb>`
161

162
        :arg username: the username to use for SSH authentication
163
        :type username: `string`
164

165
        :arg password: the password used if using password authentication, or the passphrase to use for unlocking keys that require it
166
        :type password: `string`
167

168
        :arg key_filename: a filename where a the private key to be used can be found
169
        :type key_filename: `string`
170

171
        :arg allow_agent: enables querying SSH agent (if found) for keys
172
        :type allow_agent: `bool`
173

174
        :arg look_for_keys: enables looking in the usual locations for ssh keys (e.g. :file:`~/.ssh/id_*`)
175
        :type look_for_keys: `bool`
176
        """
177

    
178
        if username is None:
179
            raise SSHError("No username specified")
180

    
181
        sock = None
182
        for res in socket.getaddrinfo(host, port, socket.AF_UNSPEC, socket.SOCK_STREAM):
183
            af, socktype, proto, canonname, sa = res
184
            try:
185
                sock = socket.socket(af, socktype, proto)
186
                sock.settimeout(timeout)
187
            except socket.error:
188
                continue
189
            try:
190
                sock.connect(sa)
191
            except socket.error:
192
                sock.close()
193
                continue
194
            break
195
        else:
196
            raise SSHError("Could not open socket")
197

    
198
        t = self._transport = paramiko.Transport(sock)
199
        t.set_log_channel(logger.name)
200

    
201
        try:
202
            t.start_client()
203
        except paramiko.SSHException:
204
            raise SSHError('Negotiation failed')
205

    
206
        # host key verification
207
        server_key = t.get_remote_server_key()
208
        known_host = self._host_keys.check(host, server_key)
209

    
210
        fingerprint = hexlify(server_key.get_fingerprint())
211

    
212
        if not known_host and not unknown_host_cb(host, fingerprint):
213
            raise SSHUnknownHostError(host, fingerprint)
214

    
215
        if key_filename is None:
216
            key_filenames = []
217
        elif isinstance(key_filename, basestring):
218
            key_filenames = [ key_filename ]
219
        else:
220
            key_filenames = key_filename
221

    
222
        self._auth(username, password, key_filenames, allow_agent, look_for_keys)
223

    
224
        self._connected = True # there was no error authenticating
225

    
226
        c = self._channel = self._transport.open_session()
227
        c.set_name('netconf')
228
        c.invoke_subsystem('netconf')
229

    
230
        self._post_connect()
231

    
232
    # on the lines of paramiko.SSHClient._auth()
233
    def _auth(self, username, password, key_filenames, allow_agent,
234
              look_for_keys):
235
        saved_exception = None
236

    
237
        for key_filename in key_filenames:
238
            for cls in (paramiko.RSAKey, paramiko.DSSKey):
239
                try:
240
                    key = cls.from_private_key_file(key_filename, password)
241
                    logger.debug('Trying key %s from %s' %
242
                              (hexlify(key.get_fingerprint()), key_filename))
243
                    self._transport.auth_publickey(username, key)
244
                    return
245
                except Exception as e:
246
                    saved_exception = e
247
                    logger.debug(e)
248

    
249
        if allow_agent:
250
            for key in paramiko.Agent().get_keys():
251
                try:
252
                    logger.debug('Trying SSH agent key %s' %
253
                                 hexlify(key.get_fingerprint()))
254
                    self._transport.auth_publickey(username, key)
255
                    return
256
                except Exception as e:
257
                    saved_exception = e
258
                    logger.debug(e)
259

    
260
        keyfiles = []
261
        if look_for_keys:
262
            rsa_key = os.path.expanduser('~/.ssh/id_rsa')
263
            dsa_key = os.path.expanduser('~/.ssh/id_dsa')
264
            if os.path.isfile(rsa_key):
265
                keyfiles.append((paramiko.RSAKey, rsa_key))
266
            if os.path.isfile(dsa_key):
267
                keyfiles.append((paramiko.DSSKey, dsa_key))
268
            # look in ~/ssh/ for windows users:
269
            rsa_key = os.path.expanduser('~/ssh/id_rsa')
270
            dsa_key = os.path.expanduser('~/ssh/id_dsa')
271
            if os.path.isfile(rsa_key):
272
                keyfiles.append((paramiko.RSAKey, rsa_key))
273
            if os.path.isfile(dsa_key):
274
                keyfiles.append((paramiko.DSSKey, dsa_key))
275

    
276
        for cls, filename in keyfiles:
277
            try:
278
                key = cls.from_private_key_file(filename, password)
279
                logger.debug('Trying discovered key %s in %s' %
280
                          (hexlify(key.get_fingerprint()), filename))
281
                self._transport.auth_publickey(username, key)
282
                return
283
            except Exception as e:
284
                saved_exception = e
285
                logger.debug(e)
286

    
287
        if password is not None:
288
            try:
289
                self._transport.auth_password(username, password)
290
                return
291
            except Exception as e:
292
                saved_exception = e
293
                logger.debug(e)
294

    
295
        if saved_exception is not None:
296
            # need pep-3134 to do this right
297
            raise AuthenticationError(repr(saved_exception))
298

    
299
        raise AuthenticationError('No authentication methods available')
300

    
301
    def run(self):
302
        chan = self._channel
303
        chan.setblocking(0)
304
        q = self._q
305
        try:
306
            while True:
307
                # select on a paramiko ssh channel object does not ever return
308
                # it in the writable list, so it channel's don't exactly emulate
309
                # the socket api
310
                r, w, e = select([chan], [], [], TICK)
311
                # will wakeup evey TICK seconds to check if something
312
                # to send, more if something to read (due to select returning
313
                # chan in readable list)
314
                if r:
315
                    data = chan.recv(BUF_SIZE)
316
                    if data:
317
                        self._buffer.write(data)
318
                        self._parse()
319
                    else:
320
                        raise SessionCloseError(self._buffer.getvalue())
321
                if not q.empty() and chan.send_ready():
322
                    logger.debug('sending message')
323
                    data = q.get() + MSG_DELIM
324
                    while data:
325
                        n = chan.send(data)
326
                        if n <= 0:
327
                            raise SessionCloseError(self._buffer.getvalue(), data)
328
                        data = data[n:]
329
        except Exception as e:
330
            logger.debug('broke out of main loop')
331
            expecting = self._expecting_close
332
            self.close()
333
            logger.debug('error=%r' % e)
334
            logger.debug('expecting_close=%r' % expecting)
335
            if not (isinstance(e, SessionCloseError) and expecting):
336
                self._dispatch_error(e)
337

    
338
    @property
339
    def transport(self):
340
        """Underlying `paramiko.Transport
341
        <http://www.lag.net/paramiko/docs/paramiko.Transport-class.html>`_
342
        object. This makes it possible to call methods like set_keepalive on it.
343
        """
344
        return self._transport
345

    
346
    @property
347
    def can_pipeline(self):
348
        if 'Cisco' in self._transport.remote_version:
349
            return False
350
        # elif ..
351
        return True