Statistics
| Branch: | Tag: | Revision:

root / ncclient / transport / ssh.py @ e0e01d37

History | View | Annotate | Download (12.3 kB)

1
# Copyright 2009 Shikhar Bhushan
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#    http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

    
15
import os
16
import socket
17
from binascii import hexlify
18
from cStringIO import StringIO
19
from select import select
20

    
21
import paramiko
22

    
23
from errors import AuthenticationError, SessionCloseError, SSHError, SSHUnknownHostError
24
from session import Session
25

    
26
import logging
27
logger = logging.getLogger('ncclient.transport.ssh')
28

    
29
BUF_SIZE = 4096
30
MSG_DELIM = ']]>]]>'
31
TICK = 0.1
32

    
33
def default_unknown_host_cb(host, key):
34
    """An `unknown host callback` returns :const:`True` if it finds the key
35
    acceptable, and :const:`False` if not.
36

37
    This default callback always returns :const:`False`, which would lead to
38
    :meth:`connect` raising a :exc:`SSHUnknownHost` exception.
39

40
    Supply another valid callback if you need to verify the host key
41
    programatically.
42

43
    :arg host: the host for whom key needs to be verified
44
    :type host: string
45

46
    :arg key: a hex string representing the host key fingerprint
47
    :type key: string
48
    """
49
    return False
50

    
51

    
52
class SSHSession(Session):
53

    
54
    "Implements a :rfc:`4742` NETCONF session over SSH."
55

    
56
    def __init__(self, capabilities):
57
        Session.__init__(self, capabilities)
58
        self._host_keys = paramiko.HostKeys()
59
        self._transport = None
60
        self._connected = False
61
        self._channel = None
62
        self._buffer = StringIO() # for incoming data
63
        # parsing-related, see _parse()
64
        self._parsing_state = 0
65
        self._parsing_pos = 0
66

    
67
    def _parse(self):
68
        '''Messages ae delimited by MSG_DELIM. The buffer could have grown by a
69
        maximum of BUF_SIZE bytes everytime this method is called. Retains state
70
        across method calls and if a byte has been read it will not be
71
        considered again. '''
72
        delim = MSG_DELIM
73
        n = len(delim) - 1
74
        expect = self._parsing_state
75
        buf = self._buffer
76
        buf.seek(self._parsing_pos)
77
        while True:
78
            x = buf.read(1)
79
            if not x: # done reading
80
                break
81
            elif x == delim[expect]: # what we expected
82
                expect += 1 # expect the next delim char
83
            else:
84
                continue
85
            # loop till last delim char expected, break if other char encountered
86
            for i in range(expect, n):
87
                x = buf.read(1)
88
                if not x: # done reading
89
                    break
90
                if x == delim[expect]: # what we expected
91
                    expect += 1 # expect the next delim char
92
                else:
93
                    expect = 0 # reset
94
                    break
95
            else: # if we didn't break out of the loop, full delim was parsed
96
                msg_till = buf.tell() - n
97
                buf.seek(0)
98
                logger.debug('parsed new message')
99
                self._dispatch_message(buf.read(msg_till).strip())
100
                buf.seek(n+1, os.SEEK_CUR)
101
                rest = buf.read()
102
                buf = StringIO()
103
                buf.write(rest)
104
                buf.seek(0)
105
                expect = 0
106
        self._buffer = buf
107
        self._parsing_state = expect
108
        self._parsing_pos = self._buffer.tell()
109

    
110
    def load_known_hosts(self, filename=None):
111
        """Load host keys from a :file:`known_hosts`-style file. Can be called multiple
112
        times.
113

114
        If *filename* is not specified, looks in the default locations i.e.
115
        :file:`~/.ssh/known_hosts` and :file:`~/ssh/known_hosts` for Windows.
116
        """
117
        if filename is None:
118
            filename = os.path.expanduser('~/.ssh/known_hosts')
119
            try:
120
                self._host_keys.load(filename)
121
            except IOError:
122
                # for windows
123
                filename = os.path.expanduser('~/ssh/known_hosts')
124
                try:
125
                    self._host_keys.load(filename)
126
                except IOError:
127
                    pass
128
        else:
129
            self._host_keys.load(filename)
130

    
131
    def close(self):
132
        if self._transport.is_active():
133
            self._transport.close()
134
        self._connected = False
135

    
136
    def connect(self, host, port=830, timeout=None,
137
                unknown_host_cb=default_unknown_host_cb,
138
                username=None, password=None,
139
                key_filename=None, allow_agent=True, look_for_keys=True):
140
        """Connect via SSH and initialize the NETCONF session. First attempts
141
        the publickey authentication method and then password authentication.
142

143
        To disable attemting publickey authentication altogether, call with
144
        *allow_agent* and *look_for_keys* as :const:`False`. This may be needed
145
        for Cisco devices which immediately disconnect on an incorrect
146
        authentication attempt.
147

148
        :arg host: the hostname or IP address to connect to
149
        :type host: `string`
150

151
        :arg port: by default 830, but some devices use the default SSH port of 22 so this may need to be specified
152
        :type port: `int`
153

154
        :arg timeout: an optional timeout for the TCP handshake
155
        :type timeout: `int`
156

157
        :arg unknown_host_cb: called when a host key is not recognized
158
        :type unknown_host_cb: see :meth:`signature <ssh.default_unknown_host_cb>`
159

160
        :arg username: the username to use for SSH authentication
161
        :type username: `string`
162

163
        :arg password: the password used if using password authentication, or the passphrase to use for unlocking keys that require it
164
        :type password: `string`
165

166
        :arg key_filename: a filename where a the private key to be used can be found
167
        :type key_filename: `string`
168

169
        :arg allow_agent: enables querying SSH agent (if found) for keys
170
        :type allow_agent: `bool`
171

172
        :arg look_for_keys: enables looking in the usual locations for ssh keys (e.g. :file:`~/.ssh/id_*`)
173
        :type look_for_keys: `bool`
174
        """
175

    
176
        if username is None:
177
            raise SSHError("No username specified")
178

    
179
        sock = None
180
        for res in socket.getaddrinfo(host, port, socket.AF_UNSPEC, socket.SOCK_STREAM):
181
            af, socktype, proto, canonname, sa = res
182
            try:
183
                sock = socket.socket(af, socktype, proto)
184
                sock.settimeout(timeout)
185
            except socket.error:
186
                continue
187
            try:
188
                sock.connect(sa)
189
            except socket.error:
190
                sock.close()
191
                continue
192
            break
193
        else:
194
            raise SSHError("Could not open socket")
195

    
196
        t = self._transport = paramiko.Transport(sock)
197
        t.set_log_channel(logger.name)
198

    
199
        try:
200
            t.start_client()
201
        except paramiko.SSHException:
202
            raise SSHError('Negotiation failed')
203

    
204
        # host key verification
205
        server_key = t.get_remote_server_key()
206
        known_host = self._host_keys.check(host, server_key)
207

    
208
        fingerprint = hexlify(server_key.get_fingerprint())
209

    
210
        if not known_host and not unknown_host_cb(host, fingerprint):
211
            raise SSHUnknownHostError(host, fingerprint)
212

    
213
        if key_filename is None:
214
            key_filenames = []
215
        elif isinstance(key_filename, basestring):
216
            key_filenames = [ key_filename ]
217
        else:
218
            key_filenames = key_filename
219

    
220
        self._auth(username, password, key_filenames, allow_agent, look_for_keys)
221

    
222
        self._connected = True # there was no error authenticating
223

    
224
        c = self._channel = self._transport.open_session()
225
        c.set_name('netconf')
226
        c.invoke_subsystem('netconf')
227

    
228
        self._post_connect()
229

    
230
    # on the lines of paramiko.SSHClient._auth()
231
    def _auth(self, username, password, key_filenames, allow_agent,
232
              look_for_keys):
233
        saved_exception = None
234

    
235
        for key_filename in key_filenames:
