1 # Copyright 2009 Shikhar Bhushan
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
7 # http://www.apache.org/licenses/LICENSE-2.0
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
17 from binascii import hexlify
18 from cStringIO import StringIO
19 from select import select
24 from errors import AuthenticationError, SessionCloseError, SSHError, SSHUnknownHostError
25 from session import Session
31 class SSHSession(Session):
34 Session.__init__(self)
35 self._host_keys = paramiko.HostKeys()
36 self._system_host_keys = paramiko.HostKeys()
37 self._transport = None
38 self._connected = False
40 self._expecting_close = False
41 self._buffer = StringIO() # for incoming data
42 # parsing-related, see _parse()
43 self._parsing_state = 0
45 logger.debug('[SSHSession object created]')
48 '''Messages ae delimited by MSG_DELIM. The buffer could have grown by a
49 maximum of BUF_SIZE bytes everytime this method is called. Retains state
50 across method calls and if a byte has been read it will not be considered
55 expect = self._parsing_state
57 buf.seek(self._parsing_pos)
60 if not x: # done reading
62 elif x == delim[expect]: # what we expected
63 expect += 1 # expect the next delim char
66 # loop till last delim char expected, break if other char encountered
67 for i in range(expect, n):
69 if not x: # done reading
71 if x == delim[expect]: # what we expected
72 expect += 1 # expect the next delim char
76 else: # if we didn't break out of the loop, full delim was parsed
77 msg_till = buf.tell() - n
79 self._dispatch_received(buf.read(msg_till).strip())
80 buf.seek(n+1, os.SEEK_CUR)
87 self._parsing_state = expect
88 self._parsing_pos = self._buffer.tell()
90 def expect_close(self):
91 self._expecting_close = True
93 def load_system_host_keys(self, filename=None):
95 filename = os.path.expanduser('~/.ssh/known_hosts')
97 self._system_host_keys.load(filename)
100 filename = os.path.expanduser('~/ssh/known_hosts')
102 self._system_host_keys.load(filename)
106 self._system_host_keys.load(filename)
108 def load_host_keys(self, filename):
109 self._host_keys.load(filename)
111 def add_host_key(self, key):
112 self._host_keys.add(key)
114 def save_host_keys(self, filename):
115 f = open(filename, 'w')
116 for hostname, keys in self._host_keys.iteritems():
117 for keytype, key in keys.iteritems():
118 f.write('%s %s %s\n' % (hostname, keytype, key.get_base64()))
122 if self._transport.is_active():
123 self._transport.close()
124 self._connected = False
126 def connect(self, hostname, port=830, timeout=None,
127 unknown_host_cb=None, username=None, password=None,
128 key_filename=None, allow_agent=True, look_for_keys=True):
130 assert(username is not None)
132 for (family, socktype, proto, canonname, sockaddr) in \
133 socket.getaddrinfo(hostname, port):
134 if socktype==socket.SOCK_STREAM:
139 raise SSHError('No suitable address family for %s' % hostname)
140 sock = socket.socket(af, socket.SOCK_STREAM)
141 sock.settimeout(timeout)
143 t = self._transport = paramiko.Transport(sock)
144 t.set_log_channel(logger.name)
148 except paramiko.SSHException:
149 raise SSHError('Negotiation failed')
151 # host key verification
152 server_key = t.get_remote_server_key()
153 known_host = self._host_keys.check(hostname, server_key) or \
154 self._system_host_keys.check(hostname, server_key)
156 if unknown_host_cb is None:
157 unknown_host_cb = lambda *args: False
158 if not known_host and not unknown_host_cb(hostname, server_key):
159 raise SSHUnknownHostError(hostname, server_key)
161 if key_filename is None:
163 elif isinstance(key_filename, basestring):
164 key_filenames = [ key_filename ]
166 key_filenames = key_filename
168 self._auth(username, password, key_filenames, allow_agent, look_for_keys)
170 self._connected = True # there was no error authenticating
172 c = self._channel = self._transport.open_session()
173 c.invoke_subsystem('netconf')
174 c.set_name('netconf')
178 # on the lines of paramiko.SSHClient._auth()
179 def _auth(self, username, password, key_filenames, allow_agent,
181 saved_exception = None
183 for key_filename in key_filenames:
184 for cls in (paramiko.RSAKey, paramiko.DSSKey):
186 key = cls.from_private_key_file(key_filename, password)
187 logger.debug('Trying key %s from %s' %
188 (hexlify(key.get_fingerprint()), key_filename))
189 self._transport.auth_publickey(username, key)
191 except Exception as e:
196 for key in paramiko.Agent().get_keys():
198 logger.debug('Trying SSH agent key %s' %
199 hexlify(key.get_fingerprint()))
200 self._transport.auth_publickey(username, key)
202 except Exception as e:
208 rsa_key = os.path.expanduser('~/.ssh/id_rsa')
209 dsa_key = os.path.expanduser('~/.ssh/id_dsa')
210 if os.path.isfile(rsa_key):
211 keyfiles.append((paramiko.RSAKey, rsa_key))
212 if os.path.isfile(dsa_key):
213 keyfiles.append((paramiko.DSSKey, dsa_key))
214 # look in ~/ssh/ for windows users:
215 rsa_key = os.path.expanduser('~/ssh/id_rsa')
216 dsa_key = os.path.expanduser('~/ssh/id_dsa')
217 if os.path.isfile(rsa_key):
218 keyfiles.append((paramiko.RSAKey, rsa_key))
219 if os.path.isfile(dsa_key):
220 keyfiles.append((paramiko.DSSKey, dsa_key))
222 for cls, filename in keyfiles:
224 key = cls.from_private_key_file(filename, password)
225 logger.debug('Trying discovered key %s in %s' %
226 (hexlify(key.get_fingerprint()), filename))
227 self._transport.auth_publickey(username, key)
229 except Exception as e:
233 if password is not None:
235 self._transport.auth_password(username, password)
237 except Exception as e:
241 if saved_exception is not None:
242 raise SSHAuthenticationError(repr(saved_exception))
244 raise SSHAuthenticationError('No authentication methods available')
252 # select on a paramiko ssh channel object does not ever return
253 # it in the writable list, so it channel's don't exactly emulate
255 r, w, e = select([chan], [], [], TICK)
256 # will wakeup evey TICK seconds to check if something
257 # to send, more if something to read (due to select returning
258 # chan in readable list)
260 data = chan.recv(BUF_SIZE)
262 self._buffer.write(data)
265 raise SessionCloseError(self._buffer.getvalue())
266 if not q.empty() and chan.send_ready():
267 data = q.get() + MSG_DELIM
271 raise SessionCloseError(self._buffer.getvalue(), data)
273 except Exception as e:
274 logger.debug('*** broke out of main loop ***')
276 if not (isinstance(e, SessionCloseError) and self._expecting_close):
277 self._dispatch_error(e)
281 '''Get underlying paramiko transport object; this is provided so methods
282 like set_keepalive can be called on it. See paramiko.Transport
283 documentation for details.
285 return self._transport