more doc updates
[ncclient] / ncclient / transport / ssh.py
1 # Copyright 2009 Shikhar Bhushan
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 #    http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 import os
16 import socket
17 import getpass
18 from binascii import hexlify
19 from cStringIO import StringIO
20 from select import select
21
22 import paramiko
23
24 from errors import AuthenticationError, SessionCloseError, SSHError, SSHUnknownHostError
25 from session import Session
26
27 import logging
28 logger = logging.getLogger("ncclient.transport.ssh")
29
30 BUF_SIZE = 4096
31 MSG_DELIM = "]]>]]>"
32 TICK = 0.1
33
34 def default_unknown_host_cb(host, fingerprint):
35     """An unknown host callback returns `True` if it finds the key acceptable, and `False` if not.
36
37     This default callback always returns `False`, which would lead to :meth:`connect` raising a :exc:`SSHUnknownHost` exception.
38     
39     Supply another valid callback if you need to verify the host key programatically.
40
41     *host* is the hostname that needs to be verified
42
43     *fingerprint* is a hex string representing the host key fingerprint, colon-delimited e.g. `"4b:69:6c:72:6f:79:20:77:61:73:20:68:65:72:65:21"`
44     """
45     return False
46
47 def _colonify(fp):
48     finga = fp[:2]
49     for idx  in range(2, len(fp), 2):
50         finga += ":" + fp[idx:idx+2]
51     return finga
52
53 class SSHSession(Session):
54
55     "Implements a :rfc:`4742` NETCONF session over SSH."
56
57     def __init__(self, capabilities):
58         Session.__init__(self, capabilities)
59         self._host_keys = paramiko.HostKeys()
60         self._transport = None
61         self._connected = False
62         self._channel = None
63         self._buffer = StringIO() # for incoming data
64         # parsing-related, see _parse()
65         self._parsing_state = 0
66         self._parsing_pos = 0
67     
68     def _parse(self):
69         "Messages ae delimited by MSG_DELIM. The buffer could have grown by a maximum of BUF_SIZE bytes everytime this method is called. Retains state across method calls and if a byte has been read it will not be considered again."
70         delim = MSG_DELIM
71         n = len(delim) - 1
72         expect = self._parsing_state
73         buf = self._buffer
74         buf.seek(self._parsing_pos)
75         while True:
76             x = buf.read(1)
77             if not x: # done reading
78                 break
79             elif x == delim[expect]: # what we expected
80                 expect += 1 # expect the next delim char
81             else:
82                 expect = 0
83                 continue
84             # loop till last delim char expected, break if other char encountered
85             for i in range(expect, n):
86                 x = buf.read(1)
87                 if not x: # done reading
88                     break
89                 if x == delim[expect]: # what we expected
90                     expect += 1 # expect the next delim char
91                 else:
92                     expect = 0 # reset
93                     break
94             else: # if we didn't break out of the loop, full delim was parsed
95                 msg_till = buf.tell() - n
96                 buf.seek(0)
97                 logger.debug('parsed new message')
98                 self._dispatch_message(buf.read(msg_till).strip())
99                 buf.seek(n+1, os.SEEK_CUR)
100                 rest = buf.read()
101                 buf = StringIO()
102                 buf.write(rest)
103                 buf.seek(0)
104                 expect = 0
105         self._buffer = buf
106         self._parsing_state = expect
107         self._parsing_pos = self._buffer.tell()
108
109     def load_known_hosts(self, filename=None):
110         """Load host keys from an openssh :file:`known_hosts`-style file. Can be called multiple times.
111
112         If *filename* is not specified, looks in the default locations i.e. :file:`~/.ssh/known_hosts` and :file:`~/ssh/known_hosts` for Windows.
113         """
114         if filename is None:
115             filename = os.path.expanduser('~/.ssh/known_hosts')
116             try:
117                 self._host_keys.load(filename)
118             except IOError:
119                 # for windows
120                 filename = os.path.expanduser('~/ssh/known_hosts')
121                 try:
122                     self._host_keys.load(filename)
123                 except IOError:
124                     pass
125         else:
126             self._host_keys.load(filename)
127
128     def close(self):
129         if self._transport.is_active():
130             self._transport.close()
131         self._connected = False
132
133     # REMEMBER to update transport.rst if sig. changes, since it is hardcoded there
134     def connect(self, host, port=830, timeout=None, unknown_host_cb=default_unknown_host_cb,
135                 username=None, password=None, key_filename=None, allow_agent=True, look_for_keys=True):
136         """Connect via SSH and initialize the NETCONF session. First attempts the publickey authentication method and then password authentication.
137
138         To disable attempting publickey authentication altogether, call with *allow_agent* and *look_for_keys* as `False`.
139
140         *host* is the hostname or IP address to connect to
141
142         *port* is by default 830, but some devices use the default SSH port of 22 so this may need to be specified
143
144         *timeout* is an optional timeout for socket connect
145
146         *unknown_host_cb* is called when the server host key is not recognized. It takes two arguments, the hostname and the fingerprint (see the signature of :func:`default_unknown_host_cb`)
147
148         *username* is the username to use for SSH authentication
149
150         *password* is the password used if using password authentication, or the passphrase to use for unlocking keys that require it
151
152         *key_filename* is a filename where a the private key to be used can be found
153
154         *allow_agent* enables querying SSH agent (if found) for keys
155
156         *look_for_keys* enables looking in the usual locations for ssh keys (e.g. :file:`~/.ssh/id_*`)
157         """
158         if username is None:
159             username = getpass.getuser()
160         
161         sock = None
162         for res in socket.getaddrinfo(host, port, socket.AF_UNSPEC, socket.SOCK_STREAM):
163             af, socktype, proto, canonname, sa = res
164             try:
165                 sock = socket.socket(af, socktype, proto)
166                 sock.settimeout(timeout)
167             except socket.error:
168                 continue
169             try:
170                 sock.connect(sa)
171             except socket.error:
172                 sock.close()
173                 continue
174             break
175         else:
176             raise SSHError("Could not open socket to %s:%s" % (host, port))
177
178         t = self._transport = paramiko.Transport(sock)
179         t.set_log_channel(logger.name)
180
181         try:
182             t.start_client()
183         except paramiko.SSHException:
184             raise SSHError('Negotiation failed')
185
186         # host key verification
187         server_key = t.get_remote_server_key()
188         known_host = self._host_keys.check(host, server_key)
189
190         fingerprint = _colonify(hexlify(server_key.get_fingerprint()))
191
192         if not known_host and not unknown_host_cb(host, fingerprint):
193             raise SSHUnknownHostError(host, fingerprint)
194
195         if key_filename is None:
196             key_filenames = []
197         elif isinstance(key_filename, basestring):
198             key_filenames = [ key_filename ]
199         else:
200             key_filenames = key_filename
201
202         self._auth(username, password, key_filenames, allow_agent, look_for_keys)
203
204         self._connected = True # there was no error authenticating
205
206         c = self._channel = self._transport.open_session()
207         c.set_name("netconf")
208         c.invoke_subsystem("netconf")
209
210         self._post_connect()
211     
212     # on the lines of paramiko.SSHClient._auth()
213     def _auth(self, username, password, key_filenames, allow_agent,
214               look_for_keys):
215         saved_exception = None
216
217         for key_filename in key_filenames:
218             for cls in (paramiko.RSAKey, paramiko.DSSKey):
219                 try:
220                     key = cls.from_private_key_file(key_filename, password)
221                     logger.debug("Trying key %s from %s" %
222                               (hexlify(key.get_fingerprint()), key_filename))
223                     self._transport.auth_publickey(username, key)
224                     return
225                 except Exception as e:
226                     saved_exception = e
227                     logger.debug(e)
228
229         if allow_agent:
230             for key in paramiko.Agent().get_keys():
231                 try:
232                     logger.debug("Trying SSH agent key %s" %
233                                  hexlify(key.get_fingerprint()))
234                     self._transport.auth_publickey(username, key)
235                     return
236                 except Exception as e:
237                     saved_exception = e
238                     logger.debug(e)
239
240         keyfiles = []
241         if look_for_keys:
242             rsa_key = os.path.expanduser("~/.ssh/id_rsa")
243             dsa_key = os.path.expanduser("~/.ssh/id_dsa")
244             if os.path.isfile(rsa_key):
245                 keyfiles.append((paramiko.RSAKey, rsa_key))
246             if os.path.isfile(dsa_key):
247                 keyfiles.append((paramiko.DSSKey, dsa_key))
248             # look in ~/ssh/ for windows users:
249             rsa_key = os.path.expanduser("~/ssh/id_rsa")
250             dsa_key = os.path.expanduser("~/ssh/id_dsa")
251             if os.path.isfile(rsa_key):
252                 keyfiles.append((paramiko.RSAKey, rsa_key))
253             if os.path.isfile(dsa_key):
254                 keyfiles.append((paramiko.DSSKey, dsa_key))
255
256         for cls, filename in keyfiles:
257             try:
258                 key = cls.from_private_key_file(filename, password)
259                 logger.debug("Trying discovered key %s in %s" %
260                           (hexlify(key.get_fingerprint()), filename))
261                 self._transport.auth_publickey(username, key)
262                 return
263             except Exception as e:
264                 saved_exception = e
265                 logger.debug(e)
266
267         if password is not None:
268             try:
269                 self._transport.auth_password(username, password)
270                 return
271             except Exception as e:
272                 saved_exception = e
273                 logger.debug(e)
274
275         if saved_exception is not None:
276             # need pep-3134 to do this right
277             raise AuthenticationError(repr(saved_exception))
278
279         raise AuthenticationError("No authentication methods available")
280
281     def run(self):
282         chan = self._channel
283         chan.setblocking(0)
284         q = self._q
285         try:
286             while True:
287                 # select on a paramiko ssh channel object does not ever return it in the writable list, so channels don't exactly emulate the socket api
288                 r, w, e = select([chan], [], [], TICK)
289                 # will wakeup evey TICK seconds to check if something to send, more if something to read (due to select returning chan in readable list)
290                 if r:
291                     data = chan.recv(BUF_SIZE)
292                     if data:
293                         self._buffer.write(data)
294                         self._parse()
295                     else:
296                         raise SessionCloseError(self._buffer.getvalue())
297                 if not q.empty() and chan.send_ready():
298                     logger.debug("Sending message")
299                     data = q.get() + MSG_DELIM
300                     while data:
301                         n = chan.send(data)
302                         if n <= 0:
303                             raise SessionCloseError(self._buffer.getvalue(), data)
304                         data = data[n:]
305         except Exception as e:
306             logger.debug("Broke out of main loop, error=%r", e)
307             self.close()
308             self._dispatch_error(e)
309
310     @property
311     def transport(self):
312         "Underlying `paramiko.Transport <http://www.lag.net/paramiko/docs/paramiko.Transport-class.html>`_ object. This makes it possible to call methods like :meth:`~paramiko.Transport.set_keepalive` on it."
313         return self._transport