Statistics
| Branch: | Tag: | Revision:

root / ncclient / transport / ssh.py @ 495b9bf7

History | View | Annotate | Download (12.4 kB)

1
# Copyright 2009 Shikhar Bhushan
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#    http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

    
15
import os
16
import socket
17
import getpass
18
from binascii import hexlify
19
from cStringIO import StringIO
20
from select import select
21

    
22
import paramiko
23

    
24
from errors import AuthenticationError, SessionCloseError, SSHError, SSHUnknownHostError
25
from session import Session
26

    
27
import logging
28
logger = logging.getLogger('ncclient.transport.ssh')
29

    
30
BUF_SIZE = 4096
31
MSG_DELIM = "]]>]]>"
32
TICK = 0.1
33

    
34
def default_unknown_host_cb(host, fingerprint):
35
    """An `unknown host callback` returns :const:`True` if it finds the key
36
    acceptable, and :const:`False` if not.
37

38
    This default callback always returns :const:`False`, which would lead to
39
    :meth:`connect` raising a :exc:`SSHUnknownHost` exception.
40

41
    Supply another valid callback if you need to verify the host key
42
    programatically.
43

44
    :arg host: the hostname that needs to be verified
45
    :type host: string
46

47
    :arg fingerprint: a hex string representing the host key fingerprint
48
    :type fingerprint: string
49
    """
50
    return False
51

    
52
def _colonify(fp):
53
    finga = fp[:2]
54
    for idx  in range(2, len(fp), 2):
55
        finga += ":" + fp[idx:idx+2]
56
    return finga
57

    
58
class SSHSession(Session):
59

    
60
    "Implements a :rfc:`4742` NETCONF session over SSH."
61

    
62
    def __init__(self, capabilities):
63
        Session.__init__(self, capabilities)
64
        self._host_keys = paramiko.HostKeys()
65
        self._transport = None
66
        self._connected = False
67
        self._channel = None
68
        self._buffer = StringIO() # for incoming data
69
        # parsing-related, see _parse()
70
        self._parsing_state = 0
71
        self._parsing_pos = 0
72
    
73
    def _parse(self):
74
        '''Messages ae delimited by MSG_DELIM. The buffer could have grown by a
75
        maximum of BUF_SIZE bytes everytime this method is called. Retains state
76
        across method calls and if a byte has been read it will not be
77
        considered again. '''
78
        delim = MSG_DELIM
79
        n = len(delim) - 1
80
        expect = self._parsing_state
81
        buf = self._buffer
82
        buf.seek(self._parsing_pos)
83
        while True:
84
            x = buf.read(1)
85
            if not x: # done reading
86
                break
87
            elif x == delim[expect]: # what we expected
88
                expect += 1 # expect the next delim char
89
            else:
90
                expect = 0
91
                continue
92
            # loop till last delim char expected, break if other char encountered
93
            for i in range(expect, n):
94
                x = buf.read(1)
95
                if not x: # done reading
96
                    break
97
                if x == delim[expect]: # what we expected
98
                    expect += 1 # expect the next delim char
99
                else:
100
                    expect = 0 # reset
101
                    break
102
            else: # if we didn't break out of the loop, full delim was parsed
103
                msg_till = buf.tell() - n
104
                buf.seek(0)
105
                logger.debug('parsed new message')
106
                self._dispatch_message(buf.read(msg_till).strip())
107
                buf.seek(n+1, os.SEEK_CUR)
108
                rest = buf.read()
109
                buf = StringIO()
110
                buf.write(rest)
111
                buf.seek(0)
112
                expect = 0
113
        self._buffer = buf
114
        self._parsing_state = expect
115
        self._parsing_pos = self._buffer.tell()
116

    
117
    def load_known_hosts(self, filename=None):
118
        """Load host keys from a :file:`known_hosts`-style file. Can be called multiple
119
        times.
120

121
        If *filename* is not specified, looks in the default locations i.e.
122
        :file:`~/.ssh/known_hosts` and :file:`~/ssh/known_hosts` for Windows.
123
        """
124
        if filename is None:
125
            filename = os.path.expanduser('~/.ssh/known_hosts')
126
            try:
127
                self._host_keys.load(filename)
128
            except IOError:
129
                # for windows
130
                filename = os.path.expanduser('~/ssh/known_hosts')
131
                try:
132
                    self._host_keys.load(filename)
133
                except IOError:
134
                    pass
135
        else:
136
            self._host_keys.load(filename)
137

    
138
    def close(self):
139
        if self._transport.is_active():
140
            self._transport.close()
141
        self._connected = False
142

    
143
    def connect(self, host, port=830, timeout=None,
144
                unknown_host_cb=default_unknown_host_cb,
145
                username=None, password=None,
146
                key_filename=None, allow_agent=True, look_for_keys=True):
147
        """Connect via SSH and initialize the NETCONF session. First attempts
148
        the publickey authentication method and then password authentication.
149

150
        To disable attemting publickey authentication altogether, call with
151
        *allow_agent* and *look_for_keys* as :const:`False`.
152

153
        :arg host: the hostname or IP address to connect to
154
        :type host: `string`
155

156
        :arg port: by default 830, but some devices use the default SSH port of 22 so this may need to be specified
157
        :type port: `int`
158

159
        :arg timeout: an optional timeout for the TCP handshake
160
        :type timeout: `int`
161

162
        :arg unknown_host_cb: called when a host key is not recognized
163
        :type unknown_host_cb: see :meth:`signature <ssh.default_unknown_host_cb>`
164

165
        :arg username: the username to use for SSH authentication
166
        :type username: `string`
167

168
        :arg password: the password used if using password authentication, or the passphrase to use for unlocking keys that require it
169
        :type password: `string`
170

171
        :arg key_filename: a filename where a the private key to be used can be found
172
        :type key_filename: `string`
173

174
        :arg allow_agent: enables querying SSH agent (if found) for keys
175
        :type allow_agent: `bool`
176

177
        :arg look_for_keys: enables looking in the usual locations for ssh keys (e.g. :file:`~/.ssh/id_*`)
178
        :type look_for_keys: `bool`
179
        """
180

    
181
        if username is None:
182
            username = getpass.getuser()
183
        
184
        sock = None
185
        for res in socket.getaddrinfo(host, port, socket.AF_UNSPEC, socket.SOCK_STREAM):
186
            af, socktype, proto, canonname, sa = res
187
            try:
188
                sock = socket.socket(af, socktype, proto)
189
                sock.settimeout(timeout)
190
            except socket.error:
191
                continue
192
            try:
193
                sock.connect(sa)
194
            except socket.error:
195
                sock.close()
196
                continue
197
            break
198
        else:
199
            raise SSHError("Could not open socket to %s:%s" % (host, port))
200

    
201
        t = self._transport = paramiko.Transport(sock)
202
        t.set_log_channel(logger.name)
203

    
204
        try:
205
            t.start_client()
206
        except paramiko.SSHException:
207
            raise SSHError('Negotiation failed')
208

    
209
        # host key verification
210
        server_key = t.get_remote_server_key()
211
        known_host = self._host_keys.check(host, server_key)
212

    
213
        fingerprint = _colonify(hexlify(server_key.get_fingerprint()))
214

    
215
        if not known_host and not unknown_host_cb(host, fingerprint):
216
            raise SSHUnknownHostError(host, fingerprint)
217

    
218
        if key_filename is None:
219
            key_filenames = []
220
        elif isinstance(key_filename, basestring):
221
            key_filenames = [ key_filename ]
222
        else:
223
            key_filenames = key_filename
224

    
225
        self._auth(username, password, key_filenames, allow_agent, look_for_keys)
226

    
227
        self._connected = True # there was no error authenticating
228

    
229
        c = self._channel = self._transport.open_session()
230
        c.set_name("netconf")
231
        c.invoke_subsystem("netconf")
232

    
233
        self._post_connect()
234
    
235
    # on the lines of paramiko.SSHClient._auth()
236
    def _auth(self, username, password, key_filenames, allow_agent,
237
              look_for_keys):
238
        saved_exception = None
