1 # Copyright 2009 Shikhar Bhushan
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
7 # http://www.apache.org/licenses/LICENSE-2.0
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
17 from binascii import hexlify
18 from cStringIO import StringIO
19 from select import select
23 from errors import AuthenticationError, SessionCloseError, SSHError, SSHUnknownHostError
24 from session import Session
27 logger = logging.getLogger('ncclient.transport.ssh')
33 class SSHSession(Session):
36 Session.__init__(self)
37 self._host_keys = paramiko.HostKeys()
38 self._system_host_keys = paramiko.HostKeys()
39 self._transport = None
40 self._connected = False
42 self._expecting_close = False
43 self._buffer = StringIO() # for incoming data
44 # parsing-related, see _parse()
45 self._parsing_state = 0
49 '''Messages ae delimited by MSG_DELIM. The buffer could have grown by a
50 maximum of BUF_SIZE bytes everytime this method is called. Retains state
51 across method calls and if a byte has been read it will not be considered
56 expect = self._parsing_state
58 buf.seek(self._parsing_pos)
61 if not x: # done reading
63 elif x == delim[expect]: # what we expected
64 expect += 1 # expect the next delim char
67 # loop till last delim char expected, break if other char encountered
68 for i in range(expect, n):
70 if not x: # done reading
72 if x == delim[expect]: # what we expected
73 expect += 1 # expect the next delim char
77 else: # if we didn't break out of the loop, full delim was parsed
78 msg_till = buf.tell() - n
80 logger.debug('parsed new message')
81 self._dispatch_message(buf.read(msg_till).strip())
82 buf.seek(n+1, os.SEEK_CUR)
89 self._parsing_state = expect
90 self._parsing_pos = self._buffer.tell()
92 def expect_close(self):
93 self._expecting_close = True
95 def load_system_host_keys(self, filename=None):
97 filename = os.path.expanduser('~/.ssh/known_hosts')
99 self._system_host_keys.load(filename)
102 filename = os.path.expanduser('~/ssh/known_hosts')
104 self._system_host_keys.load(filename)
108 self._system_host_keys.load(filename)
110 def load_host_keys(self, filename):
111 self._host_keys.load(filename)
113 def add_host_key(self, key):
114 self._host_keys.add(key)
116 def save_host_keys(self, filename):
117 f = open(filename, 'w')
118 for hostname, keys in self._host_keys.iteritems():
119 for keytype, key in keys.iteritems():
120 f.write('%s %s %s\n' % (hostname, keytype, key.get_base64()))
124 if self._transport.is_active():
125 self._transport.close()
126 self._connected = False
128 def connect(self, hostname, port=830, timeout=None,
129 unknown_host_cb=None, username=None, password=None,
130 key_filename=None, allow_agent=True, look_for_keys=True):
132 assert(username is not None)
134 for (family, socktype, proto, canonname, sockaddr) in \
135 socket.getaddrinfo(hostname, port):
136 if socktype==socket.SOCK_STREAM:
141 raise SSHError('No suitable address family for %s' % hostname)
142 sock = socket.socket(af, socket.SOCK_STREAM)
143 sock.settimeout(timeout)
145 t = self._transport = paramiko.Transport(sock)
146 t.set_log_channel(logger.name)
150 except paramiko.SSHException:
151 raise SSHError('Negotiation failed')
153 # host key verification
154 server_key = t.get_remote_server_key()
155 known_host = self._host_keys.check(hostname, server_key) or \
156 self._system_host_keys.check(hostname, server_key)
158 if unknown_host_cb is None:
159 unknown_host_cb = lambda *args: False
160 if not known_host and not unknown_host_cb(hostname, server_key):
161 raise SSHUnknownHostError(hostname, server_key)
163 if key_filename is None:
165 elif isinstance(key_filename, basestring):
166 key_filenames = [ key_filename ]
168 key_filenames = key_filename
170 self._auth(username, password, key_filenames, allow_agent, look_for_keys)
172 self._connected = True # there was no error authenticating
174 c = self._channel = self._transport.open_session()
175 c.set_name('netconf')
176 c.invoke_subsystem('netconf')
180 # on the lines of paramiko.SSHClient._auth()
181 def _auth(self, username, password, key_filenames, allow_agent,
183 saved_exception = None
185 for key_filename in key_filenames:
186 for cls in (paramiko.RSAKey, paramiko.DSSKey):
188 key = cls.from_private_key_file(key_filename, password)
189 logger.debug('Trying key %s from %s' %
190 (hexlify(key.get_fingerprint()), key_filename))
191 self._transport.auth_publickey(username, key)
193 except Exception as e:
198 for key in paramiko.Agent().get_keys():
200 logger.debug('Trying SSH agent key %s' %
201 hexlify(key.get_fingerprint()))
202 self._transport.auth_publickey(username, key)
204 except Exception as e:
210 rsa_key = os.path.expanduser('~/.ssh/id_rsa')
211 dsa_key = os.path.expanduser('~/.ssh/id_dsa')
212 if os.path.isfile(rsa_key):
213 keyfiles.append((paramiko.RSAKey, rsa_key))
214 if os.path.isfile(dsa_key):
215 keyfiles.append((paramiko.DSSKey, dsa_key))
216 # look in ~/ssh/ for windows users:
217 rsa_key = os.path.expanduser('~/ssh/id_rsa')
218 dsa_key = os.path.expanduser('~/ssh/id_dsa')
219 if os.path.isfile(rsa_key):
220 keyfiles.append((paramiko.RSAKey, rsa_key))
221 if os.path.isfile(dsa_key):
222 keyfiles.append((paramiko.DSSKey, dsa_key))
224 for cls, filename in keyfiles:
226 key = cls.from_private_key_file(filename, password)
227 logger.debug('Trying discovered key %s in %s' %
228 (hexlify(key.get_fingerprint()), filename))
229 self._transport.auth_publickey(username, key)
231 except Exception as e:
235 if password is not None:
237 self._transport.auth_password(username, password)
239 except Exception as e:
243 if saved_exception is not None:
244 # need pep-3134 to do this right
245 raise SSHAuthenticationError(repr(saved_exception))
247 raise SSHAuthenticationError('No authentication methods available')
255 # select on a paramiko ssh channel object does not ever return
256 # it in the writable list, so it channel's don't exactly emulate
258 r, w, e = select([chan], [], [], TICK)
259 # will wakeup evey TICK seconds to check if something
260 # to send, more if something to read (due to select returning
261 # chan in readable list)
263 data = chan.recv(BUF_SIZE)
265 self._buffer.write(data)
268 raise SessionCloseError(self._buffer.getvalue())
269 if not q.empty() and chan.send_ready():
270 logger.debug('sending message')
271 data = q.get() + MSG_DELIM
275 raise SessionCloseError(self._buffer.getvalue(), data)
277 except Exception as e:
278 logger.debug('broke out of main loop')
280 if not (isinstance(e, SessionCloseError) and self._expecting_close):
281 self._dispatch_error(e)
285 '''Get underlying paramiko transport object; this is provided so methods
286 like set_keepalive can be called on it. See paramiko.Transport
287 documentation for details.
289 return self._transport
292 def can_pipeline(self):
293 if 'Cisco' in self._transport.remote_version: