Revision a956ef07 ncclient/session/ssh.py

b/ncclient/session/ssh.py
19 19

  
20 20
import paramiko
21 21

  
22

  
23
from session import Session, SessionError
22
from session import Session, SessionError, SessionCloseError
24 23

  
25 24
logger = logging.getLogger('ncclient.ssh')
26 25

  
27

  
28
class SessionCloseError(SessionError):
29
    
30
    def __str__(self):
31
        return 'RECEIVED: %s | UNSENT: %s' % (self._in_buf, self._out_buf)
32
    
33
    def __init__(self, in_buf, out_buf=None):
34
        SessionError.__init__(self)
35
        self._in_buf, self._out_buf = in_buf, out_buf
36

  
26
BUF_SIZE = 4096
27
MSG_DELIM = ']]>]]>'
37 28

  
38 29
class SSHSession(Session):
39 30

  
40
    BUF_SIZE = 4096
41
    MSG_DELIM = ']]>]]>'
42
    
43 31
    def __init__(self, load_known_hosts=True,
44 32
                 missing_host_key_policy=paramiko.RejectPolicy()):
45 33
        Session.__init__(self)
......
57 45
        self._connected = False
58 46
    
59 47
    def _fresh_data(self):
60
        delim = SSHSession.MSG_DELIM
48
        delim = MSG_DELIM
61 49
        n = len(delim) - 1
62 50
        state = self._parsing_state
63 51
        buf = self._in_buf
......
95 83
        self._parsing_state = state
96 84
        self._parsing_pos = self._in_buf.tell()
97 85

  
98
    #def load_host_keys(self, filename):
99
    #    self._client.load_host_keys(filename)
100
    #
101
    #def set_missing_host_key_policy(self, policy):
102
    #    self._client.set_missing_host_key_policy(policy)
103
    #
104
    #def connect(self, hostname, port=830, username=None, password=None,
105
    #            key_filename=None, timeout=None, allow_agent=True,
106
    #            look_for_keys=True):
107
    #    self._client.connect(hostname, port=port, username=username,
108
    #                        password=password, key_filename=key_filename,
109
    #                        timeout=timeout, allow_agent=allow_agent,
110
    #                        look_for_keys=look_for_keys)    
111
    #    transport = self._client.get_transport()
112
    #    self._channel = transport.open_session()
113
    #    self._channel.invoke_subsystem('netconf')
114
    #    self._channel.set_name('netconf')
115
    #    self._connected = True
116
    #    self._post_connect()
86
    def load_host_keys(self, filename):
87
        self._client.load_host_keys(filename)
88

  
89
    def set_missing_host_key_policy(self, policy):
90
        self._client.set_missing_host_key_policy(policy)
117 91

  
118 92
    def connect(self, hostname, port=830, username=None, password=None,
119 93
                key_filename=None, timeout=None, allow_agent=True,
120 94
                look_for_keys=True):
121
        self._transport = paramiko.Transport()
95
        self._client.connect(hostname, port=port, username=username,
96
                            password=password, key_filename=key_filename,
97
                            timeout=timeout, allow_agent=allow_agent,
98
                            look_for_keys=look_for_keys)    
99
        transport = self._client.get_transport()
100
        self._channel = transport.open_session()
101
        self._channel.invoke_subsystem('netconf')
102
        self._channel.set_name('netconf')
103
        self._connected = True
104
        self._post_connect()
105
    
122 106
    
123 107
    def run(self):
124 108
        chan = self._channel
125 109
        chan.setblocking(0)
126 110
        q = self._q
127 111
        try:
128
            while True:    
112
            while True:
129 113
                if chan.closed:
130 114
                    raise SessionCloseError(self._in_buf.getvalue())         
131 115
                if chan.send_ready() and not q.empty():
132
                    data = q.get() + SSHSession.MSG_DELIM
116
                    data = q.get() + MSG_DELIM
133 117
                    while data:
134 118
                        n = chan.send(data)
135 119
                        if n <= 0:
136 120
                            raise SessionCloseError(self._in_buf.getvalue(), data)
137 121
                        data = data[n:]
138 122
                if chan.recv_ready():
139
                    data = chan.recv(SSHSession.BUF_SIZE)
123
                    data = chan.recv(BUF_SIZE)
140 124
                    if data:
141 125
                        self._in_buf.write(data)
142 126
                        self._fresh_data()
......
145 129
        except Exception as e:
146 130
            logger.debug('*** broke out of main loop ***')
147 131
            self.dispatch('error', e)
148

  
149
class MissingHostKeyPolicy(paramiko.MissingHostKeyPolicy):
150
    
151
    def __init__(self, cb):
152
        self._cb = cb
153
    
154
    def missing_host_key(self, client, hostname, key):
155
        if not self._cb(hostname, key):
156
            raise SSHError

Also available in: Unified diff