moving forward
[ncclient] / ncclient / transport / ssh.py
1 # Copyright 2009 Shikhar Bhushan
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 #    http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 import os
16 import socket
17 from binascii import hexlify
18 from cStringIO import StringIO
19 from select import select
20
21 import paramiko
22
23 from errors import AuthenticationError, SessionCloseError, SSHError, SSHUnknownHostError
24 from session import Session
25
26 import logging
27 logger = logging.getLogger('ncclient.transport.ssh')
28
29 BUF_SIZE = 4096
30 MSG_DELIM = ']]>]]>'
31 TICK = 0.1
32
33 class SSHSession(Session):
34     
35     def __init__(self):
36         Session.__init__(self)
37         self._host_keys = paramiko.HostKeys()
38         self._system_host_keys = paramiko.HostKeys()
39         self._transport = None
40         self._connected = False
41         self._channel = None
42         self._expecting_close = False
43         self._buffer = StringIO() # for incoming data
44         # parsing-related, see _parse()
45         self._parsing_state = 0 
46         self._parsing_pos = 0
47     
48     def _parse(self):
49         '''Messages ae delimited by MSG_DELIM. The buffer could have grown by a
50         maximum of BUF_SIZE bytes everytime this method is called. Retains state
51         across method calls and if a byte has been read it will not be considered
52         again.
53         '''
54         delim = MSG_DELIM
55         n = len(delim) - 1
56         expect = self._parsing_state
57         buf = self._buffer
58         buf.seek(self._parsing_pos)
59         while True:
60             x = buf.read(1)
61             if not x: # done reading
62                 break
63             elif x == delim[expect]: # what we expected
64                 expect += 1 # expect the next delim char
65             else:
66                 continue
67             # loop till last delim char expected, break if other char encountered
68             for i in range(expect, n):
69                 x = buf.read(1)
70                 if not x: # done reading
71                     break
72                 if x == delim[expect]: # what we expected
73                     expect += 1 # expect the next delim char
74                 else:
75                     expect = 0 # reset
76                     break
77             else: # if we didn't break out of the loop, full delim was parsed
78                 msg_till = buf.tell() - n
79                 buf.seek(0)
80                 logger.debug('parsed new message')
81                 self._dispatch_message(buf.read(msg_till).strip())
82                 buf.seek(n+1, os.SEEK_CUR)
83                 rest = buf.read()
84                 buf = StringIO()
85                 buf.write(rest)
86                 buf.seek(0)
87                 expect = 0
88         self._buffer = buf
89         self._parsing_state = expect
90         self._parsing_pos = self._buffer.tell()
91     
92     def expect_close(self):
93         self._expecting_close = True
94     
95     def load_system_host_keys(self, filename=None):
96         if filename is None:
97             filename = os.path.expanduser('~/.ssh/known_hosts')
98             try:
99                 self._system_host_keys.load(filename)
100             except IOError:
101                 # for windows
102                 filename = os.path.expanduser('~/ssh/known_hosts')
103                 try:
104                     self._system_host_keys.load(filename)
105                 except IOError:
106                     pass
107             return
108         self._system_host_keys.load(filename)
109     
110     def load_host_keys(self, filename):
111         self._host_keys.load(filename)
112
113     def add_host_key(self, key):
114         self._host_keys.add(key)
115     
116     def save_host_keys(self, filename):
117         f = open(filename, 'w')
118         for hostname, keys in self._host_keys.iteritems():
119             for keytype, key in keys.iteritems():
120                 f.write('%s %s %s\n' % (hostname, keytype, key.get_base64()))
121         f.close()    
122     
123     def close(self):
124         if self._transport.is_active():
125             self._transport.close()
126         self._connected = False
127     
128     def connect(self, hostname, port=830, timeout=None,
129                 unknown_host_cb=None, username=None, password=None,
130                 key_filename=None, allow_agent=True, look_for_keys=True):
131         
132         assert(username is not None)
133         
134         for (family, socktype, proto, canonname, sockaddr) in \
135         socket.getaddrinfo(hostname, port):
136             if socktype==socket.SOCK_STREAM:
137                 af = family
138                 addr = sockaddr
139                 break
140         else:
141             raise SSHError('No suitable address family for %s' % hostname)
142         sock = socket.socket(af, socket.SOCK_STREAM)
143         sock.settimeout(timeout)
144         sock.connect(addr)
145         t = self._transport = paramiko.Transport(sock)
146         t.set_log_channel(logger.name)
147         
148         try:
149             t.start_client()
150         except paramiko.SSHException:
151             raise SSHError('Negotiation failed')
152         
153         # host key verification
154         server_key = t.get_remote_server_key()
155         known_host = self._host_keys.check(hostname, server_key) or \
156                         self._system_host_keys.check(hostname, server_key)
157         
158         if unknown_host_cb is None:
159             unknown_host_cb = lambda *args: False
160         if not known_host and not unknown_host_cb(hostname, server_key):
161                 raise SSHUnknownHostError(hostname, server_key)
162         
163         if key_filename is None:
164             key_filenames = []
165         elif isinstance(key_filename, basestring):
166             key_filenames = [ key_filename ]
167         else:
168             key_filenames = key_filename
169         
170         self._auth(username, password, key_filenames, allow_agent, look_for_keys)
171         
172         self._connected = True # there was no error authenticating
173         
174         c = self._channel = self._transport.open_session()
175         c.set_name('netconf')
176         c.invoke_subsystem('netconf')
177         
178         self._post_connect()
179     
180     # on the lines of paramiko.SSHClient._auth()
181     def _auth(self, username, password, key_filenames, allow_agent,
182               look_for_keys):
183         saved_exception = None
184         
185         for key_filename in key_filenames:
186             for cls in (paramiko.RSAKey, paramiko.DSSKey):
187                 try:
188                     key = cls.from_private_key_file(key_filename, password)
189                     logger.debug('Trying key %s from %s' %
190                               (hexlify(key.get_fingerprint()), key_filename))
191                     self._transport.auth_publickey(username, key)
192                     return
193                 except Exception as e:
194                     saved_exception = e
195                     logger.debug(e)
196         
197         if allow_agent:
198             for key in paramiko.Agent().get_keys():
199                 try:
200                     logger.debug('Trying SSH agent key %s' %
201                                  hexlify(key.get_fingerprint()))
202                     self._transport.auth_publickey(username, key)
203                     return
204                 except Exception as e:
205                     saved_exception = e
206                     logger.debug(e)
207         
208         keyfiles = []
209         if look_for_keys:
210             rsa_key = os.path.expanduser('~/.ssh/id_rsa')
211             dsa_key = os.path.expanduser('~/.ssh/id_dsa')
212             if os.path.isfile(rsa_key):
213                 keyfiles.append((paramiko.RSAKey, rsa_key))
214             if os.path.isfile(dsa_key):
215                 keyfiles.append((paramiko.DSSKey, dsa_key))
216             # look in ~/ssh/ for windows users:
217             rsa_key = os.path.expanduser('~/ssh/id_rsa')
218             dsa_key = os.path.expanduser('~/ssh/id_dsa')
219             if os.path.isfile(rsa_key):
220                 keyfiles.append((paramiko.RSAKey, rsa_key))
221             if os.path.isfile(dsa_key):
222                 keyfiles.append((paramiko.DSSKey, dsa_key))
223         
224         for cls, filename in keyfiles:
225             try:
226                 key = cls.from_private_key_file(filename, password)
227                 logger.debug('Trying discovered key %s in %s' %
228                           (hexlify(key.get_fingerprint()), filename))
229                 self._transport.auth_publickey(username, key)
230                 return
231             except Exception as e:
232                 saved_exception = e
233                 logger.debug(e)
234         
235         if password is not None:
236             try:
237                 self._transport.auth_password(username, password)
238                 return
239             except Exception as e:
240                 saved_exception = e
241                 logger.debug(e)
242         
243         if saved_exception is not None:
244             # need pep-3134 to do this right
245             raise SSHAuthenticationError(repr(saved_exception))
246         
247         raise SSHAuthenticationError('No authentication methods available')
248     
249     def run(self):
250         chan = self._channel
251         chan.setblocking(0)
252         q = self._q
253         try:
254             while True:
255                 # select on a paramiko ssh channel object does not ever return
256                 # it in the writable list, so it channel's don't exactly emulate 
257                 # the socket api
258                 r, w, e = select([chan], [], [], TICK)
259                 # will wakeup evey TICK seconds to check if something
260                 # to send, more if something to read (due to select returning
261                 # chan in readable list)
262                 if r:
263                     data = chan.recv(BUF_SIZE)
264                     if data:
265                         self._buffer.write(data)
266                         self._parse()
267                     else:
268                         raise SessionCloseError(self._buffer.getvalue())
269                 if not q.empty() and chan.send_ready():
270                     logger.debug('sending message')
271                     data = q.get() + MSG_DELIM
272                     while data:
273                         n = chan.send(data)
274                         if n <= 0:
275                             raise SessionCloseError(self._buffer.getvalue(), data)
276                         data = data[n:]
277         except Exception as e:
278             logger.debug('broke out of main loop')
279             self.close()
280             if not (isinstance(e, SessionCloseError) and self._expecting_close):
281                 self._dispatch_error(e)
282     
283     @property
284     def transport(self):
285         '''Get underlying paramiko transport object; this is provided so methods
286         like set_keepalive can be called on it. See paramiko.Transport
287         documentation for details.
288         '''
289         return self._transport
290     
291     @property
292     def can_pipeline(self):
293         if 'Cisco' in self._transport.remote_version:
294             return False
295         # elif ..
296         return True