rename content to xml_
[ncclient] / ncclient / transport / ssh.py
1 # Copyright 2009 Shikhar Bhushan
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 #    http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 import os
16 import socket
17 from binascii import hexlify
18 from cStringIO import StringIO
19 from select import select
20
21 import paramiko
22
23 from errors import AuthenticationError, SessionCloseError, SSHError, SSHUnknownHostError
24 from session import Session
25
26 import logging
27 logger = logging.getLogger('ncclient.transport.ssh')
28
29 BUF_SIZE = 4096
30 MSG_DELIM = ']]>]]>'
31 TICK = 0.1
32
33 def default_unknown_host_cb(host, key):
34     """An `unknown host callback` returns :const:`True` if it finds the key
35     acceptable, and :const:`False` if not.
36
37     This default callback always returns :const:`False`, which would lead to
38     :meth:`connect` raising a :exc:`SSHUnknownHost` exception.
39
40     Supply another valid callback if you need to verify the host key
41     programatically.
42
43     :arg host: the host for whom key needs to be verified
44     :type host: string
45
46     :arg key: a hex string representing the host key fingerprint
47     :type key: string
48     """
49     return False
50
51
52 class SSHSession(Session):
53
54     "Implements a :rfc:`4742` NETCONF session over SSH."
55
56     def __init__(self, capabilities):
57         Session.__init__(self, capabilities)
58         self._host_keys = paramiko.HostKeys()
59         self._transport = None
60         self._connected = False
61         self._channel = None
62         self._expecting_close = False
63         self._buffer = StringIO() # for incoming data
64         # parsing-related, see _parse()
65         self._parsing_state = 0
66         self._parsing_pos = 0
67
68     def _parse(self):
69         '''Messages ae delimited by MSG_DELIM. The buffer could have grown by a
70         maximum of BUF_SIZE bytes everytime this method is called. Retains state
71         across method calls and if a byte has been read it will not be
72         considered again. '''
73         delim = MSG_DELIM
74         n = len(delim) - 1
75         expect = self._parsing_state
76         buf = self._buffer
77         buf.seek(self._parsing_pos)
78         while True:
79             x = buf.read(1)
80             if not x: # done reading
81                 break
82             elif x == delim[expect]: # what we expected
83                 expect += 1 # expect the next delim char
84             else:
85                 continue
86             # loop till last delim char expected, break if other char encountered
87             for i in range(expect, n):
88                 x = buf.read(1)
89                 if not x: # done reading
90                     break
91                 if x == delim[expect]: # what we expected
92                     expect += 1 # expect the next delim char
93                 else:
94                     expect = 0 # reset
95                     break
96             else: # if we didn't break out of the loop, full delim was parsed
97                 msg_till = buf.tell() - n
98                 buf.seek(0)
99                 logger.debug('parsed new message')
100                 self._dispatch_message(buf.read(msg_till).strip())
101                 buf.seek(n+1, os.SEEK_CUR)
102                 rest = buf.read()
103                 buf = StringIO()
104                 buf.write(rest)
105                 buf.seek(0)
106                 expect = 0
107         self._buffer = buf
108         self._parsing_state = expect
109         self._parsing_pos = self._buffer.tell()
110
111     def load_known_hosts(self, filename=None):
112         """Load host keys from a :file:`known_hosts`-style file. Can be called multiple
113         times.
114
115         If *filename* is not specified, looks in the default locations i.e.
116         :file:`~/.ssh/known_hosts` and :file:`~/ssh/known_hosts` for Windows.
117         """
118         if filename is None:
119             filename = os.path.expanduser('~/.ssh/known_hosts')
120             try:
121                 self._host_keys.load(filename)
122             except IOError:
123                 # for windows
124                 filename = os.path.expanduser('~/ssh/known_hosts')
125                 try:
126                     self._host_keys.load(filename)
127                 except IOError:
128                     pass
129         else:
130             self._host_keys.load(filename)
131
132     def close(self):
133         self._expecting_close = True
134         if self._transport.is_active():
135             self._transport.close()
136         self._connected = False
137
138     def connect(self, host, port=830, timeout=None,
139                 unknown_host_cb=default_unknown_host_cb,
140                 username=None, password=None,
141                 key_filename=None, allow_agent=True, look_for_keys=True):
142         """Connect via SSH and initialize the NETCONF session. First attempts
143         the publickey authentication method and then password authentication.
144
145         To disable attemting publickey authentication altogether, call with
146         *allow_agent* and *look_for_keys* as :const:`False`. This may be needed
147         for Cisco devices which immediately disconnect on an incorrect
148         authentication attempt.
149
150         :arg host: the hostname or IP address to connect to
151         :type host: `string`
152
153         :arg port: by default 830, but some devices use the default SSH port of 22 so this may need to be specified
154         :type port: `int`
155
156         :arg timeout: an optional timeout for the TCP handshake
157         :type timeout: `int`
158
159         :arg unknown_host_cb: called when a host key is not recognized
160         :type unknown_host_cb: see :meth:`signature <ssh.default_unknown_host_cb>`
161
162         :arg username: the username to use for SSH authentication
163         :type username: `string`
164
165         :arg password: the password used if using password authentication, or the passphrase to use for unlocking keys that require it
166         :type password: `string`
167
168         :arg key_filename: a filename where a the private key to be used can be found
169         :type key_filename: `string`
170
171         :arg allow_agent: enables querying SSH agent (if found) for keys
172         :type allow_agent: `bool`
173
174         :arg look_for_keys: enables looking in the usual locations for ssh keys (e.g. :file:`~/.ssh/id_*`)
175         :type look_for_keys: `bool`
176         """
177
178         if username is None:
179             raise SSHError("No username specified")
180
181         sock = None
182         for res in socket.getaddrinfo(host, port, socket.AF_UNSPEC, socket.SOCK_STREAM):
183             af, socktype, proto, canonname, sa = res
184             try:
185                 sock = socket.socket(af, socktype, proto)
186                 sock.settimeout(timeout)
187             except socket.error:
188                 continue
189             try:
190                 sock.connect(sa)
191             except socket.error:
192                 sock.close()
193                 continue
194             break
195         else:
196             raise SSHError("Could not open socket")
197
198         t = self._transport = paramiko.Transport(sock)
199         t.set_log_channel(logger.name)
200
201         try:
202             t.start_client()
203         except paramiko.SSHException:
204             raise SSHError('Negotiation failed')
205
206         # host key verification
207         server_key = t.get_remote_server_key()
208         known_host = self._host_keys.check(host, server_key)
209
210         fingerprint = hexlify(server_key.get_fingerprint())
211
212         if not known_host and not unknown_host_cb(host, fingerprint):
213             raise SSHUnknownHostError(host, fingerprint)
214
215         if key_filename is None:
216             key_filenames = []
217         elif isinstance(key_filename, basestring):
218             key_filenames = [ key_filename ]
219         else:
220             key_filenames = key_filename
221
222         self._auth(username, password, key_filenames, allow_agent, look_for_keys)
223
224         self._connected = True # there was no error authenticating
225
226         c = self._channel = self._transport.open_session()
227         c.set_name('netconf')
228         c.invoke_subsystem('netconf')
229
230         self._post_connect()
231
232     # on the lines of paramiko.SSHClient._auth()
233     def _auth(self, username, password, key_filenames, allow_agent,
234               look_for_keys):
235         saved_exception = None
236
237         for key_filename in key_filenames:
238             for cls in (paramiko.RSAKey, paramiko.DSSKey):
239                 try:
240                     key = cls.from_private_key_file(key_filename, password)
241                     logger.debug('Trying key %s from %s' %
242                               (hexlify(key.get_fingerprint()), key_filename))
243                     self._transport.auth_publickey(username, key)
244                     return
245                 except Exception as e:
246                     saved_exception = e
247                     logger.debug(e)
248
249         if allow_agent:
250             for key in paramiko.Agent().get_keys():
251                 try:
252                     logger.debug('Trying SSH agent key %s' %
253                                  hexlify(key.get_fingerprint()))
254                     self._transport.auth_publickey(username, key)
255                     return
256                 except Exception as e:
257                     saved_exception = e
258                     logger.debug(e)
259
260         keyfiles = []
261         if look_for_keys:
262             rsa_key = os.path.expanduser('~/.ssh/id_rsa')
263             dsa_key = os.path.expanduser('~/.ssh/id_dsa')
264             if os.path.isfile(rsa_key):
265                 keyfiles.append((paramiko.RSAKey, rsa_key))
266             if os.path.isfile(dsa_key):
267                 keyfiles.append((paramiko.DSSKey, dsa_key))
268             # look in ~/ssh/ for windows users:
269             rsa_key = os.path.expanduser('~/ssh/id_rsa')
270             dsa_key = os.path.expanduser('~/ssh/id_dsa')
271             if os.path.isfile(rsa_key):
272                 keyfiles.append((paramiko.RSAKey, rsa_key))
273             if os.path.isfile(dsa_key):
274                 keyfiles.append((paramiko.DSSKey, dsa_key))
275
276         for cls, filename in keyfiles:
277             try:
278                 key = cls.from_private_key_file(filename, password)
279                 logger.debug('Trying discovered key %s in %s' %
280                           (hexlify(key.get_fingerprint()), filename))
281                 self._transport.auth_publickey(username, key)
282                 return
283             except Exception as e:
284                 saved_exception = e
285                 logger.debug(e)
286
287         if password is not None:
288             try:
289                 self._transport.auth_password(username, password)
290                 return
291             except Exception as e:
292                 saved_exception = e
293                 logger.debug(e)
294
295         if saved_exception is not None:
296             # need pep-3134 to do this right
297             raise AuthenticationError(repr(saved_exception))
298
299         raise AuthenticationError('No authentication methods available')
300
301     def run(self):
302         chan = self._channel
303         chan.setblocking(0)
304         q = self._q
305         try:
306             while True:
307                 # select on a paramiko ssh channel object does not ever return
308                 # it in the writable list, so it channel's don't exactly emulate
309                 # the socket api
310                 r, w, e = select([chan], [], [], TICK)
311                 # will wakeup evey TICK seconds to check if something
312                 # to send, more if something to read (due to select returning
313                 # chan in readable list)
314                 if r:
315                     data = chan.recv(BUF_SIZE)
316                     if data:
317                         self._buffer.write(data)
318                         self._parse()
319                     else:
320                         raise SessionCloseError(self._buffer.getvalue())
321                 if not q.empty() and chan.send_ready():
322                     logger.debug('sending message')
323                     data = q.get() + MSG_DELIM
324                     while data:
325                         n = chan.send(data)
326                         if n <= 0:
327                             raise SessionCloseError(self._buffer.getvalue(), data)
328                         data = data[n:]
329         except Exception as e:
330             logger.debug('broke out of main loop')
331             expecting = self._expecting_close
332             self.close()
333             logger.debug('error=%r' % e)
334             logger.debug('expecting_close=%r' % expecting)
335             if not (isinstance(e, SessionCloseError) and expecting):
336                 self._dispatch_error(e)
337
338     @property
339     def transport(self):
340         """Underlying `paramiko.Transport
341         <http://www.lag.net/paramiko/docs/paramiko.Transport-class.html>`_
342         object. This makes it possible to call methods like set_keepalive on it.
343         """
344         return self._transport
345
346     @property
347     def can_pipeline(self):
348         if 'Cisco' in self._transport.remote_version:
349             return False
350         # elif ..
351         return True