Statistics
| Branch: | Tag: | Revision:

root / ncclient / transport / ssh.py @ 9a9af391

History | View | Annotate | Download (12.5 kB)

1
# Copyright 2009 Shikhar Bhushan
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#    http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

    
15
import os
16
import socket
17
from binascii import hexlify
18
from cStringIO import StringIO
19
from select import select
20

    
21
import paramiko
22

    
23
from errors import AuthenticationError, SessionCloseError, SSHError, SSHUnknownHostError
24
from session import Session
25

    
26
import logging
27
logger = logging.getLogger('ncclient.transport.ssh')
28

    
29
BUF_SIZE = 4096
30
MSG_DELIM = ']]>]]>'
31
TICK = 0.1
32

    
33
def default_unknown_host_cb(host, fingerprint):
34
    """An `unknown host callback` returns :const:`True` if it finds the key
35
    acceptable, and :const:`False` if not.
36

37
    This default callback always returns :const:`False`, which would lead to
38
    :meth:`connect` raising a :exc:`SSHUnknownHost` exception.
39

40
    Supply another valid callback if you need to verify the host key
41
    programatically.
42

43
    :arg host: the hostname that needs to be verified
44
    :type host: string
45

46
    :arg fingerprint: a hex string representing the host key fingerprint
47
    :type fingerprint: string
48
    """
49
    return False
50

    
51
def _colonify(fp):
52
    finga = fp[:2]
53
    for idx  in range(2, len(fp), 2):
54
        finga += ":" + fp[idx:idx+2]
55
    return finga
56

    
57
class SSHSession(Session):
58

    
59
    "Implements a :rfc:`4742` NETCONF session over SSH."
60

    
61
    def __init__(self, capabilities):
62
        Session.__init__(self, capabilities)
63
        self._host_keys = paramiko.HostKeys()
64
        self._transport = None
65
        self._connected = False
66
        self._channel = None
67
        self._buffer = StringIO() # for incoming data
68
        # parsing-related, see _parse()
69
        self._parsing_state = 0
70
        self._parsing_pos = 0
71
    
72
    def _parse(self):
73
        '''Messages ae delimited by MSG_DELIM. The buffer could have grown by a
74
        maximum of BUF_SIZE bytes everytime this method is called. Retains state
75
        across method calls and if a byte has been read it will not be
76
        considered again. '''
77
        delim = MSG_DELIM
78
        n = len(delim) - 1
79
        expect = self._parsing_state
80
        buf = self._buffer
81
        buf.seek(self._parsing_pos)
82
        while True:
83
            x = buf.read(1)
84
            if not x: # done reading
85
                break
86
            elif x == delim[expect]: # what we expected
87
                expect += 1 # expect the next delim char
88
            else:
89
                continue
90
            # loop till last delim char expected, break if other char encountered
91
            for i in range(expect, n):
92
                x = buf.read(1)
93
                if not x: # done reading
94
                    break
95
                if x == delim[expect]: # what we expected
96
                    expect += 1 # expect the next delim char
97
                else:
98
                    expect = 0 # reset
99
                    break
100
            else: # if we didn't break out of the loop, full delim was parsed
101
                msg_till = buf.tell() - n
102
                buf.seek(0)
103
                logger.debug('parsed new message')
104
                self._dispatch_message(buf.read(msg_till).strip())
105
                buf.seek(n+1, os.SEEK_CUR)
106
                rest = buf.read()
107
                buf = StringIO()
108
                buf.write(rest)
109
                buf.seek(0)
110
                expect = 0
111
        self._buffer = buf
112
        self._parsing_state = expect
113
        self._parsing_pos = self._buffer.tell()
114

    
115
    def load_known_hosts(self, filename=None):
116
        """Load host keys from a :file:`known_hosts`-style file. Can be called multiple
117
        times.
118

119
        If *filename* is not specified, looks in the default locations i.e.
120
        :file:`~/.ssh/known_hosts` and :file:`~/ssh/known_hosts` for Windows.
121
        """
122
        if filename is None:
123
            filename = os.path.expanduser('~/.ssh/known_hosts')
