git-svn-id: http://ncclient.googlecode.com/svn/trunk@86 6dbcf712-26ac-11de-a2f3-13738...
[ncclient] / ncclient / transport / ssh.py
1 # Copyright 2009 Shikhar Bhushan
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 #    http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 import os
16 import socket
17 from binascii import hexlify
18 from cStringIO import StringIO
19 from select import select
20
21 import paramiko
22
23 from . import logger
24 from errors import AuthenticationError, SessionCloseError, SSHError, SSHUnknownHostError
25 from session import Session
26
27 BUF_SIZE = 4096
28 MSG_DELIM = ']]>]]>'
29 TICK = 0.1
30
31 class SSHSession(Session):
32     
33     def __init__(self):
34         Session.__init__(self)
35         self._host_keys = paramiko.HostKeys()
36         self._system_host_keys = paramiko.HostKeys()
37         self._transport = None
38         self._connected = False
39         self._channel = None
40         self._expecting_close = False
41         self._buffer = StringIO() # for incoming data
42         # parsing-related, see _parse()
43         self._parsing_state = 0 
44         self._parsing_pos = 0
45         logger.debug('[SSHSession object created]')
46     
47     def _parse(self):
48         '''Messages ae delimited by MSG_DELIM. The buffer could have grown by a
49         maximum of BUF_SIZE bytes everytime this method is called. Retains state
50         across method calls and if a byte has been read it will not be considered
51         again.
52         '''
53         delim = MSG_DELIM
54         n = len(delim) - 1
55         expect = self._parsing_state
56         buf = self._buffer
57         buf.seek(self._parsing_pos)
58         while True:
59             x = buf.read(1)
60             if not x: # done reading
61                 break
62             elif x == delim[expect]: # what we expected
63                 expect += 1 # expect the next delim char
64             else:
65                 continue
66             # loop till last delim char expected, break if other char encountered
67             for i in range(expect, n):
68                 x = buf.read(1)
69                 if not x: # done reading
70                     break
71                 if x == delim[expect]: # what we expected
72                     expect += 1 # expect the next delim char
73                 else:
74                     expect = 0 # reset
75                     break
76             else: # if we didn't break out of the loop, full delim was parsed
77                 msg_till = buf.tell() - n
78                 buf.seek(0)
79                 self._dispatch_received(buf.read(msg_till).strip())
80                 buf.seek(n+1, os.SEEK_CUR)
81                 rest = buf.read()
82                 buf = StringIO()
83                 buf.write(rest)
84                 buf.seek(0)
85                 expect = 0
86         self._buffer = buf
87         self._parsing_state = expect
88         self._parsing_pos = self._buffer.tell()
89     
90     def expect_close(self):
91         self._expecting_close = True
92     
93     def load_system_host_keys(self, filename=None):
94         if filename is None:
95             filename = os.path.expanduser('~/.ssh/known_hosts')
96             try:
97                 self._system_host_keys.load(filename)
98             except IOError:
99                 # for windows
100                 filename = os.path.expanduser('~/ssh/known_hosts')
101                 try:
102                     self._system_host_keys.load(filename)
103                 except IOError:
104                     pass
105             return
106         self._system_host_keys.load(filename)
107     
108     def load_host_keys(self, filename):
109         self._host_keys.load(filename)
110
111     def add_host_key(self, key):
112         self._host_keys.add(key)
113     
114     def save_host_keys(self, filename):
115         f = open(filename, 'w')
116         for hostname, keys in self._host_keys.iteritems():
117             for keytype, key in keys.iteritems():
118                 f.write('%s %s %s\n' % (hostname, keytype, key.get_base64()))
119         f.close()    
120     
121     def close(self):
122         if self._transport.is_active():
123             self._transport.close()
124         self._connected = False
125     
126     def connect(self, hostname, port=830, timeout=None,
127                 unknown_host_cb=None, username=None, password=None,
128                 key_filename=None, allow_agent=True, look_for_keys=True):
129         
130         assert(username is not None)
131         
132         for (family, socktype, proto, canonname, sockaddr) in \
133         socket.getaddrinfo(hostname, port):
134             if socktype==socket.SOCK_STREAM:
135                 af = family
136                 addr = sockaddr
137                 break
138         else:
139             raise SSHError('No suitable address family for %s' % hostname)
140         sock = socket.socket(af, socket.SOCK_STREAM)
141         sock.settimeout(timeout)
142         sock.connect(addr)
143         t = self._transport = paramiko.Transport(sock)
144         t.set_log_channel(logger.name)
145         
146         try:
147             t.start_client()
148         except paramiko.SSHException:
149             raise SSHError('Negotiation failed')
150         
151         # host key verification
152         server_key = t.get_remote_server_key()
153         known_host = self._host_keys.check(hostname, server_key) or \
154                         self._system_host_keys.check(hostname, server_key)
155         
156         if unknown_host_cb is None:
157             unknown_host_cb = lambda *args: False
158         if not known_host and not unknown_host_cb(hostname, server_key):
159                 raise SSHUnknownHostError(hostname, server_key)
160         
161         if key_filename is None:
162             key_filenames = []
163         elif isinstance(key_filename, basestring):
164             key_filenames = [ key_filename ]
165         else:
166             key_filenames = key_filename
167         
168         self._auth(username, password, key_filenames, allow_agent, look_for_keys)
169         
170         self._connected = True # there was no error authenticating
171         
172         c = self._channel = self._transport.open_session()
173         c.invoke_subsystem('netconf')
174         c.set_name('netconf')
175         
176         self._post_connect()
177     
178     # on the lines of paramiko.SSHClient._auth()
179     def _auth(self, username, password, key_filenames, allow_agent,
180               look_for_keys):
181         saved_exception = None
182         
183         for key_filename in key_filenames:
184             for cls in (paramiko.RSAKey, paramiko.DSSKey):
185                 try:
186                     key = cls.from_private_key_file(key_filename, password)
187                     logger.debug('Trying key %s from %s' %
188                               (hexlify(key.get_fingerprint()), key_filename))
189                     self._transport.auth_publickey(username, key)
190                     return
191                 except Exception as e:
192                     saved_exception = e
193                     logger.debug(e)
194         
195         if allow_agent:
196             for key in paramiko.Agent().get_keys():
197                 try:
198                     logger.debug('Trying SSH agent key %s' %
199                                  hexlify(key.get_fingerprint()))
200                     self._transport.auth_publickey(username, key)
201                     return
202                 except Exception as e:
203                     saved_exception = e
204                     logger.debug(e)
205         
206         keyfiles = []
207         if look_for_keys:
208             rsa_key = os.path.expanduser('~/.ssh/id_rsa')
209             dsa_key = os.path.expanduser('~/.ssh/id_dsa')
210             if os.path.isfile(rsa_key):
211                 keyfiles.append((paramiko.RSAKey, rsa_key))
212             if os.path.isfile(dsa_key):
213                 keyfiles.append((paramiko.DSSKey, dsa_key))
214             # look in ~/ssh/ for windows users:
215             rsa_key = os.path.expanduser('~/ssh/id_rsa')
216             dsa_key = os.path.expanduser('~/ssh/id_dsa')
217             if os.path.isfile(rsa_key):
218                 keyfiles.append((paramiko.RSAKey, rsa_key))
219             if os.path.isfile(dsa_key):
220                 keyfiles.append((paramiko.DSSKey, dsa_key))
221         
222         for cls, filename in keyfiles:
223             try:
224                 key = cls.from_private_key_file(filename, password)
225                 logger.debug('Trying discovered key %s in %s' %
226                           (hexlify(key.get_fingerprint()), filename))
227                 self._transport.auth_publickey(username, key)
228                 return
229             except Exception as e:
230                 saved_exception = e
231                 logger.debug(e)
232         
233         if password is not None:
234             try:
235                 self._transport.auth_password(username, password)
236                 return
237             except Exception as e:
238                 saved_exception = e
239                 logger.debug(e)
240         
241         if saved_exception is not None:
242             raise SSHAuthenticationError(repr(saved_exception))
243         
244         raise SSHAuthenticationError('No authentication methods available')
245     
246     def run(self):
247         chan = self._channel
248         chan.setblocking(0)
249         q = self._q
250         try:
251             while True:
252                 # select on a paramiko ssh channel object does not ever return
253                 # it in the writable list, so it channel's don't exactly emulate 
254                 # the socket api
255                 r, w, e = select([chan], [], [], TICK)
256                 # will wakeup evey TICK seconds to check if something
257                 # to send, more if something to read (due to select returning
258                 # chan in readable list)
259                 if r:
260                     data = chan.recv(BUF_SIZE)
261                     if data:
262                         self._buffer.write(data)
263                         self._parse()
264                     else:
265                         raise SessionCloseError(self._buffer.getvalue())
266                 if not q.empty() and chan.send_ready():
267                     data = q.get() + MSG_DELIM
268                     while data:
269                         n = chan.send(data)
270                         if n <= 0:
271                             raise SessionCloseError(self._buffer.getvalue(), data)
272                         data = data[n:]
273         except Exception as e:
274             logger.debug('*** broke out of main loop ***')
275             self.close()
276             if not (isinstance(e, SessionCloseError) and self._expecting_close):
277                 self._dispatch_error(e)
278     
279     @property
280     def transport(self):
281         '''Get underlying paramiko transport object; this is provided so methods
282         like set_keepalive can be called on it. See paramiko.Transport
283         documentation for details.
284         '''
285         return self._transport