Statistics
| Branch: | Tag: | Revision:

root / ncclient / transport / ssh.py @ 030b950d

History | View | Annotate | Download (12.5 kB)

1
# Copyright 2009 Shikhar Bhushan
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#    http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

    
15
import os
16
import socket
17
from binascii import hexlify
18
from cStringIO import StringIO
19
from select import select
20

    
21
import paramiko
22

    
23
from errors import AuthenticationError, SessionCloseError, SSHError, SSHUnknownHostError
24
from session import Session
25

    
26
import logging
27
logger = logging.getLogger('ncclient.transport.ssh')
28

    
29
BUF_SIZE = 4096
30
MSG_DELIM = "]]>]]>"
31
TICK = 0.1
32

    
33
def default_unknown_host_cb(host, fingerprint):
34
    """An `unknown host callback` returns :const:`True` if it finds the key
35
    acceptable, and :const:`False` if not.
36

37
    This default callback always returns :const:`False`, which would lead to
38
    :meth:`connect` raising a :exc:`SSHUnknownHost` exception.
39

40
    Supply another valid callback if you need to verify the host key
41
    programatically.
42

43
    :arg host: the hostname that needs to be verified
44
    :type host: string
45

46
    :arg fingerprint: a hex string representing the host key fingerprint
47
    :type fingerprint: string
48
    """
49
    return False
50

    
51
def _colonify(fp):
52
    finga = fp[:2]
53
    for idx  in range(2, len(fp), 2):
54
        finga += ":" + fp[idx:idx+2]
55
    return finga
56

    
57
class SSHSession(Session):
58

    
59
    "Implements a :rfc:`4742` NETCONF session over SSH."
60

    
61
    def __init__(self, capabilities):
62
        Session.__init__(self, capabilities)
63
        self._host_keys = paramiko.HostKeys()
64
        self._transport = None
65
        self._connected = False
66
        self._channel = None
67
        self._buffer = StringIO() # for incoming data
68
        # parsing-related, see _parse()
69
        self._parsing_state = 0
70
        self._parsing_pos = 0
71
    
72
    def _parse(self):
73
        '''Messages ae delimited by MSG_DELIM. The buffer could have grown by a
74
        maximum of BUF_SIZE bytes everytime this method is called. Retains state
75
        across method calls and if a byte has been read it will not be
76
        considered again. '''
77
        delim = MSG_DELIM
78
        n = len(delim) - 1
79
        expect = self._parsing_state
80
        buf = self._buffer
81
        buf.seek(self._parsing_pos)
82
        while True:
83
            x = buf.read(1)
84
            if not x: # done reading
85
                break
86
            elif x == delim[expect]: # what we expected
87
                expect += 1 # expect the next delim char
88
            else:
89
                expect = 0
90
                continue
91
            # loop till last delim char expected, break if other char encountered
92
            for i in range(expect, n):
93
                x = buf.read(1)
94
                if not x: # done reading
95
                    break
96
                if x == delim[expect]: # what we expected
97
                    expect += 1 # expect the next delim char
98
                else:
99
                    expect = 0 # reset
100
                    break
101
            else: # if we didn't break out of the loop, full delim was parsed
102
                msg_till = buf.tell() - n
103
                buf.seek(0)
104
                logger.debug('parsed new message')
105
                self._dispatch_message(buf.read(msg_till).strip())
106
                buf.seek(n+1, os.SEEK_CUR)
107
                rest = buf.read()
108
                buf = StringIO()
109
                buf.write(rest)
110
                buf.seek(0)
111
                expect = 0
112
        self._buffer = buf
113
        self._parsing_state = expect
114
        self._parsing_pos = self._buffer.tell()
115

    
116
    def load_known_hosts(self, filename=None):
117
        """Load host keys from a :file:`known_hosts`-style file. Can be called multiple
118
        times.
119

120
        If *filename* is not specified, looks in the default locations i.e.
121
        :file:`~/.ssh/known_hosts` and :file:`~/ssh/known_hosts` for Windows.
122
        """
123
        if filename is None:
124
            filename = os.path.expanduser('~/.ssh/known_hosts')
125
            try:
126
                self._host_keys.load(filename)
127
            except IOError:
128
                # for windows
129
                filename = os.path.expanduser('~/ssh/known_hosts')
130
                try:
131
                    self._host_keys.load(filename)
132
                except IOError:
133
                    pass
134
        else:
135
            self._host_keys.load(filename)
136

    
137
    def close(self):
138
        if self._transport.is_active():
139
            self._transport.close()
140
        self._connected = False
141

    
142
    def connect(self, host, port=830, timeout=None,
143
                unknown_host_cb=default_unknown_host_cb,
144
                username=None, password=None,
145
                key_filename=None, allow_agent=True, look_for_keys=True):
146
        """Connect via SSH and initialize the NETCONF session. First attempts
147
        the publickey authentication method and then password authentication.
148

149
        To disable attemting publickey authentication altogether, call with
150
        *allow_agent* and *look_for_keys* as :const:`False`. This may be needed
151
        for Cisco devices which immediately disconnect on an incorrect
152
        authentication attempt.
153

154
        :arg host: the hostname or IP address to connect to
155
        :type host: `string`
156

157
        :arg port: by default 830, but some devices use the default SSH port of 22 so this may need to be specified
158
        :type port: `int`
159

160
        :arg timeout: an optional timeout for the TCP handshake
161
        :type timeout: `int`
162

163
        :arg unknown_host_cb: called when a host key is not recognized
164
        :type unknown_host_cb: see :meth:`signature <ssh.default_unknown_host_cb>`
165

166
        :arg username: the username to use for SSH authentication
167
        :type username: `string`
168

169
        :arg password: the password used if using password authentication, or the passphrase to use for unlocking keys that require it
170
        :type password: `string`
171

172
        :arg key_filename: a filename where a the private key to be used can be found
173
        :type key_filename: `string`
174

175
        :arg allow_agent: enables querying SSH agent (if found) for keys
176
        :type allow_agent: `bool`
177

178
        :arg look_for_keys: enables looking in the usual locations for ssh keys (e.g. :file:`~/.ssh/id_*`)
179
        :type look_for_keys: `bool`
180
        """
181

    
182
        if username is None:
183
            raise SSHError("No username specified")
184

    
185
        sock = None
186
        for res in socket.getaddrinfo(host, port, socket.AF_UNSPEC, socket.SOCK_STREAM):
187
            af, socktype, proto, canonname, sa = res
188
            try:
189
                sock = socket.socket(af, socktype, proto)
190
                sock.settimeout(timeout)
191
            except socket.error:
192
                continue
193
            try:
194
                sock.connect(sa)
195
            except socket.error:
196
                sock.close()
197
                continue
198
            break
199
        else:
200
            raise SSHError("Could not open socket")
201

    
202
        t = self._transport = paramiko.Transport(sock)
203
        t.set_log_channel(logger.name)
204

    
205
        try:
206
            t.start_client()
207
        except paramiko.SSHException:
208
            raise SSHError('Negotiation failed')
209

    
210
        # host key verification
211
        server_key = t.get_remote_server_key()
212
        known_host = self._host_keys.check(host, server_key)
213

    
214
        fingerprint = _colonify(hexlify(server_key.get_fingerprint()))
215

    
216
        if not known_host and not unknown_host_cb(host, fingerprint):
217
            raise SSHUnknownHostError(host, fingerprint)
218

    
219
        if key_filename is None:
220
            key_filenames = []
221
        elif isinstance(key_filename, basestring):
222
            key_filenames = [ key_filename ]
223
        else:
224
            key_filenames = key_filename
225

    
226
        self._auth(username, password, key_filenames, allow_agent, look_for_keys)
227

    
228
        self._connected = True # there was no error authenticating
229

    
230
        c = self._channel = self._transport.open_session()
231
        c.set_name("netconf")
232
        c.invoke_subsystem("netconf")
233

    
234
        self._post_connect()
235
    
236
    # on the lines of paramiko.SSHClient._auth()
237
    def _auth(self, username, password, key_filenames, allow_agent,
238
              look_for_keys):
239
        saved_exception = None
240

    
241
        for key_filename in key_filenames:
242
            for cls in (paramiko.RSAKey, paramiko.DSSKey):
243
                try:
244
                    key = cls.from_private_key_file(key_filename, password)
245
                    logger.debug("Trying key %s from %s" %
246
                              (hexlify(key.get_fingerprint()), key_filename))
247
                    self._transport.auth_publickey(username, key)
248
                    return
249
                except Exception as e:
250
                    saved_exception = e
251
                    logger.debug(e)
252

    
253
        if allow_agent:
254
            for key in paramiko.Agent().get_keys():
255
                try:
256
                    logger.debug("Trying SSH agent key %s" %
257
                                 hexlify(key.get_fingerprint()))
258
                    self._transport.auth_publickey(username, key)
259
                    return
260
                except Exception as e:
261
                    saved_exception = e
262
                    logger.debug(e)
263

    
264
        keyfiles = []
265
        if look_for_keys:
266
            rsa_key = os.path.expanduser("~/.ssh/id_rsa")
267
            dsa_key = os.path.expanduser("~/.ssh/id_dsa")
268
            if os.path.isfile(rsa_key):
269
                keyfiles.append((paramiko.RSAKey, rsa_key))
270
            if os.path.isfile(dsa_key):
271
                keyfiles.append((paramiko.DSSKey, dsa_key))
272
            # look in ~/ssh/ for windows users:
273
            rsa_key = os.path.expanduser("~/ssh/id_rsa")
274
            dsa_key = os.path.expanduser("~/ssh/id_dsa")
275
            if os.path.isfile(rsa_key):
276
                keyfiles.append((paramiko.RSAKey, rsa_key))
277
            if os.path.isfile(dsa_key):
278
                keyfiles.append((paramiko.DSSKey, dsa_key))
279

    
280
        for cls, filename in keyfiles:
281
            try:
282
                key = cls.from_private_key_file(filename, password)
283
                logger.debug("Trying discovered key %s in %s" %
284
                          (hexlify(key.get_fingerprint()), filename))
285
                self._transport.auth_publickey(username, key)
286
                return
287
            except Exception as e:
288
                saved_exception = e
289
                logger.debug(e)
290

    
291
        if password is not None:
292
            try:
293
                self._transport.auth_password(username, password)
294
                return
295
            except Exception as e:
296
                saved_exception = e
297
                logger.debug(e)
298

    
299
        if saved_exception is not None:
300
            # need pep-3134 to do this right
301
            raise AuthenticationError(repr(saved_exception))
302

    
303
        raise AuthenticationError("No authentication methods available")
304

    
305
    def run(self):
306
        chan = self._channel
307
        chan.setblocking(0)
308
        q = self._q
309
        try:
310
            while True:
311
                # select on a paramiko ssh channel object does not ever return
312
                # it in the writable list, so it channel's don't exactly emulate
313
                # the socket api
314
                r, w, e = select([chan], [], [], TICK)
315
                # will wakeup evey TICK seconds to check if something
316
                # to send, more if something to read (due to select returning
317
                # chan in readable list)
318
                if r:
319
                    data = chan.recv(BUF_SIZE)
320
                    if data:
321
                        self._buffer.write(data)
322
                        self._parse()
323
                    else:
324
                        raise SessionCloseError(self._buffer.getvalue())
325
                if not q.empty() and chan.send_ready():
326
                    logger.debug("Sending message")
327
                    data = q.get() + MSG_DELIM
328
                    while data:
329
                        n = chan.send(data)
330
                        if n <= 0:
331
                            raise SessionCloseError(self._buffer.getvalue(), data)
332
                        data = data[n:]
333
        except Exception as e:
334
            logger.debug("Broke out of main loop, error=%r", e)
335
            self.close()
336
            self._dispatch_error(e)
337

    
338
    @property
339
    def transport(self):
340
        """Underlying `paramiko.Transport
341
        <http://www.lag.net/paramiko/docs/paramiko.Transport-class.html>`_
342
        object. This makes it possible to call methods like set_keepalive on it.
343
        """
344
        return self._transport
345