1 # Copyright 2009 Shikhar Bhushan
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
7 # http://www.apache.org/licenses/LICENSE-2.0
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
17 from binascii import hexlify
18 from cStringIO import StringIO
19 from select import select
23 from errors import AuthenticationError, SessionCloseError, SSHError, SSHUnknownHostError
24 from session import Session
27 logger = logging.getLogger('ncclient.transport.ssh')
33 def default_unknown_host_cb(host, key):
34 """An `unknown host callback` returns :const:`True` if it finds the key
35 acceptable, and :const:`False` if not.
37 This default callback always returns :const:`False`, which would lead to
38 :meth:`connect` raising a :exc:`SSHUnknownHost` exception.
40 Supply another valid callback if you need to verify the host key
43 :arg host: the host for whom key needs to be verified
46 :arg key: a hex string representing the host key fingerprint
52 class SSHSession(Session):
54 "Implements a :rfc:`4742` NETCONF session over SSH."
56 def __init__(self, capabilities):
57 Session.__init__(self, capabilities)
58 self._host_keys = paramiko.HostKeys()
59 self._transport = None
60 self._connected = False
62 self._expecting_close = False
63 self._buffer = StringIO() # for incoming data
64 # parsing-related, see _parse()
65 self._parsing_state = 0
69 '''Messages ae delimited by MSG_DELIM. The buffer could have grown by a
70 maximum of BUF_SIZE bytes everytime this method is called. Retains state
71 across method calls and if a byte has been read it will not be
75 expect = self._parsing_state
77 buf.seek(self._parsing_pos)
80 if not x: # done reading
82 elif x == delim[expect]: # what we expected
83 expect += 1 # expect the next delim char
86 # loop till last delim char expected, break if other char encountered
87 for i in range(expect, n):
89 if not x: # done reading
91 if x == delim[expect]: # what we expected
92 expect += 1 # expect the next delim char
96 else: # if we didn't break out of the loop, full delim was parsed
97 msg_till = buf.tell() - n
99 logger.debug('parsed new message')
100 self._dispatch_message(buf.read(msg_till).strip())
101 buf.seek(n+1, os.SEEK_CUR)
108 self._parsing_state = expect
109 self._parsing_pos = self._buffer.tell()
111 def load_known_hosts(self, filename=None):
112 """Load host keys from a :file:`known_hosts`-style file. Can be called multiple
115 If *filename* is not specified, looks in the default locations i.e.
116 :file:`~/.ssh/known_hosts` and :file:`~/ssh/known_hosts` for Windows.
119 filename = os.path.expanduser('~/.ssh/known_hosts')
121 self._host_keys.load(filename)
124 filename = os.path.expanduser('~/ssh/known_hosts')
126 self._host_keys.load(filename)
130 self._host_keys.load(filename)
133 self._expecting_close = True
134 if self._transport.is_active():
135 self._transport.close()
136 self._connected = False
138 def connect(self, host, port=830, timeout=None,
139 unknown_host_cb=default_unknown_host_cb,
140 username=None, password=None,
141 key_filename=None, allow_agent=True, look_for_keys=True):
142 """Connect via SSH and initialize the NETCONF session. First attempts
143 the publickey authentication method and then password authentication.
145 To disable attemting publickey authentication altogether, call with
146 *allow_agent* and *look_for_keys* as :const:`False`. This may be needed
147 for Cisco devices which immediately disconnect on an incorrect
148 authentication attempt.
150 :arg host: the hostname or IP address to connect to
153 :arg port: by default 830, but some devices use the default SSH port of 22 so this may need to be specified
156 :arg timeout: an optional timeout for the TCP handshake
159 :arg unknown_host_cb: called when a host key is not recognized
160 :type unknown_host_cb: see :meth:`signature <ssh.default_unknown_host_cb>`
162 :arg username: the username to use for SSH authentication
163 :type username: `string`
165 :arg password: the password used if using password authentication, or the passphrase to use for unlocking keys that require it
166 :type password: `string`
168 :arg key_filename: a filename where a the private key to be used can be found
169 :type key_filename: `string`
171 :arg allow_agent: enables querying SSH agent (if found) for keys
172 :type allow_agent: `bool`
174 :arg look_for_keys: enables looking in the usual locations for ssh keys (e.g. :file:`~/.ssh/id_*`)
175 :type look_for_keys: `bool`
178 assert(username is not None)
180 for (family, socktype, proto, canonname, sockaddr) in \
181 socket.getaddrinfo(host, port):
182 if socktype == socket.SOCK_STREAM:
187 raise SSHError('No suitable address family for %s' % host)
188 sock = socket.socket(af, socket.SOCK_STREAM)
189 sock.settimeout(timeout)
191 t = self._transport = paramiko.Transport(sock)
192 t.set_log_channel(logger.name)
196 except paramiko.SSHException:
197 raise SSHError('Negotiation failed')
199 # host key verification
200 server_key = t.get_remote_server_key()
201 known_host = self._host_keys.check(host, server_key)
203 fingerprint = hexlify(server_key.get_fingerprint())
205 if not known_host and not unknown_host_cb(host, fingerprint):
206 raise SSHUnknownHostError(host, fingerprint)
208 if key_filename is None:
210 elif isinstance(key_filename, basestring):
211 key_filenames = [ key_filename ]
213 key_filenames = key_filename
215 self._auth(username, password, key_filenames, allow_agent, look_for_keys)
217 self._connected = True # there was no error authenticating
219 c = self._channel = self._transport.open_session()
220 c.set_name('netconf')
221 c.invoke_subsystem('netconf')
225 # on the lines of paramiko.SSHClient._auth()
226 def _auth(self, username, password, key_filenames, allow_agent,
228 saved_exception = None
230 for key_filename in key_filenames:
231 for cls in (paramiko.RSAKey, paramiko.DSSKey):
233 key = cls.from_private_key_file(key_filename, password)
234 logger.debug('Trying key %s from %s' %
235 (hexlify(key.get_fingerprint()), key_filename))
236 self._transport.auth_publickey(username, key)
238 except Exception as e:
243 for key in paramiko.Agent().get_keys():
245 logger.debug('Trying SSH agent key %s' %
246 hexlify(key.get_fingerprint()))
247 self._transport.auth_publickey(username, key)
249 except Exception as e:
255 rsa_key = os.path.expanduser('~/.ssh/id_rsa')
256 dsa_key = os.path.expanduser('~/.ssh/id_dsa')
257 if os.path.isfile(rsa_key):
258 keyfiles.append((paramiko.RSAKey, rsa_key))
259 if os.path.isfile(dsa_key):
260 keyfiles.append((paramiko.DSSKey, dsa_key))
261 # look in ~/ssh/ for windows users:
262 rsa_key = os.path.expanduser('~/ssh/id_rsa')
263 dsa_key = os.path.expanduser('~/ssh/id_dsa')
264 if os.path.isfile(rsa_key):
265 keyfiles.append((paramiko.RSAKey, rsa_key))
266 if os.path.isfile(dsa_key):
267 keyfiles.append((paramiko.DSSKey, dsa_key))
269 for cls, filename in keyfiles:
271 key = cls.from_private_key_file(filename, password)
272 logger.debug('Trying discovered key %s in %s' %
273 (hexlify(key.get_fingerprint()), filename))
274 self._transport.auth_publickey(username, key)
276 except Exception as e:
280 if password is not None:
282 self._transport.auth_password(username, password)
284 except Exception as e:
288 if saved_exception is not None:
289 # need pep-3134 to do this right
290 raise AuthenticationError(repr(saved_exception))
292 raise AuthenticationError('No authentication methods available')
300 # select on a paramiko ssh channel object does not ever return
301 # it in the writable list, so it channel's don't exactly emulate
303 r, w, e = select([chan], [], [], TICK)
304 # will wakeup evey TICK seconds to check if something
305 # to send, more if something to read (due to select returning
306 # chan in readable list)
308 data = chan.recv(BUF_SIZE)
310 self._buffer.write(data)
313 raise SessionCloseError(self._buffer.getvalue())
314 if not q.empty() and chan.send_ready():
315 logger.debug('sending message')
316 data = q.get() + MSG_DELIM
320 raise SessionCloseError(self._buffer.getvalue(), data)
322 except Exception as e:
323 logger.debug('broke out of main loop')
324 expecting = self._expecting_close
326 logger.debug('error=%r' % e)
327 logger.debug('expecting_close=%r' % expecting)
328 if not (isinstance(e, SessionCloseError) and expecting):
329 self._dispatch_error(e)
333 """Underlying `paramiko.Transport
334 <http://www.lag.net/paramiko/docs/paramiko.Transport-class.html>`_
335 object. This makes it possible to call methods like set_keepalive on it.
337 return self._transport
340 def can_pipeline(self):
341 if 'Cisco' in self._transport.remote_version: