Statistics
| Branch: | Tag: | Revision:

root / ncclient / transport / ssh.py @ 68ac4439

History | View | Annotate | Download (12.4 kB)

1
# Copyright 2009 Shikhar Bhushan
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#    http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

    
15
import os
16
import socket
17
from binascii import hexlify
18
from cStringIO import StringIO
19
from select import select
20

    
21
import paramiko
22

    
23
from errors import AuthenticationError, SessionCloseError, SSHError, SSHUnknownHostError
24
from session import Session
25

    
26
import logging
27
logger = logging.getLogger('ncclient.transport.ssh')
28

    
29
BUF_SIZE = 4096
30
MSG_DELIM = "]]>]]>"
31
TICK = 0.1
32

    
33
def default_unknown_host_cb(host, fingerprint):
34
    """An `unknown host callback` returns :const:`True` if it finds the key
35
    acceptable, and :const:`False` if not.
36

37
    This default callback always returns :const:`False`, which would lead to
38
    :meth:`connect` raising a :exc:`SSHUnknownHost` exception.
39

40
    Supply another valid callback if you need to verify the host key
41
    programatically.
42

43
    :arg host: the hostname that needs to be verified
44
    :type host: string
45

46
    :arg fingerprint: a hex string representing the host key fingerprint
47
    :type fingerprint: string
48
    """
49
    return False
50

    
51
def _colonify(fp):
52
    finga = fp[:2]
53
    for idx  in range(2, len(fp), 2):
54
        finga += ":" + fp[idx:idx+2]
55
    return finga
56

    
57
class SSHSession(Session):
58

    
59
    "Implements a :rfc:`4742` NETCONF session over SSH."
60

    
61
    def __init__(self, capabilities):
62
        Session.__init__(self, capabilities)
63
        self._host_keys = paramiko.HostKeys()
64
        self._transport = None
65
        self._connected = False
66
        self._channel = None
67
        self._buffer = StringIO() # for incoming data
68
        # parsing-related, see _parse()
69
        self._parsing_state = 0
70
        self._parsing_pos = 0
71
    
72
    def _parse(self):
73
        '''Messages ae delimited by MSG_DELIM. The buffer could have grown by a
74
        maximum of BUF_SIZE bytes everytime this method is called. Retains state
75
        across method calls and if a byte has been read it will not be
76
        considered again. '''
77
        delim = MSG_DELIM
78
        n = len(delim) - 1
79
        expect = self._parsing_state
80
        buf = self._buffer
81
        buf.seek(self._parsing_pos)
82
        while True:
83
            x = buf.read(1)
84
            if not x: # done reading
85
                break
86
            elif x == delim[expect]: # what we expected
87
                expect += 1 # expect the next delim char
88
            else:
89
                expect = 0
90
                continue
91
            # loop till last delim char expected, break if other char encountered
92
            for i in range(expect, n):
93
                x = buf.read(1)
94
                if not x: # done reading
95
                    break
96
                if x == delim[expect]: # what we expected
97
                    expect += 1 # expect the next delim char
98
                else:
99
                    expect = 0 # reset
100
                    break
101
            else: # if we didn't break out of the loop, full delim was parsed
102
                msg_till = buf.tell() - n
103
                buf.seek(0)
104
                logger.debug('parsed new message')
105
                self._dispatch_message(buf.read(msg_till).strip())
106
                buf.seek(n+1, os.SEEK_CUR)
107
                rest = buf.read()
108
                buf = StringIO()
109
                buf.write(rest)
110
                buf.seek(0)
111
                expect = 0
112
        self._buffer = buf
113
        self._parsing_state = expect
114
        self._parsing_pos = self._buffer.tell()
115

    
116
    def load_known_hosts(self, filename=None):
117
        """Load host keys from a :file:`known_hosts`-style file. Can be called multiple
118
        times.
119

120
        If *filename* is not specified, looks in the default locations i.e.
121
        :file:`~/.ssh/known_hosts` and :file:`~/ssh/known_hosts` for Windows.
122
        """
123
        if filename is None:
124
            filename = os.path.expanduser('~/.ssh/known_hosts')
125
            try:
126
                self._host_keys.load(filename)
127
            except IOError:
128
                # for windows
129
                filename = os.path.expanduser('~/ssh/known_hosts')
130
                try:
131
                    self._host_keys.load(filename)
132
                except IOError:
133
                    pass
134
        else:
135
            self._host_keys.load(filename)
136

    
137
    def close(self):
138
        if self._transport.is_active():
139
            self._transport.close()
140
        self._connected = False
141

    
142
    def connect(self, host, port=830, timeout=None,
143
                unknown_host_cb=default_unknown_host_cb,
144
                username=None, password=None,
145
                key_filename=None, allow_agent=True, look_for_keys=True):
146
        """Connect via SSH and initialize the NETCONF session. First attempts
147
        the publickey authentication method and then password authentication.
148

149
        To disable attemting publickey authentication altogether, call with
150
        *allow_agent* and *look_for_keys* as :const:`False`.
151

152
        :arg host: the hostname or IP address to connect to
153
        :type host: `string`
154

155
        :arg port: by default 830, but some devices use the default SSH port of 22 so this may need to be specified
156
        :type port: `int`
157

158
        :arg timeout: an optional timeout for the TCP handshake
159
        :type timeout: `int`
160

161
        :arg unknown_host_cb: called when a host key is not recognized
162
        :type unknown_host_cb: see :meth:`signature <ssh.default_unknown_host_cb>`
163

164
        :arg username: the username to use for SSH authentication
165
        :type username: `string`
166

167
        :arg password: the password used if using password authentication, or the passphrase to use for unlocking keys that require it
168
        :type password: `string`
169

170
        :arg key_filename: a filename where a the private key to be used can be found
171
        :type key_filename: `string`
172

173
        :arg allow_agent: enables querying SSH agent (if found) for keys
174
        :type allow_agent: `bool`
175

176
        :arg look_for_keys: enables looking in the usual locations for ssh keys (e.g. :file:`~/.ssh/id_*`)
177
        :type look_for_keys: `bool`
178
        """
179

    
180
        if username is None:
181
            username = getpass.getuser()
182
        
183
        sock = None
184
        for res in socket.getaddrinfo(host, port, socket.AF_UNSPEC, socket.SOCK_STREAM):
185
            af, socktype, proto, canonname, sa = res
186
            try:
187
                sock = socket.socket(af, socktype, proto)
188
                sock.settimeout(timeout)
189
            except socket.error:
190
                continue
191
            try:
192
                sock.connect(sa)
193
            except socket.error:
194
                sock.close()
195
                continue
196
            break
197
        else:
198
            raise SSHError("Could not open socket")
199

    
200
        t = self._transport = paramiko.Transport(sock)
201
        t.set_log_channel(logger.name)
202

    
203
        try:
204
            t.start_client()
205
        except paramiko.SSHException:
206
            raise SSHError('Negotiation failed')
207

    
208
        # host key verification
209
        server_key = t.get_remote_server_key()
210
        known_host = self._host_keys.check(host, server_key)
211

    
212
        fingerprint = _colonify(hexlify(server_key.get_fingerprint()))
213

    
214
        if not known_host and not unknown_host_cb(host, fingerprint):
215
            raise SSHUnknownHostError(host, fingerprint)
216

    
217
        if key_filename is None:
218
            key_filenames = []
219
        elif isinstance(key_filename, basestring):
220
            key_filenames = [ key_filename ]
221
        else:
222
            key_filenames = key_filename
223

    
224
        self._auth(username, password, key_filenames, allow_agent, look_for_keys)
225

    
226
        self._connected = True # there was no error authenticating
227

    
228
        c = self._channel = self._transport.open_session()
229
        c.set_name("netconf")
230
        c.invoke_subsystem("netconf")
231

    
232
        self._post_connect()
233
    
234
    # on the lines of paramiko.SSHClient._auth()
235
    def _auth(self, username, password, key_filenames, allow_agent,
236
              look_for_keys):
237
        saved_exception = None
