root / ncclient / session / ssh.py @ 33a4aa10
History | View | Annotate | Download (10.5 kB)
1 |
# Copyright 2009 Shikhar Bhushan
|
---|---|
2 |
#
|
3 |
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
# you may not use this file except in compliance with the License.
|
5 |
# You may obtain a copy of the License at
|
6 |
#
|
7 |
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
#
|
9 |
# Unless required by applicable law or agreed to in writing, software
|
10 |
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
# See the License for the specific language governing permissions and
|
13 |
# limitations under the License.
|
14 |
|
15 |
import os |
16 |
import socket |
17 |
from binascii import hexlify |
18 |
from cStringIO import StringIO |
19 |
from select import select |
20 |
|
21 |
import paramiko |
22 |
|
23 |
import session |
24 |
from . import logger
|
25 |
from error import SSHError, SSHUnknownHostError, SSHAuthenticationError, RemoteClosedError |
26 |
from session import Session |
27 |
|
28 |
BUF_SIZE = 4096
|
29 |
MSG_DELIM = ']]>]]>'
|
30 |
TICK = 0.1
|
31 |
|
32 |
class SSHSession(Session): |
33 |
|
34 |
def __init__(self): |
35 |
Session.__init__(self)
|
36 |
self._system_host_keys = paramiko.HostKeys()
|
37 |
self._host_keys = paramiko.HostKeys()
|
38 |
self._host_keys_filename = None |
39 |
self._transport = None |
40 |
self._connected = False |
41 |
self._channel = None |
42 |
self._buffer = StringIO() # for incoming data |
43 |
# parsing-related, see _fresh_data()
|
44 |
self._parsing_state = 0 |
45 |
self._parsing_pos = 0 |
46 |
|
47 |
def _fresh_data(self): |
48 |
'''The buffer could have grown by a maximum of BUF_SIZE bytes everytime
|
49 |
this method is called. Retains state across method calls and if a byte
|
50 |
has been read it will not be parsed again.
|
51 |
'''
|
52 |
delim = MSG_DELIM |
53 |
n = len(delim) - 1 |
54 |
state = self._parsing_state
|
55 |
buf = self._buffer
|
56 |
buf.seek(self._parsing_pos)
|
57 |
while True: |
58 |
x = buf.read(1)
|
59 |
if not x: # done reading |
60 |
break
|
61 |
elif x == delim[state]:
|
62 |
state += 1
|
63 |
else:
|
64 |
continue
|
65 |
# loop till last delim char expected, break if other char encountered
|
66 |
for i in range(state, n): |
67 |
x = buf.read(1)
|
68 |
if not x: # done reading |
69 |
break
|
70 |
if x==delim[state]: # what we expected |
71 |
state += 1 # expect the next delim char |
72 |
else:
|
73 |
state = 0 # reset |
74 |
break
|
75 |
else: # if we didn't break out of above loop, full delim parsed |
76 |
till = buf.tell() - n |
77 |
buf.seek(0)
|
78 |
msg = buf.read(till) |
79 |
self.dispatch('received', msg) |
80 |
buf.seek(n+1, os.SEEK_CUR)
|
81 |
rest = buf.read() |
82 |
buf = StringIO() |
83 |
buf.write(rest) |
84 |
buf.seek(0)
|
85 |
state = 0
|
86 |
self._buffer = buf
|
87 |
self._parsing_state = state
|
88 |
self._parsing_pos = self._buffer.tell() |
89 |
|
90 |
def load_system_host_keys(self, filename=None): |
91 |
if filename is None: |
92 |
# try the user's .ssh key file, and mask exceptions
|
93 |
filename = os.path.expanduser('~/.ssh/known_hosts')
|
94 |
try:
|
95 |
self._system_host_keys.load(filename)
|
96 |
except IOError: |
97 |
pass
|
98 |
return
|
99 |
self._system_host_keys.load(filename)
|
100 |
|
101 |
def load_host_keys(self, filename): |
102 |
self._host_keys_filename = filename
|
103 |
self._host_keys.load(filename)
|
104 |
|
105 |
def add_host_key(self, key): |
106 |
self._host_keys.add(key)
|
107 |
|
108 |
def save_host_keys(self, filename): |
109 |
f = open(filename, 'w') |
110 |
for hostname, keys in self._host_keys.iteritems(): |
111 |
for keytype, key in keys.iteritems(): |
112 |
f.write('%s %s %s\n' % (hostname, keytype, key.get_base64()))
|
113 |
f.close() |
114 |
|
115 |
def close(self): |
116 |
if self._transport.is_active(): |
117 |
self._transport.close()
|
118 |
self._connected = False |
119 |
|
120 |
def connect(self, hostname, port=830, timeout=None, |
121 |
unknown_host_cb=None, username=None, password=None, |
122 |
key_filename=None, allow_agent=True, look_for_keys=True): |
123 |
|
124 |
assert(username is not None) |
125 |
|
126 |
for (family, socktype, proto, canonname, sockaddr) in \ |
127 |
socket.getaddrinfo(hostname, port): |
128 |
if socktype==socket.SOCK_STREAM:
|
129 |
af = family |
130 |
addr = sockaddr |
131 |
break
|
132 |
else:
|
133 |
raise SSHError('No suitable address family for %s' % hostname) |
134 |
sock = socket.socket(af, socket.SOCK_STREAM) |
135 |
sock.settimeout(timeout) |
136 |
sock.connect(addr) |
137 |
t = self._transport = paramiko.Transport(sock)
|
138 |
t.set_log_channel(logger.name) |
139 |
|
140 |
try:
|
141 |
t.start_client() |
142 |
except paramiko.SSHException:
|
143 |
raise SSHError('Negotiation failed') |
144 |
|
145 |
# host key verification
|
146 |
server_key = t.get_remote_server_key() |
147 |
known_host = self._host_keys.check(hostname, server_key) or \ |
148 |
self._system_host_keys.check(hostname, server_key)
|
149 |
|
150 |
if unknown_host_cb is None: |
151 |
unknown_host_cb = lambda *args: False |
152 |
if not known_host and not unknown_host_cb(hostname, server_key): |
153 |
raise SSHUnknownHostError(hostname, server_key)
|
154 |
|
155 |
if key_filename is None: |
156 |
key_filenames = [] |
157 |
elif isinstance(key_filename, basestring): |
158 |
key_filenames = [ key_filename ] |
159 |
else:
|
160 |
key_filenames = key_filename |
161 |
|
162 |
self._auth(username, password, key_filenames, allow_agent, look_for_keys)
|
163 |
|
164 |
self._connected = True # there was no error authenticating |
165 |
|
166 |
c = self._channel = self._transport.open_session() |
167 |
c.invoke_subsystem('netconf')
|
168 |
c.set_name('netconf')
|
169 |
|
170 |
Session._post_connect(self)
|
171 |
|
172 |
# on the lines of paramiko.SSHClient._auth()
|
173 |
def _auth(self, username, password, key_filenames, allow_agent, |
174 |
look_for_keys): |
175 |
saved_exception = None
|
176 |
|
177 |
allowed = ['publickey', 'keyboard-interactive', 'password'] |
178 |
|
179 |
for key_filename in key_filenames: |
180 |
if 'publickey' not in allowed: |
181 |
break
|
182 |
for cls in (paramiko.RSAKey, paramiko.DSSKey): |
183 |
try:
|
184 |
key = cls.from_private_key_file(key_filename, password) |
185 |
logger.debug('Trying key %s from %s' %
|
186 |
(hexlify(key.get_fingerprint()), key_filename)) |
187 |
self._transport.auth_publickey(username, key)
|
188 |
return
|
189 |
except paramiko.BadAuthenticationType as e: |
190 |
allowed = e.allowed_types |
191 |
logger.debug(e) |
192 |
except Exception as e: |
193 |
saved_exception = e |
194 |
logger.debug(e) |
195 |
|
196 |
if allow_agent:
|
197 |
for key in paramiko.Agent().get_keys(): |
198 |
if 'publickey' not in allowed: |
199 |
break
|
200 |
try:
|
201 |
logger.debug('Trying SSH agent key %s' %
|
202 |
hexlify(key.get_fingerprint())) |
203 |
logger.error( self._transport.auth_publickey(username, key) )
|
204 |
return
|
205 |
except paramiko.BadAuthenticationType as e: |
206 |
allowed = e.allowed_types |
207 |
logger.debug(e) |
208 |
except Exception as e: |
209 |
saved_exception = e |
210 |
logger.debug(e) |
211 |
|
212 |
keyfiles = [] |
213 |
if look_for_keys and 'publickey' in allowed: |
214 |
rsa_key = os.path.expanduser('~/.ssh/id_rsa')
|
215 |
dsa_key = os.path.expanduser('~/.ssh/id_dsa')
|
216 |
if os.path.isfile(rsa_key):
|
217 |
keyfiles.append((paramiko.RSAKey, rsa_key)) |
218 |
if os.path.isfile(dsa_key):
|
219 |
keyfiles.append((paramiko.DSSKey, dsa_key)) |
220 |
# look in ~/ssh/ for windows users:
|
221 |
rsa_key = os.path.expanduser('~/ssh/id_rsa')
|
222 |
dsa_key = os.path.expanduser('~/ssh/id_dsa')
|
223 |
if os.path.isfile(rsa_key):
|
224 |
keyfiles.append((paramiko.RSAKey, rsa_key)) |
225 |
if os.path.isfile(dsa_key):
|
226 |
keyfiles.append((paramiko.DSSKey, dsa_key)) |
227 |
|
228 |
for cls, filename in keyfiles: |
229 |
try:
|
230 |
key = cls.from_private_key_file(filename, password) |
231 |
logger.debug('Trying discovered key %s in %s' %
|
232 |
(hexlify(key.get_fingerprint()), filename)) |
233 |
allowed = self._transport.auth_publickey(username, key)
|
234 |
return
|
235 |
except Exception as e: |
236 |
saved_exception = e |
237 |
logger.debug(e) |
238 |
|
239 |
if password is not None: |
240 |
try:
|
241 |
self._transport.auth_password(username, password)
|
242 |
return
|
243 |
except Exception as e: |
244 |
saved_exception = e |
245 |
logger.debug(e) |
246 |
|
247 |
if saved_exception is not None: |
248 |
raise SSHAuthenticationError(saved_exception)
|
249 |
|
250 |
raise SSHAuthenticationError('No authentication methods available') |
251 |
|
252 |
def run(self): |
253 |
chan = self._channel
|
254 |
chan.setblocking(0)
|
255 |
q = self._q
|
256 |
try:
|
257 |
while True: |
258 |
# select on a paramiko ssh channel object does not ever
|
259 |
# return it in the writable list, so it does not exactly
|
260 |
# emulate the socket api
|
261 |
r, w, e = select([chan], [], [], TICK) |
262 |
# will wakeup evey TICK seconds to check if something
|
263 |
# to send, more if something to read (due to select returning chan
|
264 |
# in readable list)
|
265 |
if r:
|
266 |
data = chan.recv(BUF_SIZE) |
267 |
if data:
|
268 |
self._buffer.write(data)
|
269 |
self._fresh_data()
|
270 |
else:
|
271 |
raise RemoteClosedError(self._buffer.getvalue()) |
272 |
if not q.empty() and chan.send_ready(): |
273 |
data = q.get() + MSG_DELIM |
274 |
while data:
|
275 |
n = chan.send(data) |
276 |
if n <= 0: |
277 |
raise RemoteClosedError(self._buffer.getvalue(), data) |
278 |
data = data[n:] |
279 |
except Exception as e: |
280 |
self.close()
|
281 |
logger.debug('*** broke out of main loop ***')
|
282 |
self.dispatch('error', e) |
283 |
|
284 |
def set_keepalive(self, interval=0): |
285 |
self._transport.set_keepalive(interval)
|