Revision 4f650d54 ncclient/transport/ssh.py

b/ncclient/transport/ssh.py
30 30
MSG_DELIM = ']]>]]>'
31 31
TICK = 0.1
32 32

  
33
def default_unknown_host_cb(host, key):
34
    """An `unknown host callback` returns :const:`True` if it finds the key
35
    acceptable, and :const:`False` if not.
36

  
37
    :arg host: the hostname/address which needs to be verified
38

  
39
    :arg key: a hex string representing the host key fingerprint
40

  
41
    :returns: this default callback always returns :const:`False`
42
    """
43
    return False
44

  
45

  
33 46
class SSHSession(Session):
34
    
35
    "A NETCONF SSH session, per :rfc:`4742`"
36
    
47

  
48
    "Implements a :rfc:`4742` NETCONF session over SSH."
49

  
37 50
    def __init__(self, capabilities):
38 51
        Session.__init__(self, capabilities)
39 52
        self._host_keys = paramiko.HostKeys()
......
44 57
        self._expecting_close = False
45 58
        self._buffer = StringIO() # for incoming data
46 59
        # parsing-related, see _parse()
47
        self._parsing_state = 0 
60
        self._parsing_state = 0
48 61
        self._parsing_pos = 0
49
    
62

  
50 63
    def _parse(self):
51 64
        '''Messages ae delimited by MSG_DELIM. The buffer could have grown by a
52 65
        maximum of BUF_SIZE bytes everytime this method is called. Retains state
53
        across method calls and if a byte has been read it will not be considered
54
        again.
55
        '''
66
        across method calls and if a byte has been read it will not be
