Revision 004042be

b/ncclient/session/error.py
17 17
class SessionError(ClientError):
18 18
    pass
19 19

  
20
class RemoteClosedError(SessionError):
21
    
22
    def __init__(self, in_buf, out_buf=None):
23
        SessionError.__init__(self)
24
        self._in_buf, self._out_buf = in_buf, out_buf
25
        
26
    def __str__(self):
27
        msg = 'Session closed by remote endpoint.'
28
        if self._in_buf:
29
            msg += '\nIN_BUFFER: %s' % self._in_buf
30
        if self._out_buf:
31
            msg += '\nOUT_BUFFER: %s' % self._out_buf
32
        return msg
33

  
34
class AuthenticationError(SessionError):
35
    pass
36

  
37 20
class SSHError(SessionError):
38 21
    pass
39 22

  
......
48 31
        return ('Unknown host key [%s] for [%s]' %
49 32
                (hexlify(self.key.get_fingerprint()), self.hostname))
50 33

  
51
class SSHAuthenticationError(AuthenticationError, SSHError):
52
    'wraps a paramiko exception that occured during auth'
53
    
54
    def __init__(self, ex):
55
        self.ex = ex
34
class SSHAuthenticationError(SSHError):
35
    pass
36

  
37
class SSHSessionClosedError(SSHError):
56 38
    
57
    def __repr__(self):
58
        return repr(ex)
39
    def __init__(self, in_buf, out_buf=None):
40
        SessionError.__init__(self, "Unexpected session close.")
41
        self._in_buf, self._out_buf = in_buf, out_buf
42
        
43
    def __str__(self):
44
        msg = SessionError(self).__str__()
45
        if self._in_buf:
46
            msg += '\nIN_BUFFER: %s' % self._in_buf
47
        if self._out_buf:
48
            msg += '\nOUT_BUFFER: %s' % self._out_buf
49
        return msg
b/ncclient/session/ssh.py
22 22

  
23 23
import session
24 24
from . import logger
25
from error import SSHError, SSHUnknownHostError, SSHAuthenticationError, RemoteClosedError
25
from error import SSHError, SSHUnknownHostError, SSHAuthenticationError, SSHSessionClosedError
26 26
from session import Session
27 27

  
28 28
BUF_SIZE = 4096
......
40 40
        self._connected = False
41 41
        self._channel = None
42 42
        self._buffer = StringIO() # for incoming data
43
        # parsing-related, see _fresh_data()
43
        # parsing-related, see _parse()
44 44
        self._parsing_state = 0 
45 45
        self._parsing_pos = 0
46 46
    
47
    def _fresh_data(self):
48
        '''The buffer could have grown by a maximum of BUF_SIZE bytes everytime 
49
        this method is called. Retains state across method calls and if a byte
50
        has been read it will not be parsed again.
47
    def _parse(self):
48
        '''Messages ae delimited by MSG_DELIM. The buffer could have grown by a
49
        maximum of BUF_SIZE bytes everytime this method is called. Retains state
50
        across method calls and if a byte has been read it will not be considered
51
        again.