236
            for cls in (paramiko.RSAKey, paramiko.DSSKey):
237
                try:
238
                    key = cls.from_private_key_file(key_filename, password)
239
                    logger.debug('Trying key %s from %s' %
240
                              (hexlify(key.get_fingerprint()), key_filename))
241
                    self._transport.auth_publickey(username, key)
242
                    return
243
                except Exception as e:
244
                    saved_exception = e
245
                    logger.debug(e)
246

    
247
        if allow_agent:
248
            for key in paramiko.Agent().get_keys():
249
                try:
250
                    logger.debug('Trying SSH agent key %s' %
251
                                 hexlify(key.get_fingerprint()))
252
                    self._transport.auth_publickey(username, key)
253
                    return
254
                except Exception as e:
255
                    saved_exception = e
256
                    logger.debug(e)
257

    
258
        keyfiles = []
259
        if look_for_keys:
260
            rsa_key = os.path.expanduser('~/.ssh/id_rsa')
261
            dsa_key = os.path.expanduser('~/.ssh/id_dsa')
262
            if os.path.isfile(rsa_key):
263
                keyfiles.append((paramiko.RSAKey, rsa_key))
264
            if os.path.isfile(dsa_key):
265
                keyfiles.append((paramiko.DSSKey, dsa_key))
266
            # look in ~/ssh/ for windows users:
267
            rsa_key = os.path.expanduser('~/ssh/id_rsa')
268
            dsa_key = os.path.expanduser('~/ssh/id_dsa')
269
            if os.path.isfile(rsa_key):
270
                keyfiles.append((paramiko.RSAKey, rsa_key))
271
            if os.path.isfile(dsa_key):
272
                keyfiles.append((paramiko.DSSKey, dsa_key))
273

    
274
        for cls, filename in keyfiles:
275
            try:
276
                key = cls.from_private_key_file(filename, password)
277
                logger.debug('Trying discovered key %s in %s' %
278
                          (hexlify(key.get_fingerprint()), filename))
279
                self._transport.auth_publickey(username, key)
280
                return
281
            except Exception as e:
282
                saved_exception = e
283
                logger.debug(e)
284

    
285
        if password is not None:
286
            try:
287
                self._transport.auth_password(username, password)
288
                return
289
            except Exception as e:
290
                saved_exception = e
291
                logger.debug(e)
292

    
293
        if saved_exception is not None:
294
            # need pep-3134 to do this right
295
            raise AuthenticationError(repr(saved_exception))
296

    
297
        raise AuthenticationError('No authentication methods available')
298

    
299
    def run(self):
300
        chan = self._channel
301
        chan.setblocking(0)
302
        q = self._q
303
        try:
304
            while True:
305
                # select on a paramiko ssh channel object does not ever return
306
                # it in the writable list, so it channel's don't exactly emulate
307
                # the socket api
308
                r, w, e = select([chan], [], [], TICK)
309
                # will wakeup evey TICK seconds to check if something
310
                # to send, more if something to read (due to select returning
311
                # chan in readable list)
312
                if r:
313
                    data = chan.recv(BUF_SIZE)
314
                    if data:
315
                        self._buffer.write(data)
316
                        self._parse()
317
                    else:
318
                        raise SessionCloseError(self._buffer.getvalue())
319
                if not q.empty() and chan.send_ready():
320
                    logger.debug('sending message')
321
                    data = q.get() + MSG_DELIM
322
                    while data:
323
                        n = chan.send(data)
324
                        if n <= 0:
325
                            raise SessionCloseError(self._buffer.getvalue(), data)
326
                        data = data[n:]
327
        except Exception as e:
328
            logger.debug('broke out of main loop, error=%r', e)
329
            self.close()
330
            self._dispatch_error(e)
331

    
332
    @property
333
    def transport(self):
334
        """Underlying `paramiko.Transport
335
        <http://www.lag.net/paramiko/docs/paramiko.Transport-class.html>`_
336
        object. This makes it possible to call methods like set_keepalive on it.
337
        """
338
        return self._transport
339