Statistics
| Branch: | Tag: | Revision:

root / ncclient / session / ssh.py @ 33a4aa10

History | View | Annotate | Download (10.5 kB)

1
# Copyright 2009 Shikhar Bhushan
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#    http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

    
15
import os
16
import socket
17
from binascii import hexlify
18
from cStringIO import StringIO
19
from select import select
20

    
21
import paramiko
22

    
23
import session
24
from . import logger
25
from error import SSHError, SSHUnknownHostError, SSHAuthenticationError, RemoteClosedError
26
from session import Session
27

    
28
BUF_SIZE = 4096
29
MSG_DELIM = ']]>]]>'
30
TICK = 0.1
31

    
32
class SSHSession(Session):
33

    
34
    def __init__(self):
35
        Session.__init__(self)
36
        self._system_host_keys = paramiko.HostKeys()
37
        self._host_keys = paramiko.HostKeys()
38
        self._host_keys_filename = None
39
        self._transport = None
40
        self._connected = False
41
        self._channel = None
42
        self._buffer = StringIO() # for incoming data
43
        # parsing-related, see _fresh_data()
44
        self._parsing_state = 0 
45
        self._parsing_pos = 0
46
    
47
    def _fresh_data(self):
48
        '''The buffer could have grown by a maximum of BUF_SIZE bytes everytime 
49
        this method is called. Retains state across method calls and if a byte
50
        has been read it will not be parsed again.
51
        '''
52
        delim = MSG_DELIM
53
        n = len(delim) - 1
54
        state = self._parsing_state
55
        buf = self._buffer
56
        buf.seek(self._parsing_pos)
57
        while True:
58
            x = buf.read(1)
59
            if not x: # done reading
60
                break
61
            elif x == delim[state]:
62
                state += 1
63
            else:
64
                continue
65
            # loop till last delim char expected, break if other char encountered
66
            for i in range(state, n):
67
                x = buf.read(1)
68
                if not x: # done reading
69
                    break
70
                if x==delim[state]: # what we expected
71
                    state += 1 # expect the next delim char
72
                else:
73
                    state = 0 # reset
74
                    break
75
            else: # if we didn't break out of above loop, full delim parsed
76
                till = buf.tell() - n
77
                buf.seek(0)
78
                msg = buf.read(till)
79
                self.dispatch('received', msg)
80
                buf.seek(n+1, os.SEEK_CUR)
81
                rest = buf.read()
82
                buf = StringIO()
83
                buf.write(rest)
84
                buf.seek(0)
85
                state = 0
86
        self._buffer = buf
87
        self._parsing_state = state
88
        self._parsing_pos = self._buffer.tell()
89
    
90
    def load_system_host_keys(self, filename=None):
91
        if filename is None:
92
            # try the user's .ssh key file, and mask exceptions
93
            filename = os.path.expanduser('~/.ssh/known_hosts')
94
            try:
95
                self._system_host_keys.load(filename)
96
            except IOError:
97
                pass
98
            return
99
        self._system_host_keys.load(filename)
100
    
101
    def load_host_keys(self, filename):
102
        self._host_keys_filename = filename
103
        self._host_keys.load(filename)
104

    
105
    def add_host_key(self, key):
106
        self._host_keys.add(key)
107
    
108
    def save_host_keys(self, filename):
109
        f = open(filename, 'w')
110
        for hostname, keys in self._host_keys.iteritems():
111
            for keytype, key in keys.iteritems():
112
                f.write('%s %s %s\n' % (hostname, keytype, key.get_base64()))
113
        f.close()    
114
    
115
    def close(self):
116
        if self._transport.is_active():
117
            self._transport.close()
118
        self._connected = False
119
    
120
    def connect(self, hostname, port=830, timeout=None,
121
                unknown_host_cb=None, username=None, password=None,
122
                key_filename=None, allow_agent=True, look_for_keys=True):
123
        
124
        assert(username is not None)
125
        
126
        for (family, socktype, proto, canonname, sockaddr) in \
127
        socket.getaddrinfo(hostname, port):
128
            if socktype==socket.SOCK_STREAM:
129
                af = family
130
                addr = sockaddr
131
                break
132
        else:
133
            raise SSHError('No suitable address family for %s' % hostname)
134
        sock = socket.socket(af, socket.SOCK_STREAM)
135
        sock.settimeout(timeout)
136
        sock.connect(addr)
137
        t = self._transport = paramiko.Transport(sock)
138
        t.set_log_channel(logger.name)
139
        
140
        try:
141
            t.start_client()
142
        except paramiko.SSHException:
143
            raise SSHError('Negotiation failed')
144
        
145
        # host key verification
146
        server_key = t.get_remote_server_key()
147
        known_host = self._host_keys.check(hostname, server_key) or \
148
                        self._system_host_keys.check(hostname, server_key)
149
        
150
        if unknown_host_cb is None:
151
            unknown_host_cb = lambda *args: False
152
        if not known_host and not unknown_host_cb(hostname, server_key):
153
                raise SSHUnknownHostError(hostname, server_key)
