Statistics
| Branch: | Tag: | Revision:

root / ncclient / transport / ssh.py @ 88e9a79a

History | View | Annotate | Download (10.6 kB)

1
# Copyright 2009 Shikhar Bhushan
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#    http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

    
15
import os
16
import socket
17
from binascii import hexlify
18
from cStringIO import StringIO
19
from select import select
20

    
21
import paramiko
22

    
23
from errors import AuthenticationError, SessionCloseError, SSHError, SSHUnknownHostError
24
from session import Session
25

    
26
import logging
27
logger = logging.getLogger('ncclient.transport.ssh')
28

    
29
BUF_SIZE = 4096
30
MSG_DELIM = ']]>]]>'
31
TICK = 0.1
32

    
33
class SSHSession(Session):
34
    
35
    "A NETCONF SSH session, per :rfc:`4742`"
36
    
37
    def __init__(self, capabilities):
38
        Session.__init__(self, capabilities)
39
        self._host_keys = paramiko.HostKeys()
40
        self._system_host_keys = paramiko.HostKeys()
41
        self._transport = None
42
        self._connected = False
43
        self._channel = None
44
        self._expecting_close = False
45
        self._buffer = StringIO() # for incoming data
46
        # parsing-related, see _parse()
47
        self._parsing_state = 0 
48
        self._parsing_pos = 0
49
    
50
    def _parse(self):
51
        '''Messages ae delimited by MSG_DELIM. The buffer could have grown by a
52
        maximum of BUF_SIZE bytes everytime this method is called. Retains state
53
        across method calls and if a byte has been read it will not be considered
54
        again.
55
        '''
56
        delim = MSG_DELIM
57
        n = len(delim) - 1
58
        expect = self._parsing_state
59
        buf = self._buffer
60
        buf.seek(self._parsing_pos)
61
        while True:
62
            x = buf.read(1)
63
            if not x: # done reading
64
                break
65
            elif x == delim[expect]: # what we expected
66
                expect += 1 # expect the next delim char
67
            else:
68
                continue
69
            # loop till last delim char expected, break if other char encountered
70
            for i in range(expect, n):
71
                x = buf.read(1)
72
                if not x: # done reading
73
                    break
74
                if x == delim[expect]: # what we expected
75
                    expect += 1 # expect the next delim char
76
                else:
77
                    expect = 0 # reset
78
                    break
79
            else: # if we didn't break out of the loop, full delim was parsed
80
                msg_till = buf.tell() - n
81
                buf.seek(0)
82
                logger.debug('parsed new message')
83
                self._dispatch_message(buf.read(msg_till).strip())
84
                buf.seek(n+1, os.SEEK_CUR)
85
                rest = buf.read()
86
                buf = StringIO()
87
                buf.write(rest)
88
                buf.seek(0)
89
                expect = 0
90
        self._buffer = buf
91
        self._parsing_state = expect
92
        self._parsing_pos = self._buffer.tell()
93
    
94
    def load_system_host_keys(self, filename=None):
95
        if filename is None:
96
            filename = os.path.expanduser('~/.ssh/known_hosts')
97
            try:
98
                self._system_host_keys.load(filename)
99
            except IOError:
100
                # for windows
101
                filename = os.path.expanduser('~/ssh/known_hosts')
102
                try:
103
                    self._system_host_keys.load(filename)
104
                except IOError:
105
                    pass
106
            return
107
        self._system_host_keys.load(filename)
108
    
109
    def load_host_keys(self, filename):
110
        self._host_keys.load(filename)
111

    
112
    def add_host_key(self, key):
113
        self._host_keys.add(key)
114
    
115
    def save_host_keys(self, filename):
116
        f = open(filename, 'w')
117
        for host, keys in self._host_keys.iteritems():
118
            for keytype, key in keys.iteritems():
119
                f.write('%s %s %s\n' % (host, keytype, key.get_base64()))
120
        f.close()    
121
    
122
    def close(self):
123
        self._expecting_close = True
124
        if self._transport.is_active():
125
            self._transport.close()
126
        self._connected = False
127
    
128
    def connect(self, host, port=830, timeout=None,
129
                unknown_host_cb=None, username=None, password=None,
130
                key_filename=None, allow_agent=True, look_for_keys=True):
131
        assert(username is not None)
132
        
133
        for (family, socktype, proto, canonname, sockaddr) in \
134
        socket.getaddrinfo(host, port):
135
            if socktype == socket.SOCK_STREAM:
136
                af = family
137
                addr = sockaddr
138
                break
139
        else:
140
            raise SSHError('No suitable address family for %s' % host)
141
        sock = socket.socket(af, socket.SOCK_STREAM)
142
        sock.settimeout(timeout)
143
        sock.connect(addr)
144
        t = self._transport = paramiko.Transport(sock)
145
        t.set_log_channel(logger.name)
146
        
147
        try:
148
            t.start_client()
149
        except paramiko.SSHException:
150
            raise SSHError('Negotiation failed')
151
        
152
        # host key verification
153
        server_key = t.get_remote_server_key()
154
        known_host = self._host_keys.check(host, server_key) or \
155
                        self._system_host_keys.check(host, server_key)
156
        
157
        if unknown_host_cb is None:
158
            unknown_host_cb = lambda *args: False
159
        if not known_host and not unknown_host_cb(host, server_key):
160
                raise SSHUnknownHostError(host, server_key)
161
        
162
        if key_filename is None:
163
            key_filenames = []
164
        elif isinstance(key_filename, basestring):
165
            key_filenames = [ key_filename ]
166
        else:
167
            key_filenames = key_filename
168
        
169
        self._auth(username, password, key_filenames, allow_agent, look_for_keys)
170
        
171
        self._connected = True # there was no error authenticating
172
        
173
        c = self._channel = self._transport.open_session()
174
        c.set_name('netconf')
175
        c.invoke_subsystem('netconf')
176
        
177
        self._post_connect()
178
    
179
    # on the lines of paramiko.SSHClient._auth()
180
    def _auth(self, username, password, key_filenames, allow_agent,
181
              look_for_keys):
182
        saved_exception = None
183
        
184
        for key_filename in key_filenames:
185
            for cls in (paramiko.RSAKey, paramiko.DSSKey):
186
                try:
187
                    key = cls.from_private_key_file(key_filename, password)
188
                    logger.debug('Trying key %s from %s' %
189
                              (hexlify(key.get_fingerprint()), key_filename))
190
                    self._transport.auth_publickey(username, key)
191
                    return
192
                except Exception as e:
193
                    saved_exception = e
194
                    logger.debug(e)
195
        
196
        if allow_agent:
197
            for key in paramiko.Agent().get_keys():
198
                try:
199
                    logger.debug('Trying SSH agent key %s' %
200
                                 hexlify(key.get_fingerprint()))
201
                    self._transport.auth_publickey(username, key)
202
                    return
203
                except Exception as e:
204
                    saved_exception = e
205
                    logger.debug(e)
206
        
207
        keyfiles = []
208
        if look_for_keys:
209
            rsa_key = os.path.expanduser('~/.ssh/id_rsa')
210
            dsa_key = os.path.expanduser('~/.ssh/id_dsa')
211
            if os.path.isfile(rsa_key):
212
                keyfiles.append((paramiko.RSAKey, rsa_key))
213
            if os.path.isfile(dsa_key):
214
                keyfiles.append((paramiko.DSSKey, dsa_key))
215
            # look in ~/ssh/ for windows users:
216
            rsa_key = os.path.expanduser('~/ssh/id_rsa')
217
            dsa_key = os.path.expanduser('~/ssh/id_dsa')
218
            if os.path.isfile(rsa_key):
219
                keyfiles.append((paramiko.RSAKey, rsa_key))
220
            if os.path.isfile(dsa_key):
221
                keyfiles.append((paramiko.DSSKey, dsa_key))
222
        
223
        for cls, filename in keyfiles:
224
            try:
225
                key = cls.from_private_key_file(filename, password)
226
                logger.debug('Trying discovered key %s in %s' %
227
                          (hexlify(key.get_fingerprint()), filename))
228
                self._transport.auth_publickey(username, key)
229
                return
230
            except Exception as e:
231
                saved_exception = e
232
                logger.debug(e)
233
        
234
        if password is not None:
235
            try:
236
                self._transport.auth_password(username, password)
237
                return
238
            except Exception as e:
239
                saved_exception = e
240
                logger.debug(e)
241
        
242
        if saved_exception is not None:
243
            # need pep-3134 to do this right
244
            raise SSHAuthenticationError(repr(saved_exception))
245
        
246
        raise SSHAuthenticationError('No authentication methods available')
247
    
248
    def run(self):
249
        chan = self._channel
250
        chan.setblocking(0)
251
        q = self._q
252
        try:
253
            while True:
254
                # select on a paramiko ssh channel object does not ever return
255
                # it in the writable list, so it channel's don't exactly emulate 
256
                # the socket api
257
                r, w, e = select([chan], [], [], TICK)
258
                # will wakeup evey TICK seconds to check if something
259
                # to send, more if something to read (due to select returning
260
                # chan in readable list)
261
                if r:
262
                    data = chan.recv(BUF_SIZE)
263
                    if data:
264
                        self._buffer.write(data)
265
                        self._parse()
266
                    else:
267
                        raise SessionCloseError(self._buffer.getvalue())
268
                if not q.empty() and chan.send_ready():
269
                    logger.debug('sending message')
270
                    data = q.get() + MSG_DELIM
271
                    while data:
272
                        n = chan.send(data)
273
                        if n <= 0:
274
                            raise SessionCloseError(self._buffer.getvalue(), data)
275
                        data = data[n:]
276
        except Exception as e:
277
            logger.debug('broke out of main loop')
278
            self.close()
279
            if not (isinstance(e, SessionCloseError) and self._expecting_close):
280
                self._dispatch_error(e)
281
    
282
    @property
283
    def transport(self):
284
        "gug"
285
        return self._transport
286
    
287
    @property
288
    def can_pipeline(self):
289
        if 'Cisco' in self._transport.remote_version:
290
            return False
291
        # elif ..
292
        return True