#!/usr/bin/python # # Copyright (C) 2010, 2012 Google Inc. # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation; either version 2 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, but # WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU # General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program; if not, write to the Free Software # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA # 02110-1301, USA. """Tool to setup the SSH configuration on a remote node. This is needed before we can join the node into the cluster. """ # pylint: disable=C0103 # C0103: Invalid name setup-ssh import getpass import logging import os.path import optparse import sys # workaround paramiko warnings # FIXME: use 'with warnings.catch_warnings' once we drop Python 2.4 import warnings warnings.simplefilter("ignore") import paramiko warnings.resetwarnings() from ganeti import cli from ganeti import constants from ganeti import errors from ganeti import netutils from ganeti import ssconf from ganeti import ssh from ganeti import utils class RemoteCommandError(errors.GenericError): """Exception if remote command was not successful. """ class JoinCheckError(errors.GenericError): """Exception raised if join check fails. """ class HostKeyVerificationError(errors.GenericError): """Exception if host key do not match. """ class AuthError(errors.GenericError): """Exception for authentication errors to hosts. """ def _CheckJoin(transport): """Checks if a join is safe or dangerous. Note: This function relies on the fact, that all hosts have the same configuration at compile time of Ganeti. So that the constants do not mismatch. @param transport: The paramiko transport instance @return: True if the join is safe; False otherwise """ sftp = transport.open_sftp_client() ss = ssconf.SimpleStore() ss_cluster_name_path = ss.KeyToFilename(constants.SS_CLUSTER_NAME) cluster_files = [ (constants.NODED_CERT_FILE, utils.ReadFile(constants.NODED_CERT_FILE)), (ss_cluster_name_path, utils.ReadFile(ss_cluster_name_path)), ] for (filename, local_content) in cluster_files: try: remote_content = _ReadSftpFile(sftp, filename) except IOError, err: # Assume file does not exist. Paramiko's error reporting is lacking. logging.debug("Failed to read %s: %s", filename, err) continue if remote_content != local_content: logging.error("File %s doesn't match local version", filename) return False return True def _RunRemoteCommand(transport, command): """Invokes and wait for the command over SSH. @param transport: The paramiko transport instance @param command: The command to be executed """ chan = transport.open_session() chan.set_combine_stderr(True) output_handler = chan.makefile("r") chan.exec_command(command) result = chan.recv_exit_status() msg = output_handler.read() out_msg = "'%s' exited with status code %s, output %r" % (command, result, msg) # If result is -1 (no exit status provided) we assume it was not successful if result: raise RemoteCommandError(out_msg) if msg: logging.info(out_msg) def _InvokeDaemonUtil(transport, command): """Invokes daemon-util on the remote side. @param transport: The paramiko transport instance @param command: The daemon-util command to be run """ _RunRemoteCommand(transport, "%s %s" % (constants.DAEMON_UTIL, command)) def _ReadSftpFile(sftp, filename): """Reads a file over sftp. @param sftp: An open paramiko SFTP client @param filename: The filename of the file to read @return: The content of the file """ remote_file = sftp.open(filename, "r") try: return remote_file.read() finally: remote_file.close() def _WriteSftpFile(sftp, name, perm, data): """SFTPs data to a remote file. @param sftp: A open paramiko SFTP client @param name: The remote file name @param perm: The remote file permission @param data: The data to write """ remote_file = sftp.open(name, "w") try: sftp.chmod(name, perm) remote_file.write(data) finally: remote_file.close() def SetupSSH(transport): """Sets the SSH up on the other side. @param transport: The paramiko transport instance """ priv_key, pub_key, auth_keys = ssh.GetUserFiles(constants.GANETI_RUNAS) keyfiles = [ (constants.SSH_HOST_DSA_PRIV, 0600), (constants.SSH_HOST_DSA_PUB, 0644), (constants.SSH_HOST_RSA_PRIV, 0600), (constants.SSH_HOST_RSA_PUB, 0644), (priv_key, 0600), (pub_key, 0644), ] sftp = transport.open_sftp_client() filemap = dict((name, (utils.ReadFile(name), perm)) for (name, perm) in keyfiles) auth_path = os.path.dirname(auth_keys) try: sftp.mkdir(auth_path, 0700) except IOError, err: # Sadly paramiko doesn't provide errno or similiar # so we can just assume that the path already exists logging.info("Assuming directory %s on remote node exists: %s", auth_path, err) for name, (data, perm) in filemap.iteritems(): _WriteSftpFile(sftp, name, perm, data) authorized_keys = sftp.open(auth_keys, "a+") try: # Due to the way SFTPFile and BufferedFile are implemented, # opening in a+ mode and then issuing a read(), readline() or # iterating over the file (which uses read() internally) will see # an empty file, since the paramiko internal file position and the # OS-level file-position are desynchronized; therefore, we issue # an explicit seek to resynchronize these; writes should (note # should) still go to the right place authorized_keys.seek(0, 0) # We don't have to close, as the close happened already in AddAuthorizedKey utils.AddAuthorizedKey(authorized_keys, filemap[pub_key][0]) finally: authorized_keys.close() _InvokeDaemonUtil(transport, "reload-ssh-keys") def ParseOptions(): """Parses options passed to program. """ program = os.path.basename(sys.argv[0]) parser = optparse.OptionParser(usage=("%prog [--debug|--verbose] [--force]" " "), prog=program) parser.add_option(cli.DEBUG_OPT) parser.add_option(cli.VERBOSE_OPT) parser.add_option(cli.NOSSH_KEYCHECK_OPT) default_key = ssh.GetUserFiles(constants.GANETI_RUNAS)[0] parser.add_option(optparse.Option("-f", dest="private_key", default=default_key, help="The private key to (try to) use for" "authentication ")) parser.add_option(optparse.Option("--key-type", dest="key_type", choices=("rsa", "dsa"), default="dsa", help="The private key type (rsa or dsa)")) parser.add_option(optparse.Option("-j", "--force-join", dest="force_join", action="store_true", default=False, help="Force the join of the host")) (options, args) = parser.parse_args() if not args: parser.print_help() sys.exit(constants.EXIT_FAILURE) return (options, args) def SetupLogging(options): """Sets up the logging. @param options: Parsed options """ fmt = "%(asctime)s: %(threadName)s " if options.debug or options.verbose: fmt += "%(levelname)s " fmt += "%(message)s" formatter = logging.Formatter(fmt) file_handler = logging.FileHandler(constants.LOG_SETUP_SSH) stderr_handler = logging.StreamHandler() stderr_handler.setFormatter(formatter) file_handler.setFormatter(formatter) file_handler.setLevel(logging.INFO) if options.debug: stderr_handler.setLevel(logging.DEBUG) elif options.verbose: stderr_handler.setLevel(logging.INFO) else: stderr_handler.setLevel(logging.WARNING) root_logger = logging.getLogger("") root_logger.setLevel(logging.NOTSET) root_logger.addHandler(stderr_handler) root_logger.addHandler(file_handler) # This is the paramiko logger instance paramiko_logger = logging.getLogger("paramiko") paramiko_logger.addHandler(file_handler) # We don't want to debug Paramiko, so filter anything below warning paramiko_logger.setLevel(logging.WARNING) def LoadPrivateKeys(options): """Load the list of available private keys. It loads the standard ssh key from disk and then tries to connect to the ssh agent too. @rtype: list @return: a list of C{paramiko.PKey} """ if options.key_type == "rsa": pkclass = paramiko.RSAKey elif options.key_type == "dsa": pkclass = paramiko.DSSKey else: logging.critical("Unknown key type %s selected (choose either rsa or dsa)", options.key_type) sys.exit(1) try: private_key = pkclass.from_private_key_file(options.private_key) except (paramiko.SSHException, EnvironmentError), err: logging.critical("Can't load private key %s: %s", options.private_key, err) sys.exit(1) try: agent = paramiko.Agent() agent_keys = agent.get_keys() except paramiko.SSHException, err: # this will only be seen when the agent is broken/uses invalid # protocol; for non-existing agent, get_keys() will just return an # empty tuple logging.warning("Can't connect to the ssh agent: %s; skipping its use", err) agent_keys = [] return [private_key] + list(agent_keys) def _FormatFingerprint(fpr): """Formats a paramiko.PKey.get_fingerprint() human readable. @param fpr: The fingerprint to be formatted @return: A human readable fingerprint """ return ssh.FormatParamikoFingerprint(paramiko.util.hexify(fpr)) def LoginViaKeys(transport, username, keys): """Try to login on the given transport via a list of keys. @param transport: the transport to use @param username: the username to login as @type keys: list @param keys: list of C{paramiko.PKey} to use for authentication @rtype: boolean @return: True or False depending on whether the login was successfull or not """ for private_key in keys: try: transport.auth_publickey(username, private_key) fpr = _FormatFingerprint(private_key.get_fingerprint()) if isinstance(private_key, paramiko.AgentKey): logging.debug("Authentication via the ssh-agent key %s", fpr) else: logging.debug("Authenticated via public key %s", fpr) return True except paramiko.SSHException: continue else: # all keys exhausted return False def LoadKnownHosts(): """Load the known hosts. @return: paramiko.util.load_host_keys dict """ homedir = utils.GetHomeDir(constants.GANETI_RUNAS) known_hosts = os.path.join(homedir, ".ssh", "known_hosts") try: return paramiko.util.load_host_keys(known_hosts) except EnvironmentError: # We didn't find the path, silently ignore and return an empty dict return {} def _VerifyServerKey(transport, host, host_keys): """Verify the server keys. @param transport: A paramiko.transport instance @param host: Name of the host we verify @param host_keys: Loaded host keys @raises HostkeyVerificationError: When the host identify couldn't be verified """ server_key = transport.get_remote_server_key() keytype = server_key.get_name() our_server_key = host_keys.get(host, {}).get(keytype, None) if not our_server_key: hexified_key = _FormatFingerprint(server_key.get_fingerprint()) msg = ("Unable to verify hostkey of host %s: %s. Do you want to accept" " it?" % (host, hexified_key)) if cli.AskUser(msg): our_server_key = server_key if our_server_key != server_key: raise HostKeyVerificationError("Unable to verify host identity") def main(): """Main routine. """ (options, args) = ParseOptions() SetupLogging(options) all_keys = LoadPrivateKeys(options) passwd = None username = constants.GANETI_RUNAS ssh_port = netutils.GetDaemonPort("ssh") host_keys = LoadKnownHosts() # Below, we need to join() the transport objects, as otherwise the # following happens: # - the main thread finishes # - the atexit functions run (in the main thread), and cause the # logging file to be closed # - a tiny bit later, the transport thread is finally ending, and # wants to log one more message, which fails as the file is closed # now success = True for host in args: logging.info("Configuring %s", host) transport = paramiko.Transport((host, ssh_port)) try: try: transport.start_client() if options.ssh_key_check: _VerifyServerKey(transport, host, host_keys) try: if LoginViaKeys(transport, username, all_keys): logging.info("Authenticated to %s via public key", host) else: if all_keys: logging.warning("Authentication to %s via public key failed," " trying password", host) if passwd is None: passwd = getpass.getpass(prompt="%s password:" % username) transport.auth_password(username=username, password=passwd) logging.info("Authenticated to %s via password", host) except paramiko.SSHException, err: raise AuthError("Auth error TODO" % err) if not _CheckJoin(transport): if not options.force_join: raise JoinCheckError(("Host %s failed join check; Please verify" " that the host was not previously joined" " to another cluster and use --force-join" " to continue") % host) logging.warning("Host %s failed join check, forced to continue", host) SetupSSH(transport) logging.info("%s successfully configured", host) finally: transport.close() # this is needed for compatibility with older Paramiko or Python # versions transport.join() except AuthError, err: logging.error("Authentication error: %s", err) success = False break except HostKeyVerificationError, err: logging.error("Host key verification error: %s", err) success = False except Exception, err: logging.exception("During setup of %s: %s", host, err) success = False if success: sys.exit(constants.EXIT_SUCCESS) sys.exit(constants.EXIT_FAILURE) if __name__ == "__main__": main()