154
        
155
        if key_filename is None:
156
            key_filenames = []
157
        elif isinstance(key_filename, basestring):
158
            key_filenames = [ key_filename ]
159
        else:
160
            key_filenames = key_filename
161
        
162
        self._auth(username, password, key_filenames, allow_agent, look_for_keys)
163
        
164
        self._connected = True # there was no error authenticating
165
        
166
        c = self._channel = self._transport.open_session()
167
        c.invoke_subsystem('netconf')
168
        c.set_name('netconf')
169
        
170
        Session._post_connect(self)
171
    
172
    # on the lines of paramiko.SSHClient._auth()
173
    def _auth(self, username, password, key_filenames, allow_agent,
174
              look_for_keys):
175
        saved_exception = None
176
        
177
        allowed = ['publickey', 'keyboard-interactive', 'password']
178
        
179
        for key_filename in key_filenames:
180
            if 'publickey' not in allowed:
181
                    break
182
            for cls in (paramiko.RSAKey, paramiko.DSSKey):
183
                try:
184
                    key = cls.from_private_key_file(key_filename, password)
185
                    logger.debug('Trying key %s from %s' %
186
                              (hexlify(key.get_fingerprint()), key_filename))
187
                    self._transport.auth_publickey(username, key)
188
                    return
189
                except paramiko.BadAuthenticationType as e:
190
                    allowed = e.allowed_types
191
                    logger.debug(e)
192
                except Exception as e:
193
                    saved_exception = e
194
                    logger.debug(e)
195
        
196
        if allow_agent:
197
            for key in paramiko.Agent().get_keys():
198
                if 'publickey' not in allowed:
199
                    break
200
                try:
201
                    logger.debug('Trying SSH agent key %s' %
202
                                 hexlify(key.get_fingerprint()))
203
                    logger.error( self._transport.auth_publickey(username, key) )
204
                    return
205
                except paramiko.BadAuthenticationType as e:
206
                    allowed = e.allowed_types
207
                    logger.debug(e)
208
                except Exception as e:
209
                    saved_exception = e
210
                    logger.debug(e)
211
        
212
        keyfiles = []
213
        if look_for_keys and 'publickey' in allowed:
214
            rsa_key = os.path.expanduser('~/.ssh/id_rsa')
215
            dsa_key = os.path.expanduser('~/.ssh/id_dsa')
216
            if os.path.isfile(rsa_key):
217
                keyfiles.append((paramiko.RSAKey, rsa_key))
218
            if os.path.isfile(dsa_key):
219
                keyfiles.append((paramiko.DSSKey, dsa_key))
220
            # look in ~/ssh/ for windows users:
221
            rsa_key = os.path.expanduser('~/ssh/id_rsa')
222
            dsa_key = os.path.expanduser('~/ssh/id_dsa')
223
            if os.path.isfile(rsa_key):
224
                keyfiles.append((paramiko.RSAKey, rsa_key))
225
            if os.path.isfile(dsa_key):
226
                keyfiles.append((paramiko.DSSKey, dsa_key))
227
        
228
        for cls, filename in keyfiles:
229
            try:
230
                key = cls.from_private_key_file(filename, password)
231
                logger.debug('Trying discovered key %s in %s' %
232
                          (hexlify(key.get_fingerprint()), filename))
233
                allowed = self._transport.auth_publickey(username, key)
234
                return
235
            except Exception as e:
236
                saved_exception = e
237
                logger.debug(e)
238
        
239
        if password is not None:
240
            try:
241
                self._transport.auth_password(username, password)
242
                return
243
            except Exception as e:
244
                saved_exception = e
245
                logger.debug(e)
246
        
247
        if saved_exception is not None:
248
            raise SSHAuthenticationError(saved_exception)
249
        
250
        raise SSHAuthenticationError('No authentication methods available')
251
    
252
    def run(self):
253
        chan = self._channel
254
        chan.setblocking(0)
255
        q = self._q
256
        try:
257
            while True:
258
                # select on a paramiko ssh channel object does not ever
259
                # return it in the writable list, so it does not exactly
260
                # emulate the socket api
261
                r, w, e = select([chan], [], [], TICK)
262
                # will wakeup evey TICK seconds to check if something
263
                # to send, more if something to read (due to select returning chan
264
                # in readable list)
265
                if r:
266
                    data = chan.recv(BUF_SIZE)
267
                    if data:
268
                        self._buffer.write(data)
269
                        self._fresh_data()
270
                    else:
271
                        raise RemoteClosedError(self._buffer.getvalue())
272
                if not q.empty() and chan.send_ready():
273
                    data = q.get() + MSG_DELIM
274
                    while data:
275
                        n = chan.send(data)
276
                        if n <= 0:
277
                            raise RemoteClosedError(self._buffer.getvalue(), data)
278
                        data = data[n:]
279
        except Exception as e:
280
            self.close()
281
            logger.debug('*** broke out of main loop ***')
282
            self.dispatch('error', e)
283
    
284
    def set_keepalive(self, interval=0):
285
        self._transport.set_keepalive(interval)