67
        considered again. '''
56 68
        delim = MSG_DELIM
57 69
        n = len(delim) - 1
58 70
        expect = self._parsing_state
......
90 102
        self._buffer = buf
91 103
        self._parsing_state = expect
92 104
        self._parsing_pos = self._buffer.tell()
93
    
105

  
94 106
    def load_system_host_keys(self, filename=None):
95 107
        if filename is None:
96 108
            filename = os.path.expanduser('~/.ssh/known_hosts')
......
105 117
                    pass
106 118
            return
107 119
        self._system_host_keys.load(filename)
108
    
120

  
109 121
    def load_host_keys(self, filename):
110 122
        self._host_keys.load(filename)
111 123

  
112 124
    def add_host_key(self, key):
113 125
        self._host_keys.add(key)
114
    
126

  
115 127
    def save_host_keys(self, filename):
116 128
        f = open(filename, 'w')
117 129
        for host, keys in self._host_keys.iteritems():
118 130
            for keytype, key in keys.iteritems():
119 131
                f.write('%s %s %s\n' % (host, keytype, key.get_base64()))
120
        f.close()    
121
    
132
        f.close()
133

  
122 134
    def close(self):
123 135
        self._expecting_close = True
124 136
        if self._transport.is_active():
125 137
            self._transport.close()
126 138
        self._connected = False
127
    
139

  
128 140
    def connect(self, host, port=830, timeout=None,
129
                unknown_host_cb=None, username=None, password=None,
141
                unknown_host_cb=default_unknown_host_cb,
142
                username=None, password=None,
130 143
                key_filename=None, allow_agent=True, look_for_keys=True):
144
        """Connect via SSH and initialize the NETCONF session. First attempts
145
        the publickey authentication method and then password authentication.
146

  
147
        To disable publickey authentication, call with *allow_agent* and
148
        *look_for_keys* as :const:`False`
149

  
150
        :arg host: the hostname or IP address to connect to
151

  
152
        :arg port: by default 830, but some devices use the default SSH port of 22 so this may need to be specified
153

  
154
        :arg timeout: an optional timeout for the TCP handshake
155

  
156
        :arg unknown_host_cb: called when a host key is not known. See :func:`unknown_host_cb` for details on signature
157

  
158
        :arg username: the username to use for SSH authentication
159

  
160
        :arg password: the password used if using password authentication, or the passphrase to use in order to unlock keys that require it
161

  
162
        :arg key_filename: a filename where a the private key to be used can be found
163

  
164
        :arg allow_agent: enables querying SSH agent (if found) for keys
165

  
166
        :arg look_for_keys: enables looking in the usual locations for ssh keys (e.g. :file:`~/.ssh/id_*`)
167
        """
168

  
131 169
        assert(username is not None)
132
        
170

  
133 171
        for (family, socktype, proto, canonname, sockaddr) in \
134 172
        socket.getaddrinfo(host, port):
135 173
            if socktype == socket.SOCK_STREAM:
......
143 181
        sock.connect(addr)
144 182
        t = self._transport = paramiko.Transport(sock)
145 183
        t.set_log_channel(logger.name)
146
        
184

  
147 185
        try:
148 186
            t.start_client()
149 187
        except paramiko.SSHException:
150 188
            raise SSHError('Negotiation failed')
151
        
189

  
152 190
        # host key verification
153 191
        server_key = t.get_remote_server_key()
154 192
        known_host = self._host_keys.check(host, server_key) or \
155 193
                        self._system_host_keys.check(host, server_key)
156
        
157
        if unknown_host_cb is None:
158
            unknown_host_cb = lambda *args: False
159
        if not known_host and not unknown_host_cb(host, server_key):
160
                raise SSHUnknownHostError(host, server_key)
161
        
194

  
195
        fp = hexlify(server_key.get_fingerprint())
196
        if not known_host and not unknown_host_cb(host, fp):
197
            raise SSHUnknownHostError(host, fp)
198

  
162 199
        if key_filename is None:
163 200
            key_filenames = []
164 201
        elif isinstance(key_filename, basestring):
165 202
            key_filenames = [ key_filename ]
166 203
        else:
167 204
            key_filenames = key_filename
168
        
205

  
169 206
        self._auth(username, password, key_filenames, allow_agent, look_for_keys)
170
        
207

  
171 208
        self._connected = True # there was no error authenticating
172
        
209

  
173 210
        c = self._channel = self._transport.open_session()
174 211
        c.set_name('netconf')
175 212
        c.invoke_subsystem('netconf')
176
        
213

  
177 214
        self._post_connect()
178
    
215

  
179 216
    # on the lines of paramiko.SSHClient._auth()
180 217
    def _auth(self, username, password, key_filenames, allow_agent,
181 218
              look_for_keys):
182 219
        saved_exception = None
183
        
220

  
184 221
        for key_filename in key_filenames:
185 222
            for cls in (paramiko.RSAKey, paramiko.DSSKey):
186 223
                try:
......
192 229
                except Exception as e:
193 230
                    saved_exception = e
194 231
                    logger.debug(e)
195
        
232

  
196 233
        if allow_agent:
197 234
            for key in paramiko.Agent().get_keys():
198 235
                try:
......
203 240
                except Exception as e:
204 241
                    saved_exception = e
205 242
                    logger.debug(e)
206
        
243

  
207 244
        keyfiles = []
208 245
        if look_for_keys:
209 246
            rsa_key = os.path.expanduser('~/.ssh/id_rsa')
......
219 256
                keyfiles.append((paramiko.RSAKey, rsa_key))
220 257
            if os.path.isfile(dsa_key):
221 258
                keyfiles.append((paramiko.DSSKey, dsa_key))
222
        
259

  
223 260
        for cls, filename in keyfiles:
224 261
            try:
225 262
                key = cls.from_private_key_file(filename, password)
......
230 267
            except Exception as e:
231 268
                saved_exception = e
232 269
                logger.debug(e)
233
        
270

  
234 271
        if password is not None:
235 272
            try:
236 273
                self._transport.auth_password(username, password)
......
238 275
            except Exception as e:
239 276
                saved_exception = e
240 277
                logger.debug(e)
241
        
278

  
242 279
        if saved_exception is not None:
243 280
            # need pep-3134 to do this right
244 281
            raise SSHAuthenticationError(repr(saved_exception))
245
        
282

  
246 283
        raise SSHAuthenticationError('No authentication methods available')
247
    
284

  
248 285
    def run(self):
249 286
        chan = self._channel
250 287
        chan.setblocking(0)
......
252 289
        try:
253 290
            while True:
254 291
                # select on a paramiko ssh channel object does not ever return
255
                # it in the writable list, so it channel's don't exactly emulate 
292
                # it in the writable list, so it channel's don't exactly emulate
256 293
                # the socket api
257 294
                r, w, e = select([chan], [], [], TICK)
258 295
                # will wakeup evey TICK seconds to check if something
......
278 315
            self.close()
279 316
            if not (isinstance(e, SessionCloseError) and self._expecting_close):
280 317
                self._dispatch_error(e)
281
    
318

  
282 319
    @property
283 320
    def transport(self):
284
        "gug"
321
        """The underlying `paramiko.Transport
322
        <http://www.lag.net/paramiko/docs/paramiko.Transport-class.html>`_
323
        object. This makes it possible to call methods like set_keepalive on it.
324
        """
285 325
        return self._transport
286
    
326

  
287 327
    @property
288 328
    def can_pipeline(self):
289 329
        if 'Cisco' in self._transport.remote_version:

Also available in: Unified diff