124
            try:
125
                self._host_keys.load(filename)
126
            except IOError:
127
                # for windows
128
                filename = os.path.expanduser('~/ssh/known_hosts')
129
                try:
130
                    self._host_keys.load(filename)
131
                except IOError:
132
                    pass
133
        else:
134
            self._host_keys.load(filename)
135

    
136
    def close(self):
137
        if self._transport.is_active():
138
            self._transport.close()
139
        self._connected = False
140

    
141
    def connect(self, host, port=830, timeout=None,
142
                unknown_host_cb=default_unknown_host_cb,
143
                username=None, password=None,
144
                key_filename=None, allow_agent=True, look_for_keys=True):
145
        """Connect via SSH and initialize the NETCONF session. First attempts
146
        the publickey authentication method and then password authentication.
147

148
        To disable attemting publickey authentication altogether, call with
149
        *allow_agent* and *look_for_keys* as :const:`False`. This may be needed
150
        for Cisco devices which immediately disconnect on an incorrect
151
        authentication attempt.
152

153
        :arg host: the hostname or IP address to connect to
154
        :type host: `string`
155

156
        :arg port: by default 830, but some devices use the default SSH port of 22 so this may need to be specified
157
        :type port: `int`
158

159
        :arg timeout: an optional timeout for the TCP handshake
160
        :type timeout: `int`
161

162
        :arg unknown_host_cb: called when a host key is not recognized
163
        :type unknown_host_cb: see :meth:`signature <ssh.default_unknown_host_cb>`
164

165
        :arg username: the username to use for SSH authentication
166
        :type username: `string`
167

168
        :arg password: the password used if using password authentication, or the passphrase to use for unlocking keys that require it
169
        :type password: `string`
170

171
        :arg key_filename: a filename where a the private key to be used can be found
172
        :type key_filename: `string`
173

174
        :arg allow_agent: enables querying SSH agent (if found) for keys
175
        :type allow_agent: `bool`
176

177
        :arg look_for_keys: enables looking in the usual locations for ssh keys (e.g. :file:`~/.ssh/id_*`)
178
        :type look_for_keys: `bool`
179
        """
180

    
181
        if username is None:
182
            raise SSHError("No username specified")
183

    
184
        sock = None
185
        for res in socket.getaddrinfo(host, port, socket.AF_UNSPEC, socket.SOCK_STREAM):
186
            af, socktype, proto, canonname, sa = res
187
            try:
188
                sock = socket.socket(af, socktype, proto)
189
                sock.settimeout(timeout)
190
            except socket.error:
191
                continue
192
            try:
193
                sock.connect(sa)
194
            except socket.error:
195
                sock.close()
196
                continue
197
            break
198
        else:
199
            raise SSHError("Could not open socket")
200

    
201
        t = self._transport = paramiko.Transport(sock)
202
        t.set_log_channel(logger.name)
203

    
204
        try:
205
            t.start_client()
206
        except paramiko.SSHException:
207
            raise SSHError('Negotiation failed')
208

    
209
        # host key verification
210
        server_key = t.get_remote_server_key()
211
        known_host = self._host_keys.check(host, server_key)
212

    
213
        fingerprint = _colonify(hexlify(server_key.get_fingerprint()))
214

    
215
        if not known_host and not unknown_host_cb(host, fingerprint):
216
            raise SSHUnknownHostError(host, fingerprint)
217

    
218
        if key_filename is None:
219
            key_filenames = []
220
        elif isinstance(key_filename, basestring):
221
            key_filenames = [ key_filename ]
222
        else:
223
            key_filenames = key_filename
224

    
225
        self._auth(username, password, key_filenames, allow_agent, look_for_keys)
226

    
227
        self._connected = True # there was no error authenticating
228

    
229
        c = self._channel = self._transport.open_session()
230
        c.set_name('netconf')
231
        c.invoke_subsystem('netconf')
232

    
233
        self._post_connect()
234
    
235
    # on the lines of paramiko.SSHClient._auth()
236
    def _auth(self, username, password, key_filenames, allow_agent,
237
              look_for_keys):
238
        saved_exception = None
239

    
240
        for key_filename in key_filenames:
241
            for cls in (paramiko.RSAKey, paramiko.DSSKey):
242
                try:
243
                    key = cls.from_private_key_file(key_filename, password)
244
                    logger.debug('Trying key %s from %s' %
245
                              (hexlify(key.get_fingerprint()), key_filename))
246
                    self._transport.auth_publickey(username, key)
247
                    return
248
                except Exception as e:
249
                    saved_exception = e
250
                    logger.debug(e)
251

    
252
        if allow_agent:
253
            for key in paramiko.Agent().get_keys():
254
                try:
255
                    logger.debug('Trying SSH agent key %s' %
256
                                 hexlify(key.get_fingerprint()))
257
                    self._transport.auth_publickey(username, key)
258
                    return
259
                except Exception as e:
260
                    saved_exception = e
261
                    logger.debug(e)
262

    
263
        keyfiles = []
264
        if look_for_keys:
265
            rsa_key = os.path.expanduser('~/.ssh/id_rsa')
266
            dsa_key = os.path.expanduser('~/.ssh/id_dsa')
267
            if os.path.isfile(rsa_key):
268
                keyfiles.append((paramiko.RSAKey, rsa_key))
269
            if os.path.isfile(dsa_key):
270
                keyfiles.append((paramiko.DSSKey, dsa_key))
271
            # look in ~/ssh/ for windows users:
272
            rsa_key = os.path.expanduser('~/ssh/id_rsa')
273
            dsa_key = os.path.expanduser('~/ssh/id_dsa')
274
            if os.path.isfile(rsa_key):
275
                keyfiles.append((paramiko.RSAKey, rsa_key))
276
            if os.path.isfile(dsa_key):
277
                keyfiles.append((paramiko.DSSKey, dsa_key))
278

    
279
        for cls, filename in keyfiles:
280
            try:
281
                key = cls.from_private_key_file(filename, password)
282
                logger.debug('Trying discovered key %s in %s' %
283
                          (hexlify(key.get_fingerprint()), filename))
284
                self._transport.auth_publickey(username, key)
285
                return
286
            except Exception as e:
287
                saved_exception = e
288
                logger.debug(e)
289

    
290
        if password is not None:
291
            try:
292
                self._transport.auth_password(username, password)
293
                return
294
            except Exception as e:
295
                saved_exception = e
296
                logger.debug(e)
297

    
298
        if saved_exception is not None:
299
            # need pep-3134 to do this right
300
            raise AuthenticationError(repr(saved_exception))
301

    
302
        raise AuthenticationError('No authentication methods available')
303

    
304
    def run(self):
305
        chan = self._channel
306
        chan.setblocking(0)
307
        q = self._q
308
        try:
309
            while True:
310
                # select on a paramiko ssh channel object does not ever return
311
                # it in the writable list, so it channel's don't exactly emulate
312
                # the socket api
313
                r, w, e = select([chan], [], [], TICK)
314
                # will wakeup evey TICK seconds to check if something
315
                # to send, more if something to read (due to select returning
316
                # chan in readable list)
317
                if r:
318
                    data = chan.recv(BUF_SIZE)
319
                    if data:
320
                        self._buffer.write(data)
321
                        self._parse()
322
                    else:
323
                        raise SessionCloseError(self._buffer.getvalue())
324
                if not q.empty() and chan.send_ready():
325
                    logger.debug('sending message')
326
                    data = q.get() + MSG_DELIM
327
                    while data:
328
                        n = chan.send(data)
329
                        if n <= 0:
330
                            raise SessionCloseError(self._buffer.getvalue(), data)
331
                        data = data[n:]
332
        except Exception as e:
333
            logger.debug('broke out of main loop, error=%r', e)
334
            self.close()
335
            self._dispatch_error(e)
336

    
337
    @property
338
    def transport(self):
339
        """Underlying `paramiko.Transport
340
        <http://www.lag.net/paramiko/docs/paramiko.Transport-class.html>`_
341
        object. This makes it possible to call methods like set_keepalive on it.
342
        """
343
        return self._transport
344