Statistics
| Branch: | Tag: | Revision:

root / ncclient / transport / ssh.py @ 4f650d54

History | View | Annotate | Download (12.1 kB)

1
# Copyright 2009 Shikhar Bhushan
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#    http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

    
15
import os
16
import socket
17
from binascii import hexlify
18
from cStringIO import StringIO
19
from select import select
20

    
21
import paramiko
22

    
23
from errors import AuthenticationError, SessionCloseError, SSHError, SSHUnknownHostError
24
from session import Session
25

    
26
import logging
27
logger = logging.getLogger('ncclient.transport.ssh')
28

    
29
BUF_SIZE = 4096
30
MSG_DELIM = ']]>]]>'
31
TICK = 0.1
32

    
33
def default_unknown_host_cb(host, key):
34
    """An `unknown host callback` returns :const:`True` if it finds the key
35
    acceptable, and :const:`False` if not.
36

37
    :arg host: the hostname/address which needs to be verified
38

39
    :arg key: a hex string representing the host key fingerprint
40

41
    :returns: this default callback always returns :const:`False`
42
    """
43
    return False
44

    
45

    
46
class SSHSession(Session):
47

    
48
    "Implements a :rfc:`4742` NETCONF session over SSH."
49

    
50
    def __init__(self, capabilities):
51
        Session.__init__(self, capabilities)
52
        self._host_keys = paramiko.HostKeys()
53
        self._system_host_keys = paramiko.HostKeys()
54
        self._transport = None
55
        self._connected = False
56
        self._channel = None
57
        self._expecting_close = False
58
        self._buffer = StringIO() # for incoming data
59
        # parsing-related, see _parse()
60
        self._parsing_state = 0
61
        self._parsing_pos = 0
62

    
63
    def _parse(self):
64
        '''Messages ae delimited by MSG_DELIM. The buffer could have grown by a
65
        maximum of BUF_SIZE bytes everytime this method is called. Retains state
66
        across method calls and if a byte has been read it will not be
67
        considered again. '''
68
        delim = MSG_DELIM
69
        n = len(delim) - 1
70
        expect = self._parsing_state
71
        buf = self._buffer
72
        buf.seek(self._parsing_pos)
73
        while True:
74
            x = buf.read(1)
75
            if not x: # done reading
76
                break
77
            elif x == delim[expect]: # what we expected
78
                expect += 1 # expect the next delim char
79
            else:
80
                continue
81
            # loop till last delim char expected, break if other char encountered
82
            for i in range(expect, n):
83
                x = buf.read(1)
84
                if not x: # done reading
85
                    break
86
                if x == delim[expect]: # what we expected
87
                    expect += 1 # expect the next delim char
88
                else:
89
                    expect = 0 # reset
90
                    break
91
            else: # if we didn't break out of the loop, full delim was parsed
92
                msg_till = buf.tell() - n
93
                buf.seek(0)
94
                logger.debug('parsed new message')
95
                self._dispatch_message(buf.read(msg_till).strip())
96
                buf.seek(n+1, os.SEEK_CUR)
97
                rest = buf.read()
98
                buf = StringIO()
99
                buf.write(rest)
100
                buf.seek(0)
101
                expect = 0
102
        self._buffer = buf
103
        self._parsing_state = expect
104
        self._parsing_pos = self._buffer.tell()
105

    
106
    def load_system_host_keys(self, filename=None):
107
        if filename is None:
108
            filename = os.path.expanduser('~/.ssh/known_hosts')
109
            try:
110
                self._system_host_keys.load(filename)
111
            except IOError:
112
                # for windows
113
                filename = os.path.expanduser('~/ssh/known_hosts')
114
                try:
115
                    self._system_host_keys.load(filename)
116
                except IOError:
117
                    pass
118
            return
119
        self._system_host_keys.load(filename)
120

    
121
    def load_host_keys(self, filename):
122
        self._host_keys.load(filename)
123

    
124
    def add_host_key(self, key):