239

    
240
        for key_filename in key_filenames:
241
            for cls in (paramiko.RSAKey, paramiko.DSSKey):
242
                try:
243
                    key = cls.from_private_key_file(key_filename, password)
244
                    logger.debug("Trying key %s from %s" %
245
                              (hexlify(key.get_fingerprint()), key_filename))
246
                    self._transport.auth_publickey(username, key)
247
                    return
248
                except Exception as e:
249
                    saved_exception = e
250
                    logger.debug(e)
251

    
252
        if allow_agent:
253
            for key in paramiko.Agent().get_keys():
254
                try:
255
                    logger.debug("Trying SSH agent key %s" %
256
                                 hexlify(key.get_fingerprint()))
257
                    self._transport.auth_publickey(username, key)
258
                    return
259
                except Exception as e:
260
                    saved_exception = e
261
                    logger.debug(e)
262

    
263
        keyfiles = []
264
        if look_for_keys:
265
            rsa_key = os.path.expanduser("~/.ssh/id_rsa")
266
            dsa_key = os.path.expanduser("~/.ssh/id_dsa")
267
            if os.path.isfile(rsa_key):
268
                keyfiles.append((paramiko.RSAKey, rsa_key))
269
            if os.path.isfile(dsa_key):
270
                keyfiles.append((paramiko.DSSKey, dsa_key))
271
            # look in ~/ssh/ for windows users:
272
            rsa_key = os.path.expanduser("~/ssh/id_rsa")
273
            dsa_key = os.path.expanduser("~/ssh/id_dsa")
274
            if os.path.isfile(rsa_key):
275
                keyfiles.append((paramiko.RSAKey, rsa_key))
276
            if os.path.isfile(dsa_key):
277
                keyfiles.append((paramiko.DSSKey, dsa_key))
278

    
279
        for cls, filename in keyfiles:
280
            try:
281
                key = cls.from_private_key_file(filename, password)
282
                logger.debug("Trying discovered key %s in %s" %
283
                          (hexlify(key.get_fingerprint()), filename))
284
                self._transport.auth_publickey(username, key)
285
                return
286
            except Exception as e:
287
                saved_exception = e
288
                logger.debug(e)
289

    
290
        if password is not None:
291
            try:
292
                self._transport.auth_password(username, password)
293
                return
294
            except Exception as e:
295
                saved_exception = e
296
                logger.debug(e)
297

    
298
        if saved_exception is not None:
299
            # need pep-3134 to do this right
300
            raise AuthenticationError(repr(saved_exception))
301

    
302
        raise AuthenticationError("No authentication methods available")
303

    
304
    def run(self):
305
        chan = self._channel
306
        chan.setblocking(0)
307
        q = self._q
308
        try:
309
            while True:
310
                # select on a paramiko ssh channel object does not ever return
311
                # it in the writable list, so it channel's don't exactly emulate
312
                # the socket api
313
                r, w, e = select([chan], [], [], TICK)
314
                # will wakeup evey TICK seconds to check if something
315
                # to send, more if something to read (due to select returning
316
                # chan in readable list)
317
                if r:
318
                    data = chan.recv(BUF_SIZE)
319
                    if data:
320
                        self._buffer.write(data)
321
                        self._parse()
322
                    else:
323
                        raise SessionCloseError(self._buffer.getvalue())
324
                if not q.empty() and chan.send_ready():
325
                    logger.debug("Sending message")
326
                    data = q.get() + MSG_DELIM
327
                    while data:
328
                        n = chan.send(data)
329
                        if n <= 0:
330
                            raise SessionCloseError(self._buffer.getvalue(), data)
331
                        data = data[n:]
332
        except Exception as e:
333
            logger.debug("Broke out of main loop, error=%r", e)
334
            self.close()
335
            self._dispatch_error(e)
336

    
337
    @property
338
    def transport(self):
339
        """Underlying `paramiko.Transport
340
        <http://www.lag.net/paramiko/docs/paramiko.Transport-class.html>`_
341
        object. This makes it possible to call methods like set_keepalive on it.
342
        """
343
        return self._transport