transport layer changes
[ncclient] / ncclient / transport / ssh.py
1 # Copyright 2009 Shikhar Bhushan
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 #    http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 import os
16 import socket
17 from binascii import hexlify
18 from cStringIO import StringIO
19 from select import select
20
21 import paramiko
22
23 from . import logger
24 from errors import SSHError, SSHUnknownHostError, SSHAuthenticationError, SessionCloseError
25 from session import Session
26
27 BUF_SIZE = 4096
28 MSG_DELIM = ']]>]]>'
29 TICK = 0.1
30
31 class SSHSession(Session):
32
33     def __init__(self):
34         Session.__init__(self)
35         self._host_keys = paramiko.HostKeys()
36         self._system_host_keys = paramiko.HostKeys()
37         self._transport = None
38         self._connected = False
39         self._channel = None
40         self._buffer = StringIO() # for incoming data
41         # parsing-related, see _parse()
42         self._parsing_state = 0 
43         self._parsing_pos = 0
44     
45     def _parse(self):
46         '''Messages ae delimited by MSG_DELIM. The buffer could have grown by a
47         maximum of BUF_SIZE bytes everytime this method is called. Retains state
48         across method calls and if a byte has been read it will not be considered
49         again.
50         '''
51         delim = MSG_DELIM
52         n = len(delim) - 1
53         expect = self._parsing_state
54         buf = self._buffer
55         buf.seek(self._parsing_pos)
56         while True:
57             x = buf.read(1)
58             if not x: # done reading
59                 break
60             elif x == delim[expect]: # what we expected
61                 expect += 1 # expect the next delim char
62             else:
63                 continue
64             # loop till last delim char expected, break if other char encountered
65             for i in range(expect, n):
66                 x = buf.read(1)
67                 if not x: # done reading
68                     break
69                 if x == delim[expect]: # what we expected
70                     expect += 1 # expect the next delim char
71                 else:
72                     expect = 0 # reset
73                     break
74             else: # if we didn't break out of the loop, full delim was parsed
75                 msg_till = buf.tell() - n
76                 buf.seek(0)
77                 msg = buf.read(msg_till)
78                 self.dispatch('received', msg)
79                 buf.seek(n+1, os.SEEK_CUR)
80                 rest = buf.read()
81                 buf = StringIO()
82                 buf.write(rest)
83                 buf.seek(0)
84                 expect = 0
85         self._buffer = buf
86         self._parsing_state = expect
87         self._parsing_pos = self._buffer.tell()
88     
89     def load_system_host_keys(self, filename=None):
90         if filename is None:
91             filename = os.path.expanduser('~/.ssh/known_hosts')
92             try:
93                 self._system_host_keys.load(filename)
94             except IOError:
95                 # for windows
96                 filename = os.path.expanduser('~/ssh/known_hosts')
97                 try:
98                     self._system_host_keys.load(filename)
99                 except IOError:
100                     pass
101             return
102         self._system_host_keys.load(filename)
103     
104     def load_host_keys(self, filename):
105         self._host_keys.load(filename)
106
107     def add_host_key(self, key):
108         self._host_keys.add(key)
109     
110     def save_host_keys(self, filename):
111         f = open(filename, 'w')
112         for hostname, keys in self._host_keys.iteritems():
113             for keytype, key in keys.iteritems():
114                 f.write('%s %s %s\n' % (hostname, keytype, key.get_base64()))
115         f.close()    
116     
117     def close(self):
118         if self._transport.is_active():
119             self._transport.close()
120         self._connected = False
121     
122     def connect(self, hostname, port=830, timeout=None,
123                 unknown_host_cb=None, username=None, password=None,
124                 key_filename=None, allow_agent=True, look_for_keys=True):
125         
126         assert(username is not None)
127         
128         for (family, socktype, proto, canonname, sockaddr) in \
129         socket.getaddrinfo(hostname, port):
130             if socktype==socket.SOCK_STREAM:
131                 af = family
132                 addr = sockaddr
133                 break
134         else:
135             raise SSHError('No suitable address family for %s' % hostname)
136         sock = socket.socket(af, socket.SOCK_STREAM)
137         sock.settimeout(timeout)
138         sock.connect(addr)
139         t = self._transport = paramiko.Transport(sock)
140         t.set_log_channel(logger.name)
141         
142         try:
143             t.start_client()
144         except paramiko.SSHException:
145             raise SSHError('Negotiation failed')
146         
147         # host key verification
148         server_key = t.get_remote_server_key()
149         known_host = self._host_keys.check(hostname, server_key) or \
150                         self._system_host_keys.check(hostname, server_key)
151         
152         if unknown_host_cb is None:
153             unknown_host_cb = lambda *args: False
154         if not known_host and not unknown_host_cb(hostname, server_key):
155                 raise SSHUnknownHostError(hostname, server_key)
156         
157         if key_filename is None:
158             key_filenames = []
159         elif isinstance(key_filename, basestring):
160             key_filenames = [ key_filename ]
161         else:
162             key_filenames = key_filename
163         
164         self._auth(username, password, key_filenames, allow_agent, look_for_keys)
165         
166         self._connected = True # there was no error authenticating
167         
168         c = self._channel = self._transport.open_session()
169         c.invoke_subsystem('netconf')
170         c.set_name('netconf')
171         
172         self._post_connect()
173     
174     # on the lines of paramiko.SSHClient._auth()
175     def _auth(self, username, password, key_filenames, allow_agent,
176               look_for_keys):
177         saved_exception = None
178         
179         for key_filename in key_filenames:
180             for cls in (paramiko.RSAKey, paramiko.DSSKey):
181                 try:
182                     key = cls.from_private_key_file(key_filename, password)
183                     logger.debug('Trying key %s from %s' %
184                               (hexlify(key.get_fingerprint()), key_filename))
185                     self._transport.auth_publickey(username, key)
186                     return
187                 except Exception as e:
188                     saved_exception = e
189                     logger.debug(e)
190         
191         if allow_agent:
192             for key in paramiko.Agent().get_keys():
193                 try:
194                     logger.debug('Trying SSH agent key %s' %
195                                  hexlify(key.get_fingerprint()))
196                     self._transport.auth_publickey(username, key)
197                     return
198                 except Exception as e:
199                     saved_exception = e
200                     logger.debug(e)
201         
202         keyfiles = []
203         if look_for_keys:
204             rsa_key = os.path.expanduser('~/.ssh/id_rsa')
205             dsa_key = os.path.expanduser('~/.ssh/id_dsa')
206             if os.path.isfile(rsa_key):
207                 keyfiles.append((paramiko.RSAKey, rsa_key))
208             if os.path.isfile(dsa_key):
209                 keyfiles.append((paramiko.DSSKey, dsa_key))
210             # look in ~/ssh/ for windows users:
211             rsa_key = os.path.expanduser('~/ssh/id_rsa')
212             dsa_key = os.path.expanduser('~/ssh/id_dsa')
213             if os.path.isfile(rsa_key):
214                 keyfiles.append((paramiko.RSAKey, rsa_key))
215             if os.path.isfile(dsa_key):
216                 keyfiles.append((paramiko.DSSKey, dsa_key))
217         
218         for cls, filename in keyfiles:
219             try:
220                 key = cls.from_private_key_file(filename, password)
221                 logger.debug('Trying discovered key %s in %s' %
222                           (hexlify(key.get_fingerprint()), filename))
223                 self._transport.auth_publickey(username, key)
224                 return
225             except Exception as e:
226                 saved_exception = e
227                 logger.debug(e)
228         
229         if password is not None:
230             try:
231                 self._transport.auth_password(username, password)
232                 return
233             except Exception as e:
234                 saved_exception = e
235                 logger.debug(e)
236         
237         if saved_exception is not None:
238             raise SSHAuthenticationError(repr(saved_exception))
239         
240         raise SSHAuthenticationError('No authentication methods available')
241     
242     def run(self):
243         chan = self._channel
244         chan.setblocking(0)
245         q = self._q
246         try:
247             while True:
248                 # select on a paramiko ssh channel object does not ever return
249                 # it in the writable list, so it channel's don't exactly emulate 
250                 # the socket api
251                 r, w, e = select([chan], [], [], TICK)
252                 # will wakeup evey TICK seconds to check if something
253                 # to send, more if something to read (due to select returning
254                 # chan in readable list)
255                 if r:
256                     data = chan.recv(BUF_SIZE)
257                     if data:
258                         self._buffer.write(data)
259                         self._parse()
260                     else:
261                         raise SessionCloseError(self._buffer.getvalue())
262                 if not q.empty() and chan.send_ready():
263                     data = q.get() + MSG_DELIM
264                     while data:
265                         n = chan.send(data)
266                         if n <= 0:
267                             raise SessionCloseError(self._buffer.getvalue(), data)
268                         data = data[n:]
269         except Exception as e:
270             self.close()
271             logger.debug('*** broke out of main loop ***')
272             self.dispatch('error', e)
273     
274     @property
275     def transport(self):
276         '''Get underlying paramiko.transport object; this is provided so methods
277         like transport.set_keepalive can be called.
278         '''
279         return self._transport