fixes
[ncclient] / ncclient / transport / ssh.py
1 # Copyright 2009 Shikhar Bhushan
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 #    http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 import os
16 import socket
17 from binascii import hexlify
18 from cStringIO import StringIO
19 from select import select
20
21 import paramiko
22
23 from errors import AuthenticationError, SessionCloseError, SSHError, SSHUnknownHostError
24 from session import Session
25
26 import logging
27 logger = logging.getLogger('ncclient.transport.ssh')
28
29 BUF_SIZE = 4096
30 MSG_DELIM = ']]>]]>'
31 TICK = 0.1
32
33 def default_unknown_host_cb(host, key):
34     """An `unknown host callback` returns :const:`True` if it finds the key
35     acceptable, and :const:`False` if not.
36
37     This default callback always returns :const:`False`, which would lead to
38     :meth:`connect` raising a :exc:`SSHUnknownHost` exception.
39
40     Supply another valid callback if you need to verify the host key
41     programatically.
42
43     :arg host: the host for whom key needs to be verified
44     :type host: string
45
46     :arg key: a hex string representing the host key fingerprint
47     :type key: string
48     """
49     return False
50
51
52 class SSHSession(Session):
53
54     "Implements a :rfc:`4742` NETCONF session over SSH."
55
56     def __init__(self, capabilities):
57         Session.__init__(self, capabilities)
58         self._host_keys = paramiko.HostKeys()
59         self._transport = None
60         self._connected = False
61         self._channel = None
62         self._expecting_close = False
63         self._buffer = StringIO() # for incoming data
64         # parsing-related, see _parse()
65         self._parsing_state = 0
66         self._parsing_pos = 0
67
68     def _parse(self):
69         '''Messages ae delimited by MSG_DELIM. The buffer could have grown by a
70         maximum of BUF_SIZE bytes everytime this method is called. Retains state
71         across method calls and if a byte has been read it will not be
72         considered again. '''
73         delim = MSG_DELIM
74         n = len(delim) - 1
75         expect = self._parsing_state
76         buf = self._buffer
77         buf.seek(self._parsing_pos)
78         while True:
79             x = buf.read(1)
80             if not x: # done reading
81                 break
82             elif x == delim[expect]: # what we expected
83                 expect += 1 # expect the next delim char
84             else:
85                 continue
86             # loop till last delim char expected, break if other char encountered
87             for i in range(expect, n):
88                 x = buf.read(1)
89                 if not x: # done reading
90                     break
91                 if x == delim[expect]: # what we expected
92                     expect += 1 # expect the next delim char
93                 else:
94                     expect = 0 # reset
95                     break
96             else: # if we didn't break out of the loop, full delim was parsed
97                 msg_till = buf.tell() - n
98                 buf.seek(0)
99                 logger.debug('parsed new message')
100                 self._dispatch_message(buf.read(msg_till).strip())
101                 buf.seek(n+1, os.SEEK_CUR)
102                 rest = buf.read()
103                 buf = StringIO()
104                 buf.write(rest)
105                 buf.seek(0)
106                 expect = 0
107         self._buffer = buf
108         self._parsing_state = expect
109         self._parsing_pos = self._buffer.tell()
110
111     def load_known_hosts(self, filename=None):
112         """Load host keys from a :file:`known_hosts`-style file. Can be called multiple
113         times.
114
115         If *filename* is not specified, looks in the default locations i.e.
116         :file:`~/.ssh/known_hosts` and :file:`~/ssh/known_hosts` for Windows.
117         """
118         if filename is None:
119             filename = os.path.expanduser('~/.ssh/known_hosts')
120             try:
121                 self._host_keys.load(filename)
122             except IOError:
123                 # for windows
124                 filename = os.path.expanduser('~/ssh/known_hosts')
125                 try:
126                     self._host_keys.load(filename)
127                 except IOError:
128                     pass
129         else:
130             self._host_keys.load(filename)
131
132     def close(self):
133         self._expecting_close = True
134         if self._transport.is_active():
135             self._transport.close()
136         self._connected = False
137
138     def connect(self, host, port=830, timeout=None,
139                 unknown_host_cb=default_unknown_host_cb,
140                 username=None, password=None,
141                 key_filename=None, allow_agent=True, look_for_keys=True):
142         """Connect via SSH and initialize the NETCONF session. First attempts
143         the publickey authentication method and then password authentication.
144
145         To disable attemting publickey authentication altogether, call with
146         *allow_agent* and *look_for_keys* as :const:`False`. This may be needed
147         for Cisco devices which immediately disconnect on an incorrect
148         authentication attempt.
149
150         :arg host: the hostname or IP address to connect to
151         :type host: `string`
152
153         :arg port: by default 830, but some devices use the default SSH port of 22 so this may need to be specified
154         :type port: `int`
155
156         :arg timeout: an optional timeout for the TCP handshake
157         :type timeout: `int`
158
159         :arg unknown_host_cb: called when a host key is not recognized
160         :type unknown_host_cb: see :meth:`signature <ssh.default_unknown_host_cb>`
161
162         :arg username: the username to use for SSH authentication
163         :type username: `string`
164
165         :arg password: the password used if using password authentication, or the passphrase to use for unlocking keys that require it
166         :type password: `string`
167
168         :arg key_filename: a filename where a the private key to be used can be found
169         :type key_filename: `string`
170
171         :arg allow_agent: enables querying SSH agent (if found) for keys
172         :type allow_agent: `bool`
173
174         :arg look_for_keys: enables looking in the usual locations for ssh keys (e.g. :file:`~/.ssh/id_*`)
175         :type look_for_keys: `bool`
176         """
177
178         assert(username is not None)
179
180         for (family, socktype, proto, canonname, sockaddr) in \
181         socket.getaddrinfo(host, port):
182             if socktype == socket.SOCK_STREAM:
183                 af = family
184                 addr = sockaddr
185                 break
186         else:
187             raise SSHError('No suitable address family for %s' % host)
188         sock = socket.socket(af, socket.SOCK_STREAM)
189         sock.settimeout(timeout)
190         sock.connect(addr)
191         t = self._transport = paramiko.Transport(sock)
192         t.set_log_channel(logger.name)
193
194         try:
195             t.start_client()
196         except paramiko.SSHException:
197             raise SSHError('Negotiation failed')
198
199         # host key verification
200         server_key = t.get_remote_server_key()
201         known_host = self._host_keys.check(host, server_key)
202
203         fingerprint = hexlify(server_key.get_fingerprint())
204
205         if not known_host and not unknown_host_cb(host, fingerprint):
206             raise SSHUnknownHostError(host, fingerprint)
207
208         if key_filename is None:
209             key_filenames = []
210         elif isinstance(key_filename, basestring):
211             key_filenames = [ key_filename ]
212         else:
213             key_filenames = key_filename
214
215         self._auth(username, password, key_filenames, allow_agent, look_for_keys)
216
217         self._connected = True # there was no error authenticating
218
219         c = self._channel = self._transport.open_session()
220         c.set_name('netconf')
221         c.invoke_subsystem('netconf')
222
223         self._post_connect()
224
225     # on the lines of paramiko.SSHClient._auth()
226     def _auth(self, username, password, key_filenames, allow_agent,
227               look_for_keys):
228         saved_exception = None
229
230         for key_filename in key_filenames:
231             for cls in (paramiko.RSAKey, paramiko.DSSKey):
232                 try:
233                     key = cls.from_private_key_file(key_filename, password)
234                     logger.debug('Trying key %s from %s' %
235                               (hexlify(key.get_fingerprint()), key_filename))
236                     self._transport.auth_publickey(username, key)
237                     return
238                 except Exception as e:
239                     saved_exception = e
240                     logger.debug(e)
241
242         if allow_agent:
243             for key in paramiko.Agent().get_keys():
244                 try:
245                     logger.debug('Trying SSH agent key %s' %
246                                  hexlify(key.get_fingerprint()))
247                     self._transport.auth_publickey(username, key)
248                     return
249                 except Exception as e:
250                     saved_exception = e
251                     logger.debug(e)
252
253         keyfiles = []
254         if look_for_keys:
255             rsa_key = os.path.expanduser('~/.ssh/id_rsa')
256             dsa_key = os.path.expanduser('~/.ssh/id_dsa')
257             if os.path.isfile(rsa_key):
258                 keyfiles.append((paramiko.RSAKey, rsa_key))
259             if os.path.isfile(dsa_key):
260                 keyfiles.append((paramiko.DSSKey, dsa_key))
261             # look in ~/ssh/ for windows users:
262             rsa_key = os.path.expanduser('~/ssh/id_rsa')
263             dsa_key = os.path.expanduser('~/ssh/id_dsa')
264             if os.path.isfile(rsa_key):
265                 keyfiles.append((paramiko.RSAKey, rsa_key))
266             if os.path.isfile(dsa_key):
267                 keyfiles.append((paramiko.DSSKey, dsa_key))
268
269         for cls, filename in keyfiles:
270             try:
271                 key = cls.from_private_key_file(filename, password)
272                 logger.debug('Trying discovered key %s in %s' %
273                           (hexlify(key.get_fingerprint()), filename))
274                 self._transport.auth_publickey(username, key)
275                 return
276             except Exception as e:
277                 saved_exception = e
278                 logger.debug(e)
279
280         if password is not None:
281             try:
282                 self._transport.auth_password(username, password)
283                 return
284             except Exception as e:
285                 saved_exception = e
286                 logger.debug(e)
287
288         if saved_exception is not None:
289             # need pep-3134 to do this right
290             raise AuthenticationError(repr(saved_exception))
291
292         raise AuthenticationError('No authentication methods available')
293
294     def run(self):
295         chan = self._channel
296         chan.setblocking(0)
297         q = self._q
298         try:
299             while True:
300                 # select on a paramiko ssh channel object does not ever return
301                 # it in the writable list, so it channel's don't exactly emulate
302                 # the socket api
303                 r, w, e = select([chan], [], [], TICK)
304                 # will wakeup evey TICK seconds to check if something
305                 # to send, more if something to read (due to select returning
306                 # chan in readable list)
307                 if r:
308                     data = chan.recv(BUF_SIZE)
309                     if data:
310                         self._buffer.write(data)
311                         self._parse()
312                     else:
313                         raise SessionCloseError(self._buffer.getvalue())
314                 if not q.empty() and chan.send_ready():
315                     logger.debug('sending message')
316                     data = q.get() + MSG_DELIM
317                     while data:
318                         n = chan.send(data)
319                         if n <= 0:
320                             raise SessionCloseError(self._buffer.getvalue(), data)
321                         data = data[n:]
322         except Exception as e:
323             logger.debug('broke out of main loop')
324             expecting = self._expecting_close
325             self.close()
326             logger.debug('error=%r' % e)
327             logger.debug('expecting_close=%r' % expecting)
328             if not (isinstance(e, SessionCloseError) and expecting):
329                 self._dispatch_error(e)
330
331     @property
332     def transport(self):
333         """Underlying `paramiko.Transport
334         <http://www.lag.net/paramiko/docs/paramiko.Transport-class.html>`_
335         object. This makes it possible to call methods like set_keepalive on it.
336         """
337         return self._transport
338
339     @property
340     def can_pipeline(self):
341         if 'Cisco' in self._transport.remote_version:
342             return False
343         # elif ..
344         return True