Statistics
| Branch: | Tag: | Revision:

root / ncclient / transport / ssh.py @ 4a3d4804

History | View | Annotate | Download (12.1 kB)

1
# Copyright 2009 Shikhar Bhushan
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#    http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

    
15
import os
16
import socket
17
import getpass
18
from binascii import hexlify
19
from cStringIO import StringIO
20
from select import select
21

    
22
import paramiko
23

    
24
from errors import AuthenticationError, SessionCloseError, SSHError, SSHUnknownHostError
25
from session import Session
26

    
27
import logging
28
logger = logging.getLogger("ncclient.transport.ssh")
29

    
30
BUF_SIZE = 4096
31
MSG_DELIM = "]]>]]>"
32
TICK = 0.1
33

    
34
def default_unknown_host_cb(host, fingerprint):
35
    """An unknown host callback returns `True` if it finds the key acceptable, and `False` if not.
36

37
    This default callback always returns `False`, which would lead to :meth:`connect` raising a :exc:`SSHUnknownHost` exception.
38
    
39
    Supply another valid callback if you need to verify the host key programatically.
40

41
    *host* is the hostname that needs to be verified
42

43
    *fingerprint* is a hex string representing the host key fingerprint, colon-delimited e.g. `"4b:69:6c:72:6f:79:20:77:61:73:20:68:65:72:65:21"`
44
    """
45
    return False
46

    
47
def _colonify(fp):
48
    finga = fp[:2]
49
    for idx  in range(2, len(fp), 2):
50
        finga += ":" + fp[idx:idx+2]
51
    return finga
52

    
53
class SSHSession(Session):
54

    
55
    "Implements a :rfc:`4742` NETCONF session over SSH."
56

    
57
    def __init__(self, capabilities):
58
        Session.__init__(self, capabilities)
59
        self._host_keys = paramiko.HostKeys()
60
        self._transport = None
61
        self._connected = False
62
        self._channel = None
63
        self._buffer = StringIO() # for incoming data
64
        # parsing-related, see _parse()
65
        self._parsing_state = 0
66
        self._parsing_pos = 0
67
    
68
    def _parse(self):
69
        "Messages ae delimited by MSG_DELIM. The buffer could have grown by a maximum of BUF_SIZE bytes everytime this method is called. Retains state across method calls and if a byte has been read it will not be considered again."
70
        delim = MSG_DELIM
71
        n = len(delim) - 1
72
        expect = self._parsing_state
73
        buf = self._buffer
74
        buf.seek(self._parsing_pos)
75
        while True:
76
            x = buf.read(1)
77
            if not x: # done reading
78
                break
79
            elif x == delim[expect]: # what we expected
80
                expect += 1 # expect the next delim char
81
            else:
82
                expect = 0
83
                continue
84
            # loop till last delim char expected, break if other char encountered
85
            for i in range(expect, n):
86
                x = buf.read(1)
87
                if not x: # done reading
88
                    break
89
                if x == delim[expect]: # what we expected
90
                    expect += 1 # expect the next delim char
91
                else:
92
                    expect = 0 # reset
93
                    break
94
            else: # if we didn't break out of the loop, full delim was parsed
95
                msg_till = buf.tell() - n
96
                buf.seek(0)
97
                logger.debug('parsed new message')
98
                self._dispatch_message(buf.read(msg_till).strip())
99
                buf.seek(n+1, os.SEEK_CUR)
100
                rest = buf.read()
101
                buf = StringIO()
102
                buf.write(rest)
103
                buf.seek(0)
104
                expect = 0
105
        self._buffer = buf
106
        self._parsing_state = expect
107
        self._parsing_pos = self._buffer.tell()
108

    
109
    def load_known_hosts(self, filename=None):
110
        """Load host keys from an openssh :file:`known_hosts`-style file. Can be called multiple times.
111

112
        If *filename* is not specified, looks in the default locations i.e. :file:`~/.ssh/known_hosts` and :file:`~/ssh/known_hosts` for Windows.
113
        """