125
        self._host_keys.add(key)
126

    
127
    def save_host_keys(self, filename):
128
        f = open(filename, 'w')
129
        for host, keys in self._host_keys.iteritems():
130
            for keytype, key in keys.iteritems():
131
                f.write('%s %s %s\n' % (host, keytype, key.get_base64()))
132
        f.close()
133

    
134
    def close(self):
135
        self._expecting_close = True
136
        if self._transport.is_active():
137
            self._transport.close()
138
        self._connected = False
139

    
140
    def connect(self, host, port=830, timeout=None,
141
                unknown_host_cb=default_unknown_host_cb,
142
                username=None, password=None,
143
                key_filename=None, allow_agent=True, look_for_keys=True):
144
        """Connect via SSH and initialize the NETCONF session. First attempts
145
        the publickey authentication method and then password authentication.
146

147
        To disable publickey authentication, call with *allow_agent* and
148
        *look_for_keys* as :const:`False`
149

150
        :arg host: the hostname or IP address to connect to
151

152
        :arg port: by default 830, but some devices use the default SSH port of 22 so this may need to be specified
153

154
        :arg timeout: an optional timeout for the TCP handshake
155

156
        :arg unknown_host_cb: called when a host key is not known. See :func:`unknown_host_cb` for details on signature
157

158
        :arg username: the username to use for SSH authentication
159

160
        :arg password: the password used if using password authentication, or the passphrase to use in order to unlock keys that require it
161

162
        :arg key_filename: a filename where a the private key to be used can be found
163

164
        :arg allow_agent: enables querying SSH agent (if found) for keys
165

166
        :arg look_for_keys: enables looking in the usual locations for ssh keys (e.g. :file:`~/.ssh/id_*`)
167
        """
168

    
169
        assert(username is not None)
170

    
171
        for (family, socktype, proto, canonname, sockaddr) in \
172
        socket.getaddrinfo(host, port):
173
            if socktype == socket.SOCK_STREAM:
174
                af = family
175
                addr = sockaddr
176
                break
177
        else:
178
            raise SSHError('No suitable address family for %s' % host)
179
        sock = socket.socket(af, socket.SOCK_STREAM)
180
        sock.settimeout(timeout)
181
        sock.connect(addr)
182
        t = self._transport = paramiko.Transport(sock)
183
        t.set_log_channel(logger.name)
184

    
185
        try:
186
            t.start_client()
187
        except paramiko.SSHException:
188
            raise SSHError('Negotiation failed')
189

    
190
        # host key verification
191
        server_key = t.get_remote_server_key()
192
        known_host = self._host_keys.check(host, server_key) or \
193
                        self._system_host_keys.check(host, server_key)
194

    
195
        fp = hexlify(server_key.get_fingerprint())
196
        if not known_host and not unknown_host_cb(host, fp):
197
            raise SSHUnknownHostError(host, fp)
198

    
199
        if key_filename is None:
200
            key_filenames = []
201
        elif isinstance(key_filename, basestring):
202
            key_filenames = [ key_filename ]
203
        else:
204
            key_filenames = key_filename
205

    
206
        self._auth(username, password, key_filenames, allow_agent, look_for_keys)
207

    
208
        self._connected = True # there was no error authenticating
209

    
210
        c = self._channel = self._transport.open_session()
211
        c.set_name('netconf')
212
        c.invoke_subsystem('netconf')
213

    
214
        self._post_connect()
215

    
216
    # on the lines of paramiko.SSHClient._auth()
217
    def _auth(self, username, password, key_filenames, allow_agent,
218
              look_for_keys):
219
        saved_exception = None
220

    
221
        for key_filename in key_filenames:
222
            for cls in (paramiko.RSAKey, paramiko.DSSKey):
223
                try:
224
                    key = cls.from_private_key_file(key_filename, password)
225
                    logger.debug('Trying key %s from %s' %
226
                              (hexlify(key.get_fingerprint()), key_filename))
227
                    self._transport.auth_publickey(username, key)
228
                    return
229
                except Exception as e:
230
                    saved_exception = e
231
                    logger.debug(e)
232

    
233
        if allow_agent:
234
            for key in paramiko.Agent().get_keys():
235
                try:
236
                    logger.debug('Trying SSH agent key %s' %
237
                                 hexlify(key.get_fingerprint()))
238
                    self._transport.auth_publickey(username, key)
239
                    return
240
                except Exception as e:
241
                    saved_exception = e
242
                    logger.debug(e)
243

    
244
        keyfiles = []
245
        if look_for_keys:
246
            rsa_key = os.path.expanduser('~/.ssh/id_rsa')
247
            dsa_key = os.path.expanduser('~/.ssh/id_dsa')
248
            if os.path.isfile(rsa_key):
249
                keyfiles.append((paramiko.RSAKey, rsa_key))
250
            if os.path.isfile(dsa_key):
251
                keyfiles.append((paramiko.DSSKey, dsa_key))
252
            # look in ~/ssh/ for windows users:
253
            rsa_key = os.path.expanduser('~/ssh/id_rsa')
254
            dsa_key = os.path.expanduser('~/ssh/id_dsa')
255
            if os.path.isfile(rsa_key):
256
                keyfiles.append((paramiko.RSAKey, rsa_key))
257
            if os.path.isfile(dsa_key):
258
                keyfiles.append((paramiko.DSSKey, dsa_key))
259

    
260
        for cls, filename in keyfiles:
261
            try:
262
                key = cls.from_private_key_file(filename, password)
263
                logger.debug('Trying discovered key %s in %s' %
264
                          (hexlify(key.get_fingerprint()), filename))
265
                self._transport.auth_publickey(username, key)
266
                return
267
            except Exception as e:
268
                saved_exception = e
269
                logger.debug(e)
270

    
271
        if password is not None:
272
            try:
273
                self._transport.auth_password(username, password)
274
                return
275
            except Exception as e:
276
                saved_exception = e
277
                logger.debug(e)
278

    
279
        if saved_exception is not None:
280
            # need pep-3134 to do this right
281
            raise SSHAuthenticationError(repr(saved_exception))
282

    
283
        raise SSHAuthenticationError('No authentication methods available')
284

    
285
    def run(self):
286
        chan = self._channel
287
        chan.setblocking(0)
288
        q = self._q
289
        try:
290
            while True:
291
                # select on a paramiko ssh channel object does not ever return
292
                # it in the writable list, so it channel's don't exactly emulate
293
                # the socket api
294
                r, w, e = select([chan], [], [], TICK)
295
                # will wakeup evey TICK seconds to check if something
296
                # to send, more if something to read (due to select returning
297
                # chan in readable list)
298
                if r:
299
                    data = chan.recv(BUF_SIZE)
300
                    if data:
301
                        self._buffer.write(data)
302
                        self._parse()
303
                    else:
304
                        raise SessionCloseError(self._buffer.getvalue())
305
                if not q.empty() and chan.send_ready():
306
                    logger.debug('sending message')
307
                    data = q.get() + MSG_DELIM
308
                    while data:
309
                        n = chan.send(data)
310
                        if n <= 0:
311
                            raise SessionCloseError(self._buffer.getvalue(), data)
312
                        data = data[n:]
313
        except Exception as e:
314
            logger.debug('broke out of main loop')
315
            self.close()
316
            if not (isinstance(e, SessionCloseError) and self._expecting_close):
317
                self._dispatch_error(e)
318

    
319
    @property
320
    def transport(self):
321
        """The underlying `paramiko.Transport
322
        <http://www.lag.net/paramiko/docs/paramiko.Transport-class.html>`_
323
        object. This makes it possible to call methods like set_keepalive on it.
324
        """
325
        return self._transport
326

    
327
    @property
328
    def can_pipeline(self):
329
        if 'Cisco' in self._transport.remote_version:
330
            return False
331
        # elif ..
332
        return True