1 # Copyright 2009 Shikhar Bhushan
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
7 # http://www.apache.org/licenses/LICENSE-2.0
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
17 from binascii import hexlify
18 from cStringIO import StringIO
19 from select import select
24 from error import SSHError, SSHUnknownHostError, SSHAuthenticationError, SSHSessionClosedError
25 from session import Session
31 class SSHSession(Session):
34 Session.__init__(self)
35 self._host_keys = paramiko.HostKeys()
36 self._system_host_keys = paramiko.HostKeys()
37 self._transport = None
38 self._connected = False
40 self._buffer = StringIO() # for incoming data
41 # parsing-related, see _parse()
42 self._parsing_state = 0
46 '''Messages ae delimited by MSG_DELIM. The buffer could have grown by a
47 maximum of BUF_SIZE bytes everytime this method is called. Retains state
48 across method calls and if a byte has been read it will not be considered
53 expect = self._parsing_state
55 buf.seek(self._parsing_pos)
58 if not x: # done reading
60 elif x == delim[expect]: # what we expected
61 expect += 1 # expect the next delim char
64 # loop till last delim char expected, break if other char encountered
65 for i in range(expect, n):
67 if not x: # done reading
69 if x == delim[expect]: # what we expected
70 expect += 1 # expect the next delim char
74 else: # if we didn't break out of the loop, full delim was parsed
75 msg_till = buf.tell() - n
77 msg = buf.read(msg_till)
78 self.dispatch('received', msg)
79 buf.seek(n+1, os.SEEK_CUR)
86 self._parsing_state = expect
87 self._parsing_pos = self._buffer.tell()
89 def load_system_host_keys(self, filename=None):
91 filename = os.path.expanduser('~/.ssh/known_hosts')
93 self._system_host_keys.load(filename)
96 filename = os.path.expanduser('~/ssh/known_hosts')
98 self._system_host_keys.load(filename)
102 self._system_host_keys.load(filename)
104 def load_host_keys(self, filename):
105 self._host_keys.load(filename)
107 def add_host_key(self, key):
108 self._host_keys.add(key)
110 def save_host_keys(self, filename):
111 f = open(filename, 'w')
112 for hostname, keys in self._host_keys.iteritems():
113 for keytype, key in keys.iteritems():
114 f.write('%s %s %s\n' % (hostname, keytype, key.get_base64()))
118 if self._transport.is_active():
119 self._transport.close()
120 self._connected = False
122 def connect(self, hostname, port=830, timeout=None,
123 unknown_host_cb=None, username=None, password=None,
124 key_filename=None, allow_agent=True, look_for_keys=True):
126 assert(username is not None)
128 for (family, socktype, proto, canonname, sockaddr) in \
129 socket.getaddrinfo(hostname, port):
130 if socktype==socket.SOCK_STREAM:
135 raise SSHError('No suitable address family for %s' % hostname)
136 sock = socket.socket(af, socket.SOCK_STREAM)
137 sock.settimeout(timeout)
139 t = self._transport = paramiko.Transport(sock)
140 t.set_log_channel(logger.name)
144 except paramiko.SSHException:
145 raise SSHError('Negotiation failed')
147 # host key verification
148 server_key = t.get_remote_server_key()
149 known_host = self._host_keys.check(hostname, server_key) or \
150 self._system_host_keys.check(hostname, server_key)
152 if unknown_host_cb is None:
153 unknown_host_cb = lambda *args: False
154 if not known_host and not unknown_host_cb(hostname, server_key):
155 raise SSHUnknownHostError(hostname, server_key)
157 if key_filename is None:
159 elif isinstance(key_filename, basestring):
160 key_filenames = [ key_filename ]
162 key_filenames = key_filename
164 self._auth(username, password, key_filenames, allow_agent, look_for_keys)
166 self._connected = True # there was no error authenticating
168 c = self._channel = self._transport.open_session()
169 c.invoke_subsystem('netconf')
170 c.set_name('netconf')
174 # on the lines of paramiko.SSHClient._auth()
175 def _auth(self, username, password, key_filenames, allow_agent,
177 saved_exception = None
179 for key_filename in key_filenames:
180 for cls in (paramiko.RSAKey, paramiko.DSSKey):
182 key = cls.from_private_key_file(key_filename, password)
183 logger.debug('Trying key %s from %s' %
184 (hexlify(key.get_fingerprint()), key_filename))
185 self._transport.auth_publickey(username, key)
187 except Exception as e:
192 for key in paramiko.Agent().get_keys():
194 logger.debug('Trying SSH agent key %s' %
195 hexlify(key.get_fingerprint()))
196 self._transport.auth_publickey(username, key)
198 except Exception as e:
204 rsa_key = os.path.expanduser('~/.ssh/id_rsa')
205 dsa_key = os.path.expanduser('~/.ssh/id_dsa')
206 if os.path.isfile(rsa_key):
207 keyfiles.append((paramiko.RSAKey, rsa_key))
208 if os.path.isfile(dsa_key):
209 keyfiles.append((paramiko.DSSKey, dsa_key))
210 # look in ~/ssh/ for windows users:
211 rsa_key = os.path.expanduser('~/ssh/id_rsa')
212 dsa_key = os.path.expanduser('~/ssh/id_dsa')
213 if os.path.isfile(rsa_key):
214 keyfiles.append((paramiko.RSAKey, rsa_key))
215 if os.path.isfile(dsa_key):
216 keyfiles.append((paramiko.DSSKey, dsa_key))
218 for cls, filename in keyfiles:
220 key = cls.from_private_key_file(filename, password)
221 logger.debug('Trying discovered key %s in %s' %
222 (hexlify(key.get_fingerprint()), filename))
223 self._transport.auth_publickey(username, key)
225 except Exception as e:
229 if password is not None:
231 self._transport.auth_password(username, password)
233 except Exception as e:
237 if saved_exception is not None:
238 raise SSHAuthenticationError(repr(saved_exception))
240 raise SSHAuthenticationError('No authentication methods available')
248 # select on a paramiko ssh channel object does not ever return
249 # it in the writable list, so it channel's don't exactly emulate
251 r, w, e = select([chan], [], [], TICK)
252 # will wakeup evey TICK seconds to check if something
253 # to send, more if something to read (due to select returning
254 # chan in readable list)
256 data = chan.recv(BUF_SIZE)
258 self._buffer.write(data)
261 raise SSHSessionClosedError(self._buffer.getvalue())
262 if not q.empty() and chan.send_ready():
263 data = q.get() + MSG_DELIM
267 raise SSHSessionClosedError(self._buffer.getvalue(), data)
269 except Exception as e:
271 logger.debug('*** broke out of main loop ***')
272 self.dispatch('error', e)
276 '''Get underlying paramiko.transport object; this is provided so methods
277 like transport.set_keepalive can be called.
279 return self._transport