51 52
        '''
52 53
        delim = MSG_DELIM
53 54
        n = len(delim) - 1
54
        state = self._parsing_state
55
        expect = self._parsing_state
55 56
        buf = self._buffer
56 57
        buf.seek(self._parsing_pos)
57 58
        while True:
58 59
            x = buf.read(1)
59 60
            if not x: # done reading
60 61
                break
61
            elif x == delim[state]:
62
                state += 1
62
            elif x == delim[expect]: # what we expected
63
                expect += 1 # expect the next delim char
63 64
            else:
64 65
                continue
65 66
            # loop till last delim char expected, break if other char encountered
66
            for i in range(state, n):
67
            for i in range(expect, n):
67 68
                x = buf.read(1)
68 69
                if not x: # done reading
69 70
                    break
70
                if x==delim[state]: # what we expected
71
                    state += 1 # expect the next delim char
71
                if x == delim[expect]: # what we expected
72
                    expect += 1 # expect the next delim char
72 73
                else:
73
                    state = 0 # reset
74
                    expect = 0 # reset
74 75
                    break
75
            else: # if we didn't break out of above loop, full delim parsed
76
                till = buf.tell() - n
76
            else: # if we didn't break out of the loop, full delim was parsed
77
                msg_till = buf.tell() - n
77 78
                buf.seek(0)
78
                msg = buf.read(till)
79
                msg = buf.read(msg_till)
79 80
                self.dispatch('received', msg)
80 81
                buf.seek(n+1, os.SEEK_CUR)
81 82
                rest = buf.read()
......
84 85
                buf.seek(0)
85 86
                state = 0
86 87
        self._buffer = buf
87
        self._parsing_state = state
88
        self._parsing_state = expect
88 89
        self._parsing_pos = self._buffer.tell()
89 90
    
90 91
    def load_system_host_keys(self, filename=None):
......
167 168
        c.invoke_subsystem('netconf')
168 169
        c.set_name('netconf')
169 170
        
170
        Session._post_connect(self)
171
        self._post_connect()
171 172
    
172 173
    # on the lines of paramiko.SSHClient._auth()
173 174
    def _auth(self, username, password, key_filenames, allow_agent,
174 175
              look_for_keys):
175 176
        saved_exception = None
176 177
        
177
        allowed = ['publickey', 'keyboard-interactive', 'password']
178
        
179 178
        for key_filename in key_filenames:
180
            if 'publickey' not in allowed:
181
                    break
182 179
            for cls in (paramiko.RSAKey, paramiko.DSSKey):
183 180
                try:
184 181
                    key = cls.from_private_key_file(key_filename, password)
......
186 183
                              (hexlify(key.get_fingerprint()), key_filename))
187 184
                    self._transport.auth_publickey(username, key)
188 185
                    return
189
                except paramiko.BadAuthenticationType as e:
190
                    allowed = e.allowed_types
191
                    logger.debug(e)
192 186
                except Exception as e:
193 187
                    saved_exception = e
194 188
                    logger.debug(e)
195 189
        
196 190
        if allow_agent:
197 191
            for key in paramiko.Agent().get_keys():
198
                if 'publickey' not in allowed:
199
                    break
200 192
                try:
201 193
                    logger.debug('Trying SSH agent key %s' %
202 194
                                 hexlify(key.get_fingerprint()))
203
                    logger.error( self._transport.auth_publickey(username, key) )
195
                    self._transport.auth_publickey(username, key)
204 196
                    return
205
                except paramiko.BadAuthenticationType as e:
206
                    allowed = e.allowed_types
207
                    logger.debug(e)
208 197
                except Exception as e:
209 198
                    saved_exception = e
210 199
                    logger.debug(e)
211 200
        
212 201
        keyfiles = []
213
        if look_for_keys and 'publickey' in allowed:
202
        if look_for_keys:
214 203
            rsa_key = os.path.expanduser('~/.ssh/id_rsa')
215 204
            dsa_key = os.path.expanduser('~/.ssh/id_dsa')
216 205
            if os.path.isfile(rsa_key):
......
230 219
                key = cls.from_private_key_file(filename, password)
231 220
                logger.debug('Trying discovered key %s in %s' %
232 221
                          (hexlify(key.get_fingerprint()), filename))
233
                allowed = self._transport.auth_publickey(username, key)
222
                self._transport.auth_publickey(username, key)
234 223
                return
235 224
            except Exception as e:
236 225
                saved_exception = e
......
245 234
                logger.debug(e)
246 235
        
247 236
        if saved_exception is not None:
248
            raise SSHAuthenticationError(saved_exception)
237
            raise AuthenticationError(repr(saved_exception))
249 238
        
250
        raise SSHAuthenticationError('No authentication methods available')
239
        raise AuthenticationError('No authentication methods available')
251 240
    
252 241
    def run(self):
253 242
        chan = self._channel
......
255 244
        q = self._q
256 245
        try:
257 246
            while True:
258
                # select on a paramiko ssh channel object does not ever
259
                # return it in the writable list, so it does not exactly
260
                # emulate the socket api
247
                # select on a paramiko ssh channel object does not ever return
248
                # it in the writable list, so it channel's don't exactly emulate 
249
                # the socket api
261 250
                r, w, e = select([chan], [], [], TICK)
262 251
                # will wakeup evey TICK seconds to check if something
263
                # to send, more if something to read (due to select returning chan
264
                # in readable list)
252
                # to send, more if something to read (due to select returning
253
                # chan in readable list)
265 254
                if r:
266 255
                    data = chan.recv(BUF_SIZE)
267 256
                    if data:
268 257
                        self._buffer.write(data)
269
                        self._fresh_data()
258
                        self._parse()
270 259
                    else:
271
                        raise RemoteClosedError(self._buffer.getvalue())
260
                        raise SSHSessionClosedError(self._buffer.getvalue())
272 261
                if not q.empty() and chan.send_ready():
273 262
                    data = q.get() + MSG_DELIM
274 263
                    while data:
275 264
                        n = chan.send(data)
276 265
                        if n <= 0:
277
                            raise RemoteClosedError(self._buffer.getvalue(), data)
266
                            raise SSHSessionClosedError(self._buffer.getvalue(), data)
278 267
                        data = data[n:]
279 268
        except Exception as e:
280 269
            self.close()
281 270
            logger.debug('*** broke out of main loop ***')
282 271
            self.dispatch('error', e)
283 272
    
284
    def set_keepalive(self, interval=0):
285
        self._transport.set_keepalive(interval)
273
    @property
274
    def transport(self):
275
        '''Get underlying paramiko.transport object; this is provided so methods
276
        like transport.set_keepalive can be called.
277
        '''
278
        return self._transport

Also available in: Unified diff