238

    
239
        for key_filename in key_filenames:
240
            for cls in (paramiko.RSAKey, paramiko.DSSKey):
241
                try:
242
                    key = cls.from_private_key_file(key_filename, password)
243
                    logger.debug("Trying key %s from %s" %
244
                              (hexlify(key.get_fingerprint()), key_filename))
245
                    self._transport.auth_publickey(username, key)
246
                    return
247
                except Exception as e:
248
                    saved_exception = e
249
                    logger.debug(e)
250

    
251
        if allow_agent:
252
            for key in paramiko.Agent().get_keys():
253
                try:
254
                    logger.debug("Trying SSH agent key %s" %
255
                                 hexlify(key.get_fingerprint()))
256
                    self._transport.auth_publickey(username, key)
257
                    return
258
                except Exception as e:
259
                    saved_exception = e
260
                    logger.debug(e)
261

    
262
        keyfiles = []
263
        if look_for_keys:
264
            rsa_key = os.path.expanduser("~/.ssh/id_rsa")
265
            dsa_key = os.path.expanduser("~/.ssh/id_dsa")
266
            if os.path.isfile(rsa_key):
267
                keyfiles.append((paramiko.RSAKey, rsa_key))
268
            if os.path.isfile(dsa_key):
269
                keyfiles.append((paramiko.DSSKey, dsa_key))
270
            # look in ~/ssh/ for windows users:
271
            rsa_key = os.path.expanduser("~/ssh/id_rsa")
272
            dsa_key = os.path.expanduser("~/ssh/id_dsa")
273
            if os.path.isfile(rsa_key):
274
                keyfiles.append((paramiko.RSAKey, rsa_key))
275
            if os.path.isfile(dsa_key):
276
                keyfiles.append((paramiko.DSSKey, dsa_key))
277

    
278
        for cls, filename in keyfiles:
279
            try:
280
                key = cls.from_private_key_file(filename, password)
281
                logger.debug("Trying discovered key %s in %s" %
282
                          (hexlify(key.get_fingerprint()), filename))
283
                self._transport.auth_publickey(username, key)
284
                return
285
            except Exception as e:
286
                saved_exception = e
287
                logger.debug(e)
288

    
289
        if password is not None:
290
            try:
291
                self._transport.auth_password(username, password)
292
                return
293
            except Exception as e:
294
                saved_exception = e
295
                logger.debug(e)
296

    
297
        if saved_exception is not None:
298
            # need pep-3134 to do this right
299
            raise AuthenticationError(repr(saved_exception))
300

    
301
        raise AuthenticationError("No authentication methods available")
302

    
303
    def run(self):
304
        chan = self._channel
305
        chan.setblocking(0)
306
        q = self._q
307
        try:
308
            while True:
309
                # select on a paramiko ssh channel object does not ever return
310
                # it in the writable list, so it channel's don't exactly emulate
311
                # the socket api
312
                r, w, e = select([chan], [], [], TICK)
313
                # will wakeup evey TICK seconds to check if something
314
                # to send, more if something to read (due to select returning
315
                # chan in readable list)
316
                if r:
317
                    data = chan.recv(BUF_SIZE)
318
                    if data:
319
                        self._buffer.write(data)
320
                        self._parse()
321
                    else:
322
                        raise SessionCloseError(self._buffer.getvalue())
323
                if not q.empty() and chan.send_ready():
324
                    logger.debug("Sending message")
325
                    data = q.get() + MSG_DELIM
326
                    while data:
327
                        n = chan.send(data)
328
                        if n <= 0:
329
                            raise SessionCloseError(self._buffer.getvalue(), data)
330
                        data = data[n:]
331
        except Exception as e:
332
            logger.debug("Broke out of main loop, error=%r", e)
333
            self.close()
334
            self._dispatch_error(e)
335

    
336
    @property
337
    def transport(self):
338
        """Underlying `paramiko.Transport
339
        <http://www.lag.net/paramiko/docs/paramiko.Transport-class.html>`_
340
        object. This makes it possible to call methods like set_keepalive on it.
341
        """
342
        return self._transport