114
        if filename is None:
115
            filename = os.path.expanduser('~/.ssh/known_hosts')
116
            try:
117
                self._host_keys.load(filename)
118
            except IOError:
119
                # for windows
120
                filename = os.path.expanduser('~/ssh/known_hosts')
121
                try:
122
                    self._host_keys.load(filename)
123
                except IOError:
124
                    pass
125
        else:
126
            self._host_keys.load(filename)
127

    
128
    def close(self):
129
        if self._transport.is_active():
130
            self._transport.close()
131
        self._connected = False
132

    
133
    # REMEMBER to update transport.rst if sig. changes, since it is hardcoded there
134
    def connect(self, host, port=830, timeout=None, unknown_host_cb=default_unknown_host_cb,
135
                username=None, password=None, key_filename=None, allow_agent=True, look_for_keys=True):
136
        """Connect via SSH and initialize the NETCONF session. First attempts the publickey authentication method and then password authentication.
137

138
        To disable attempting publickey authentication altogether, call with *allow_agent* and *look_for_keys* as `False`.
139

140
        *host* is the hostname or IP address to connect to
141

142
        *port* is by default 830, but some devices use the default SSH port of 22 so this may need to be specified
143

144
        *timeout* is an optional timeout for socket connect
145

146
        *unknown_host_cb* is called when the server host key is not recognized. It takes two arguments, the hostname and the fingerprint (see the signature of :func:`default_unknown_host_cb`)
147

148
        *username* is the username to use for SSH authentication
149

150
        *password* is the password used if using password authentication, or the passphrase to use for unlocking keys that require it
151

152
        *key_filename* is a filename where a the private key to be used can be found
153

154
        *allow_agent* enables querying SSH agent (if found) for keys
155

156
        *look_for_keys* enables looking in the usual locations for ssh keys (e.g. :file:`~/.ssh/id_*`)
157
        """
158
        if username is None:
159
            username = getpass.getuser()
160
        
161
        sock = None
162
        for res in socket.getaddrinfo(host, port, socket.AF_UNSPEC, socket.SOCK_STREAM):
163
            af, socktype, proto, canonname, sa = res
164
            try:
165
                sock = socket.socket(af, socktype, proto)
166
                sock.settimeout(timeout)
167
            except socket.error:
168
                continue
169
            try:
170
                sock.connect(sa)
171
            except socket.error:
172
                sock.close()
173
                continue
174
            break
175
        else:
176
            raise SSHError("Could not open socket to %s:%s" % (host, port))
177

    
178
        t = self._transport = paramiko.Transport(sock)
179
        t.set_log_channel(logger.name)
180

    
181
        try:
182
            t.start_client()
183
        except paramiko.SSHException:
184
            raise SSHError('Negotiation failed')
185

    
186
        # host key verification
187
        server_key = t.get_remote_server_key()
188
        known_host = self._host_keys.check(host, server_key)
189

    
190
        fingerprint = _colonify(hexlify(server_key.get_fingerprint()))
191

    
192
        if not known_host and not unknown_host_cb(host, fingerprint):
193
            raise SSHUnknownHostError(host, fingerprint)
194

    
195
        if key_filename is None:
196
            key_filenames = []
197
        elif isinstance(key_filename, basestring):
198
            key_filenames = [ key_filename ]
199
        else:
200
            key_filenames = key_filename
201

    
202
        self._auth(username, password, key_filenames, allow_agent, look_for_keys)
203

    
204
        self._connected = True # there was no error authenticating
205

    
206
        c = self._channel = self._transport.open_session()
207
        c.set_name("netconf")
208
        c.invoke_subsystem("netconf")
209

    
210
        self._post_connect()
211
    
212
    # on the lines of paramiko.SSHClient._auth()
213
    def _auth(self, username, password, key_filenames, allow_agent,
214
              look_for_keys):
215
        saved_exception = None
216

    
217
        for key_filename in key_filenames:
218
            for cls in (paramiko.RSAKey, paramiko.DSSKey):
219
                try:
220
                    key = cls.from_private_key_file(key_filename, password)
221
                    logger.debug("Trying key %s from %s" %
222
                              (hexlify(key.get_fingerprint()), key_filename))
223
                    self._transport.auth_publickey(username, key)
224
                    return
225
                except Exception as e:
226
                    saved_exception = e
227
                    logger.debug(e)
228

    
229
        if allow_agent:
230
            for key in paramiko.Agent().get_keys():
231
                try:
232
                    logger.debug("Trying SSH agent key %s" %
233
                                 hexlify(key.get_fingerprint()))
234
                    self._transport.auth_publickey(username, key)
235
                    return
236
                except Exception as e:
237
                    saved_exception = e
238
                    logger.debug(e)
239

    
240
        keyfiles = []
241
        if look_for_keys:
242
            rsa_key = os.path.expanduser("~/.ssh/id_rsa")
243
            dsa_key = os.path.expanduser("~/.ssh/id_dsa")
244
            if os.path.isfile(rsa_key):
245
                keyfiles.append((paramiko.RSAKey, rsa_key))
246
            if os.path.isfile(dsa_key):
247
                keyfiles.append((paramiko.DSSKey, dsa_key))
248
            # look in ~/ssh/ for windows users:
249
            rsa_key = os.path.expanduser("~/ssh/id_rsa")
250
            dsa_key = os.path.expanduser("~/ssh/id_dsa")
251
            if os.path.isfile(rsa_key):
252
                keyfiles.append((paramiko.RSAKey, rsa_key))
253
            if os.path.isfile(dsa_key):
254
                keyfiles.append((paramiko.DSSKey, dsa_key))
255

    
256
        for cls, filename in keyfiles:
257
            try:
258
                key = cls.from_private_key_file(filename, password)
259
                logger.debug("Trying discovered key %s in %s" %
260
                          (hexlify(key.get_fingerprint()), filename))
261
                self._transport.auth_publickey(username, key)
262
                return
263
            except Exception as e:
264
                saved_exception = e
265
                logger.debug(e)
266

    
267
        if password is not None:
268
            try:
269
                self._transport.auth_password(username, password)
270
                return
271
            except Exception as e:
272
                saved_exception = e
273
                logger.debug(e)
274

    
275
        if saved_exception is not None:
276
            # need pep-3134 to do this right
277
            raise AuthenticationError(repr(saved_exception))
278

    
279
        raise AuthenticationError("No authentication methods available")
280

    
281
    def run(self):
282
        chan = self._channel
283
        q = self._q
284
        try:
285
            while True:
286
                # select on a paramiko ssh channel object does not ever return it in the writable list, so channels don't exactly emulate the socket api
287
                r, w, e = select([chan], [], [], TICK)
288
                # will wakeup evey TICK seconds to check if something to send, more if something to read (due to select returning chan in readable list)
289
                if r:
290
                    data = chan.recv(BUF_SIZE)
291
                    if data:
292
                        self._buffer.write(data)
293
                        self._parse()
294
                    else:
295
                        raise SessionCloseError(self._buffer.getvalue())
296
                if not q.empty() and chan.send_ready():
297
                    logger.debug("Sending message")
298
                    data = q.get() + MSG_DELIM
299
                    while data:
300
                        n = chan.send(data)
301
                        if n <= 0:
302
                            raise SessionCloseError(self._buffer.getvalue(), data)
303
                        data = data[n:]
304
        except Exception as e:
305
            logger.debug("Broke out of main loop, error=%r", e)
306
            self.close()
307
            self._dispatch_error(e)
308

    
309
    @property
310
    def transport(self):
311
        "Underlying `paramiko.Transport <http://www.lag.net/paramiko/docs/paramiko.Transport-class.html>`_ object. This makes it possible to call methods like :meth:`~paramiko.Transport.set_keepalive` on it."
312
        return self._transport