Statistics
| Branch: | Tag: | Revision:

root / ncclient / session / ssh.py @ f5c75f88

History | View | Annotate | Download (10.6 kB)

1
# Copyright 2009 Shikhar Bhushan
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#    http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

    
15
import os
16
import socket
17
from binascii import hexlify
18
from cStringIO import StringIO
19
from select import select
20

    
21
import paramiko
22

    
23
import session
24
from . import logger
25
from error import SSHError, SSHUnknownHostError, SSHAuthenticationError, RemoteClosedError
26
from session import Session
27

    
28
BUF_SIZE = 4096
29
MSG_DELIM = ']]>]]>'
30
TICK = 0.1
31

    
32
class SSHSession(Session):
33

    
34
    def __init__(self):
35
        Session.__init__(self)
36
        self._system_host_keys = paramiko.HostKeys()
37
        self._host_keys = paramiko.HostKeys()
38
        self._host_keys_filename = None
39
        self._transport = None
40
        self._connected = False
41
        self._channel = None
42
        self._buffer = StringIO() # for incoming data
43
        # parsing-related, see _fresh_data()
44
        self._parsing_state = 0 
45
        self._parsing_pos = 0
46
    
47
    def _fresh_data(self):
48
        '''The buffer could have grown by a maximum of BUF_SIZE bytes everytime 
49
        this method is called. Retains state across method calls and if a byte
50
        has been read it will not be parsed again.
51
        '''
52
        delim = MSG_DELIM
53
        n = len(delim) - 1
54
        state = self._parsing_state
55
        buf = self._buffer
56
        buf.seek(self._parsing_pos)
57
        while True:
58
            x = buf.read(1)
59
            if not x: # done reading
60
                break
61
            elif x == delim[state]:
62
                state += 1
63
            else:
64
                continue
65
            # loop till last delim char expected, break if other char encountered
66
            for i in range(state, n):
67
                x = buf.read(1)
68
                if not x: # done reading
69
                    break
70
                if x==delim[state]: # what we expected
71
                    state += 1 # expect the next delim char
72
                else:
73
                    state = 0 # reset
74
                    break
75
            else: # if we didn't break out of above loop, full delim parsed
76
                till = buf.tell() - n
77
                buf.seek(0)
78
                msg = buf.read(till)
79
                self.dispatch('received', msg)
80
                buf.seek(n+1, os.SEEK_CUR)
81
                rest = buf.read()
82
                buf = StringIO()
83
                buf.write(rest)
84
                buf.seek(0)
85
                state = 0
86
        self._buffer = buf
87
        self._parsing_state = state
88
        self._parsing_pos = self._buffer.tell()
89
    
90
    def load_system_host_keys(self, filename=None):
91
        if filename is None:
92
            # try the user's .ssh key file, and mask exceptions
93
            filename = os.path.expanduser('~/.ssh/known_hosts')
94
            try:
95
                self._system_host_keys.load(filename)
96
            except IOError:
97
                pass
98
            return
99
        self._system_host_keys.load(filename)
100
    
101
    def load_host_keys(self, filename):
102
        self._host_keys_filename = filename
103
        self._host_keys.load(filename)
104

    
105
    def add_host_key(self, key):
106
        self._host_keys.add(key)
107
    
108
    def save_host_keys(self, filename):
109
        f = open(filename, 'w')
110
        for hostname, keys in self._host_keys.iteritems():
111
            for keytype, key in keys.iteritems():
112
                f.write('%s %s %s\n' % (hostname, keytype, key.get_base64()))
113
        f.close()    
114
    
115
    def close(self):
116
        if self._transport.is_active():
117
            self._transport.close()
118
        self._connected = False
119
    
120
    def connect(self, hostname, port=830, timeout=None,
121
                unknown_host_cb=None, username=None, password=None,
122
                key_filename=None, allow_agent=True, look_for_keys=True,
123
                authtypes=['publickey', 'password', 'keyboard-interactive']):
124
        
125
        assert(username is not None)
126
        
127
        for (family, socktype, proto, canonname, sockaddr) in \
128
        socket.getaddrinfo(hostname, port):
129
            if socktype==socket.SOCK_STREAM:
130
                af = family
131
                addr = sockaddr
132
                break
133
        else:
134
            raise SSHError('No suitable address family for %s' % hostname)
135
        sock = socket.socket(af, socket.SOCK_STREAM)
136
        sock.settimeout(timeout)
137
        sock.connect(addr)
138
        t = self._transport = paramiko.Transport(sock)
139
        t.set_log_channel(logger.name)
140
        
141
        try:
142
            t.start_client()
143
        except paramiko.SSHException:
144
            raise SSHError('Negotiation failed')
145
        
146
        # host key verification
147
        server_key = t.get_remote_server_key()
148
        known_host = self._host_keys.check(hostname, server_key) or \
149
                        self._system_host_keys.check(hostname, server_key)
150
        
151
        if unknown_host_cb is None:
152
            unknown_host_cb = lambda *args: False
153
        if not known_host and not unknown_host_cb(hostname, server_key):
154
                raise SSHUnknownHostError(hostname, server_key)
155
        
156
        if key_filename is None:
157
            key_filenames = []
158
        elif isinstance(key_filename, basestring):
159
            key_filenames = [ key_filename ]
160
        else:
161
            key_filenames = key_filename
162
        
163
        self._auth(username, password, key_filenames, allow_agent, look_for_keys)
164
        
165
        self._connected = True # there was no error authenticating
166
        
167
        c = self._channel = self._transport.open_session()
168
        c.invoke_subsystem('netconf')
169
        c.set_name('netconf')
170
        
171
        Session._post_connect(self)
172
    
173
    # on the lines of paramiko.SSHClient._auth()
174
    def _auth(self, username, password, key_filenames, allow_agent,
175
              look_for_keys):
176
        saved_exception = None
177
        
178
        allowed = ['publickey', 'keyboard-interactive', 'password']
179
        
180
        for key_filename in key_filenames:
181
            if 'publickey' not in allowed:
182
                    break
183
            for cls in (paramiko.RSAKey, paramiko.DSSKey):
184
                try:
185
                    key = cls.from_private_key_file(key_filename, password)
186
                    logger.debug('Trying key %s from %s' %
187
                              (hexlify(key.get_fingerprint()), key_filename))
188
                    self._transport.auth_publickey(username, key)
189
                    return
190
                except paramiko.BadAuthenticationType as e:
191
                    allowed = e.allowed_types
192
                    logger.debug(e)
193
                except Exception as e:
194
                    saved_exception = e
195
                    logger.debug(e)
196
        
197
        if allow_agent:
198
            for key in paramiko.Agent().get_keys():
199
                if 'publickey' not in allowed:
200
                    break
201
                try:
202
                    logger.debug('Trying SSH agent key %s' %
203
                                 hexlify(key.get_fingerprint()))
204
                    logger.error( self._transport.auth_publickey(username, key) )
205
                    return
206
                except paramiko.BadAuthenticationType as e:
207
                    allowed = e.allowed_types
208
                    logger.debug(e)
209
                except Exception as e:
210
                    saved_exception = e
211
                    logger.debug(e)
212
        
213
        keyfiles = []
214
        if look_for_keys and 'publickey' in allowed:
215
            rsa_key = os.path.expanduser('~/.ssh/id_rsa')
216
            dsa_key = os.path.expanduser('~/.ssh/id_dsa')
217
            if os.path.isfile(rsa_key):
218
                keyfiles.append((paramiko.RSAKey, rsa_key))
219
            if os.path.isfile(dsa_key):
220
                keyfiles.append((paramiko.DSSKey, dsa_key))
221
            # look in ~/ssh/ for windows users:
222
            rsa_key = os.path.expanduser('~/ssh/id_rsa')
223
            dsa_key = os.path.expanduser('~/ssh/id_dsa')
224
            if os.path.isfile(rsa_key):
225
                keyfiles.append((paramiko.RSAKey, rsa_key))
226
            if os.path.isfile(dsa_key):
227
                keyfiles.append((paramiko.DSSKey, dsa_key))
228
        
229
        for cls, filename in keyfiles:
230
            try:
231
                key = cls.from_private_key_file(filename, password)
232
                logger.debug('Trying discovered key %s in %s' %
233
                          (hexlify(key.get_fingerprint()), filename))
234
                allowed = self._transport.auth_publickey(username, key)
235
                return
236
            except Exception as e:
237
                saved_exception = e
238
                logger.debug(e)
239
        
240
        if password is not None:
241
            try:
242
                self._transport.auth_password(username, password)
243
                return
244
            except Exception as e:
245
                saved_exception = e
246
                logger.debug(e)
247
        
248
        if saved_exception is not None:
249
            raise SSHAuthenticationError(saved_exception)
250
        
251
        raise SSHAuthenticationError('No authentication methods available')
252
    
253
    def run(self):
254
        chan = self._channel
255
        chan.setblocking(0)
256
        q = self._q
257
        try:
258
            while True:
259
                # select on a paramiko ssh channel object does not ever
260
                # return it in the writable list, so it does not exactly
261
                # emulate the socket api
262
                r, w, e = select([chan], [], [], TICK)
263
                # will wakeup evey TICK seconds to check if something
264
                # to send, more if something to read (due to select returning chan
265
                # in readable list)
266
                if r:
267
                    data = chan.recv(BUF_SIZE)
268
                    if data:
269
                        self._buffer.write(data)
270
                        self._fresh_data()
271
                    else:
272
                        raise RemoteClosedError(self._buffer.getvalue())
273
                if not q.empty() and chan.send_ready():
274
                    data = q.get() + MSG_DELIM
275
                    while data:
276
                        n = chan.send(data)
277
                        if n <= 0:
278
                            raise RemoteClosedError(self._buffer.getvalue(), data)
279
                        data = data[n:]
280
        except Exception as e:
281
            self.close()
282
            logger.debug('*** broke out of main loop ***')
283
            self.dispatch('error', e)
284
    
285
    def set_keepalive(self, interval=0):
286
        self._transport.set_keepalive()