NodeQuery: don't query non-vm_capable nodes
[ganeti-local] / tools / setup-ssh
1 #!/usr/bin/python
2 #
3
4 # Copyright (C) 2010 Google Inc.
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 # 02110-1301, USA.
20
21 """Tool to setup the SSH configuration on a remote node.
22
23 This is needed before we can join the node into the cluster.
24
25 """
26
27 # pylint: disable-msg=C0103
28 # C0103: Invalid name setup-ssh
29
30 import getpass
31 import logging
32 import paramiko
33 import os.path
34 import optparse
35 import sys
36
37 from ganeti import cli
38 from ganeti import constants
39 from ganeti import errors
40 from ganeti import netutils
41 from ganeti import ssconf
42 from ganeti import ssh
43 from ganeti import utils
44
45
46 class RemoteCommandError(errors.GenericError):
47   """Exception if remote command was not successful.
48
49   """
50
51
52 class JoinCheckError(errors.GenericError):
53   """Exception raised if join check fails.
54
55   """
56
57
58 class HostKeyVerificationError(errors.GenericError):
59   """Exception if host key do not match.
60
61   """
62
63
64 class AuthError(errors.GenericError):
65   """Exception for authentication errors to hosts.
66
67   """
68
69
70 def _CheckJoin(transport):
71   """Checks if a join is safe or dangerous.
72
73   Note: This function relies on the fact, that all
74   hosts have the same configuration at compile time of
75   Ganeti. So that the constants do not mismatch.
76
77   @param transport: The paramiko transport instance
78   @return: True if the join is safe; False otherwise
79
80   """
81   sftp = transport.open_sftp_client()
82   ss = ssconf.SimpleStore()
83   ss_cluster_name_path = ss.KeyToFilename(constants.SS_CLUSTER_NAME)
84
85   cluster_files = [
86     (constants.NODED_CERT_FILE, utils.ReadFile(constants.NODED_CERT_FILE)),
87     (ss_cluster_name_path, utils.ReadFile(ss_cluster_name_path)),
88     ]
89
90   for (filename, local_content) in cluster_files:
91     try:
92       remote_content = _ReadSftpFile(sftp, filename)
93     except IOError, err:
94       # Assume file does not exist. Paramiko's error reporting is lacking.
95       logging.debug("Failed to read %s: %s", filename, err)
96       continue
97
98     if remote_content != local_content:
99       logging.error("File %s doesn't match local version", filename)
100       return False
101
102   return True
103
104
105 def _RunRemoteCommand(transport, command):
106   """Invokes and wait for the command over SSH.
107
108   @param transport: The paramiko transport instance
109   @param command: The command to be executed
110
111   """
112   chan = transport.open_session()
113   chan.set_combine_stderr(True)
114   output_handler = chan.makefile("r")
115   chan.exec_command(command)
116
117   result = chan.recv_exit_status()
118   msg = output_handler.read()
119
120   out_msg = "'%s' exited with status code %s, output %r" % (command, result,
121                                                             msg)
122
123   # If result is -1 (no exit status provided) we assume it was not successful
124   if result:
125     raise RemoteCommandError(out_msg)
126
127   if msg:
128     logging.info(out_msg)
129
130
131 def _InvokeDaemonUtil(transport, command):
132   """Invokes daemon-util on the remote side.
133
134   @param transport: The paramiko transport instance
135   @param command: The daemon-util command to be run
136
137   """
138   _RunRemoteCommand(transport, "%s %s" % (constants.DAEMON_UTIL, command))
139
140
141 def _ReadSftpFile(sftp, filename):
142   """Reads a file over sftp.
143
144   @param sftp: An open paramiko SFTP client
145   @param filename: The filename of the file to read
146   @return: The content of the file
147
148   """
149   remote_file = sftp.open(filename, "r")
150   try:
151     return remote_file.read()
152   finally:
153     remote_file.close()
154
155
156 def _WriteSftpFile(sftp, name, perm, data):
157   """SFTPs data to a remote file.
158
159   @param sftp: A open paramiko SFTP client
160   @param name: The remote file name
161   @param perm: The remote file permission
162   @param data: The data to write
163
164   """
165   remote_file = sftp.open(name, "w")
166   try:
167     sftp.chmod(name, perm)
168     remote_file.write(data)
169   finally:
170     remote_file.close()
171
172
173 def SetupSSH(transport):
174   """Sets the SSH up on the other side.
175
176   @param transport: The paramiko transport instance
177
178   """
179   priv_key, pub_key, auth_keys = ssh.GetUserFiles(constants.GANETI_RUNAS)
180   keyfiles = [
181     (constants.SSH_HOST_DSA_PRIV, 0600),
182     (constants.SSH_HOST_DSA_PUB, 0644),
183     (constants.SSH_HOST_RSA_PRIV, 0600),
184     (constants.SSH_HOST_RSA_PUB, 0644),
185     (priv_key, 0600),
186     (pub_key, 0644),
187     ]
188
189   sftp = transport.open_sftp_client()
190
191   filemap = dict((name, (utils.ReadFile(name), perm))
192                  for (name, perm) in keyfiles)
193
194   auth_path = os.path.dirname(auth_keys)
195
196   try:
197     sftp.mkdir(auth_path, 0700)
198   except IOError, err:
199     # Sadly paramiko doesn't provide errno or similiar
200     # so we can just assume that the path already exists
201     logging.info("Assuming directory %s on remote node exists: %s",
202                  auth_path, err)
203
204   for name, (data, perm) in filemap.iteritems():
205     _WriteSftpFile(sftp, name, perm, data)
206
207   authorized_keys = sftp.open(auth_keys, "a+")
208   try:
209     # Due to the way SFTPFile and BufferedFile are implemented,
210     # opening in a+ mode and then issuing a read(), readline() or
211     # iterating over the file (which uses read() internally) will see
212     # an empty file, since the paramiko internal file position and the
213     # OS-level file-position are desynchronized; therefore, we issue
214     # an explicit seek to resynchronize these; writes should (note
215     # should) still go to the right place
216     authorized_keys.seek(0, 0)
217     # We don't have to close, as the close happened already in AddAuthorizedKey
218     utils.AddAuthorizedKey(authorized_keys, filemap[pub_key][0])
219   finally:
220     authorized_keys.close()
221
222   _InvokeDaemonUtil(transport, "reload-ssh-keys")
223
224
225 def ParseOptions():
226   """Parses options passed to program.
227
228   """
229   program = os.path.basename(sys.argv[0])
230
231   parser = optparse.OptionParser(usage=("%prog [--debug|--verbose] [--force]"
232                                         " <node> <node...>"), prog=program)
233   parser.add_option(cli.DEBUG_OPT)
234   parser.add_option(cli.VERBOSE_OPT)
235   parser.add_option(cli.NOSSH_KEYCHECK_OPT)
236   default_key = ssh.GetUserFiles(constants.GANETI_RUNAS)[0]
237   parser.add_option(optparse.Option("-f", dest="private_key",
238                                     default=default_key,
239                                     help="The private key to (try to) use for"
240                                     "authentication "))
241   parser.add_option(optparse.Option("--key-type", dest="key_type",
242                                     choices=("rsa", "dsa"), default="dsa",
243                                     help="The private key type (rsa or dsa)"))
244   parser.add_option(optparse.Option("-j", "--force-join", dest="force_join",
245                                     action="store_true", default=False,
246                                     help="Force the join of the host"))
247
248   (options, args) = parser.parse_args()
249
250   return (options, args)
251
252
253 def SetupLogging(options):
254   """Sets up the logging.
255
256   @param options: Parsed options
257
258   """
259   fmt = "%(asctime)s: %(threadName)s "
260   if options.debug or options.verbose:
261     fmt += "%(levelname)s "
262   fmt += "%(message)s"
263
264   formatter = logging.Formatter(fmt)
265
266   file_handler = logging.FileHandler(constants.LOG_SETUP_SSH)
267   stderr_handler = logging.StreamHandler()
268   stderr_handler.setFormatter(formatter)
269   file_handler.setFormatter(formatter)
270   file_handler.setLevel(logging.INFO)
271
272   if options.debug:
273     stderr_handler.setLevel(logging.DEBUG)
274   elif options.verbose:
275     stderr_handler.setLevel(logging.INFO)
276   else:
277     stderr_handler.setLevel(logging.WARNING)
278
279   root_logger = logging.getLogger("")
280   root_logger.setLevel(logging.NOTSET)
281   root_logger.addHandler(stderr_handler)
282   root_logger.addHandler(file_handler)
283
284   # This is the paramiko logger instance
285   paramiko_logger = logging.getLogger("paramiko")
286   paramiko_logger.addHandler(file_handler)
287   # We don't want to debug Paramiko, so filter anything below warning
288   paramiko_logger.setLevel(logging.WARNING)
289
290
291 def LoadPrivateKeys(options):
292   """Load the list of available private keys.
293
294   It loads the standard ssh key from disk and then tries to connect to
295   the ssh agent too.
296
297   @rtype: list
298   @return: a list of C{paramiko.PKey}
299
300   """
301   if options.key_type == "rsa":
302     pkclass = paramiko.RSAKey
303   elif options.key_type == "dsa":
304     pkclass = paramiko.DSSKey
305   else:
306     logging.critical("Unknown key type %s selected (choose either rsa or dsa)",
307                      options.key_type)
308     sys.exit(1)
309
310   try:
311     private_key = pkclass.from_private_key_file(options.private_key)
312   except (paramiko.SSHException, EnvironmentError), err:
313     logging.critical("Can't load private key %s: %s", options.private_key, err)
314     sys.exit(1)
315
316   try:
317     agent = paramiko.Agent()
318     agent_keys = agent.get_keys()
319   except paramiko.SSHException, err:
320     # this will only be seen when the agent is broken/uses invalid
321     # protocol; for non-existing agent, get_keys() will just return an
322     # empty tuple
323     logging.warning("Can't connect to the ssh agent: %s; skipping its use",
324                     err)
325     agent_keys = []
326
327   return [private_key] + list(agent_keys)
328
329
330 def _FormatFingerprint(fpr):
331   """Formats a paramiko.PKey.get_fingerprint() human readable.
332
333   @param fpr: The fingerprint to be formatted
334   @return: A human readable fingerprint
335
336   """
337   return ssh.FormatParamikoFingerprint(paramiko.util.hexify(fpr))
338
339
340 def LoginViaKeys(transport, username, keys):
341   """Try to login on the given transport via a list of keys.
342
343   @param transport: the transport to use
344   @param username: the username to login as
345   @type keys: list
346   @param keys: list of C{paramiko.PKey} to use for authentication
347   @rtype: boolean
348   @return: True or False depending on whether the login was
349       successfull or not
350
351   """
352   for private_key in keys:
353     try:
354       transport.auth_publickey(username, private_key)
355       fpr = _FormatFingerprint(private_key.get_fingerprint())
356       if isinstance(private_key, paramiko.AgentKey):
357         logging.debug("Authentication via the ssh-agent key %s", fpr)
358       else:
359         logging.debug("Authenticated via public key %s", fpr)
360       return True
361     except paramiko.SSHException:
362       continue
363   else:
364     # all keys exhausted
365     return False
366
367
368 def LoadKnownHosts():
369   """Load the known hosts.
370
371   @return: paramiko.util.load_host_keys dict
372
373   """
374   homedir = utils.GetHomeDir(constants.GANETI_RUNAS)
375   known_hosts = os.path.join(homedir, ".ssh", "known_hosts")
376
377   try:
378     return paramiko.util.load_host_keys(known_hosts)
379   except EnvironmentError:
380     # We didn't find the path, silently ignore and return an empty dict
381     return {}
382
383
384 def _VerifyServerKey(transport, host, host_keys):
385   """Verify the server keys.
386
387   @param transport: A paramiko.transport instance
388   @param host: Name of the host we verify
389   @param host_keys: Loaded host keys
390   @raises HostkeyVerificationError: When the host identify couldn't be verified
391
392   """
393
394   server_key = transport.get_remote_server_key()
395   keytype = server_key.get_name()
396
397   our_server_key = host_keys.get(host, {}).get(keytype, None)
398   if not our_server_key:
399     hexified_key = _FormatFingerprint(server_key.get_fingerprint())
400     msg = ("Unable to verify hostkey of host %s: %s. Do you want to accept"
401            " it?" % (host, hexified_key))
402
403     if cli.AskUser(msg):
404       our_server_key = server_key
405
406   if our_server_key != server_key:
407     raise HostKeyVerificationError("Unable to verify host identity")
408
409
410 def main():
411   """Main routine.
412
413   """
414   (options, args) = ParseOptions()
415
416   SetupLogging(options)
417
418   all_keys = LoadPrivateKeys(options)
419
420   passwd = None
421   username = constants.GANETI_RUNAS
422   ssh_port = netutils.GetDaemonPort("ssh")
423   host_keys = LoadKnownHosts()
424
425   # Below, we need to join() the transport objects, as otherwise the
426   # following happens:
427   # - the main thread finishes
428   # - the atexit functions run (in the main thread), and cause the
429   #   logging file to be closed
430   # - a tiny bit later, the transport thread is finally ending, and
431   #   wants to log one more message, which fails as the file is closed
432   #   now
433
434   success = True
435
436   for host in args:
437     logging.info("Configuring %s", host)
438
439     transport = paramiko.Transport((host, ssh_port))
440     try:
441       try:
442         transport.start_client()
443
444         if options.ssh_key_check:
445           _VerifyServerKey(transport, host, host_keys)
446
447         try:
448           if LoginViaKeys(transport, username, all_keys):
449             logging.info("Authenticated to %s via public key", host)
450           else:
451             if all_keys:
452               logging.warning("Authentication to %s via public key failed,"
453                               " trying password", host)
454             if passwd is None:
455               passwd = getpass.getpass(prompt="%s password:" % username)
456             transport.auth_password(username=username, password=passwd)
457             logging.info("Authenticated to %s via password", host)
458         except paramiko.SSHException, err:
459           raise AuthError("Auth error TODO" % err)
460
461         if not _CheckJoin(transport):
462           if not options.force_join:
463             raise JoinCheckError(("Host %s failed join check; Please verify"
464                                   " that the host was not previously joined"
465                                   " to another cluster and use --force-join"
466                                   " to continue") % host)
467
468           logging.warning("Host %s failed join check, forced to continue",
469                           host)
470
471         SetupSSH(transport)
472         logging.info("%s successfully configured", host)
473       finally:
474         transport.close()
475         # this is needed for compatibility with older Paramiko or Python
476         # versions
477         transport.join()
478     except AuthError, err:
479       logging.error("Authentication error: %s", err)
480       success = False
481       break
482     except HostKeyVerificationError, err:
483       logging.error("Host key verification error: %s", err)
484       success = False
485     except Exception, err:
486       logging.exception("During setup of %s: %s", host, err)
487       success = False
488
489   if success:
490     sys.exit(constants.EXIT_SUCCESS)
491
492   sys.exit(constants.EXIT_FAILURE)
493
494
495 if __name__ == "__main__":
496   main()