X-Git-Url: https://code.grnet.gr/git/ganeti-local/blobdiff_plain/a3f9f296fd8d14233b1d28a8168e34f8c485d69a..df07c18f5186c3018a45333accd58b5f70a1e581:/lib/ssh.py diff --git a/lib/ssh.py b/lib/ssh.py index e87b19d..c694341 100644 --- a/lib/ssh.py +++ b/lib/ssh.py @@ -1,7 +1,7 @@ # # -# Copyright (C) 2006, 2007 Google Inc. +# Copyright (C) 2006, 2007, 2010, 2011 Google Inc. # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by @@ -30,52 +30,100 @@ import logging from ganeti import utils from ganeti import errors from ganeti import constants +from ganeti import netutils +from ganeti import pathutils +from ganeti import vcluster +from ganeti import compat + + +def GetUserFiles(user, mkdir=False, dircheck=True, kind=constants.SSHK_DSA, + _homedir_fn=None): + """Return the paths of a user's SSH files. + + @type user: string + @param user: Username + @type mkdir: bool + @param mkdir: Whether to create ".ssh" directory if it doesn't exist + @type dircheck: bool + @param dircheck: Whether to check if ".ssh" directory exists + @type kind: string + @param kind: One of L{constants.SSHK_ALL} + @rtype: tuple; (string, string, string) + @return: Tuple containing three file system paths; the private SSH key file, + the public SSH key file and the user's C{authorized_keys} file + @raise errors.OpExecError: When home directory of the user can not be + determined + @raise errors.OpExecError: Regardless of the C{mkdir} parameters, this + exception is raised if C{~$user/.ssh} is not a directory and C{dircheck} + is set to C{True} + """ + if _homedir_fn is None: + _homedir_fn = utils.GetHomeDir + + user_dir = _homedir_fn(user) + if not user_dir: + raise errors.OpExecError("Cannot resolve home of user '%s'" % user) + + if kind == constants.SSHK_DSA: + suffix = "dsa" + elif kind == constants.SSHK_RSA: + suffix = "rsa" + else: + raise errors.ProgrammerError("Unknown SSH key kind '%s'" % kind) -def GetUserFiles(user, mkdir=False): - """Return the paths of a user's ssh files. + ssh_dir = utils.PathJoin(user_dir, ".ssh") + if mkdir: + utils.EnsureDirs([(ssh_dir, constants.SECURE_DIR_MODE)]) + elif dircheck and not os.path.isdir(ssh_dir): + raise errors.OpExecError("Path %s is not a directory" % ssh_dir) - The function will return a triplet (priv_key_path, pub_key_path, - auth_key_path) that are used for ssh authentication. Currently, the - keys used are DSA keys, so this function will return: - (~user/.ssh/id_dsa, ~user/.ssh/id_dsa.pub, - ~user/.ssh/authorized_keys). + return [utils.PathJoin(ssh_dir, base) + for base in ["id_%s" % suffix, "id_%s.pub" % suffix, + "authorized_keys"]] - If the optional parameter mkdir is True, the ssh directory will be - created if it doesn't exist. - Regardless of the mkdir parameters, the script will raise an error - if ~user/.ssh is not a directory. +def GetAllUserFiles(user, mkdir=False, dircheck=True, _homedir_fn=None): + """Wrapper over L{GetUserFiles} to retrieve files for all SSH key types. + + See L{GetUserFiles} for details. + + @rtype: tuple; (string, dict with string as key, tuple of (string, string) as + value) """ - user_dir = utils.GetHomeDir(user) - if not user_dir: - raise errors.OpExecError("Cannot resolve home of user %s" % user) + helper = compat.partial(GetUserFiles, user, mkdir=mkdir, dircheck=dircheck, + _homedir_fn=_homedir_fn) + result = [(kind, helper(kind=kind)) for kind in constants.SSHK_ALL] - ssh_dir = os.path.join(user_dir, ".ssh") - if not os.path.lexists(ssh_dir): - if mkdir: - try: - os.mkdir(ssh_dir, 0700) - except EnvironmentError, err: - raise errors.OpExecError("Can't create .ssh dir for user %s: %s" % - (user, str(err))) - elif not os.path.isdir(ssh_dir): - raise errors.OpExecError("path ~%s/.ssh is not a directory" % user) + authorized_keys = [i for (_, (_, _, i)) in result] - return [os.path.join(ssh_dir, base) - for base in ["id_dsa", "id_dsa.pub", "authorized_keys"]] + assert len(frozenset(authorized_keys)) == 1, \ + "Different paths for authorized_keys were returned" + + return (authorized_keys[0], + dict((kind, (privkey, pubkey)) + for (kind, (privkey, pubkey, _)) in result)) class SshRunner: """Wrapper for SSH commands. """ - def __init__(self, cluster_name): + def __init__(self, cluster_name, ipv6=False): + """Initializes this class. + + @type cluster_name: str + @param cluster_name: name of the cluster + @type ipv6: bool + @param ipv6: If true, force ssh to use IPv6 addresses only + + """ self.cluster_name = cluster_name + self.ipv6 = ipv6 def _BuildSshOptions(self, batch, ask_key, use_cluster_key, - strict_host_check): + strict_host_check, private_key=None, quiet=True): """Builds a list with needed SSH options. @param batch: same as ssh's batch option @@ -84,21 +132,30 @@ class SshRunner: @param use_cluster_key: if True, use the cluster name as the HostKeyAlias name @param strict_host_check: this makes the host key checking strict + @param private_key: use this private key instead of the default + @param quiet: whether to enable -q to ssh @rtype: list - @return: the list of options ready to use in L{utils.RunCmd} + @return: the list of options ready to use in L{utils.process.RunCmd} """ options = [ "-oEscapeChar=none", "-oHashKnownHosts=no", - "-oGlobalKnownHostsFile=%s" % constants.SSH_KNOWN_HOSTS_FILE, + "-oGlobalKnownHostsFile=%s" % pathutils.SSH_KNOWN_HOSTS_FILE, "-oUserKnownHostsFile=/dev/null", + "-oCheckHostIp=no", ] if use_cluster_key: options.append("-oHostKeyAlias=%s" % self.cluster_name) + if quiet: + options.append("-q") + + if private_key: + options.append("-i%s" % private_key) + # TODO: Too many boolean options, maybe convert them to more descriptive # constants. @@ -114,15 +171,26 @@ class SshRunner: else: options.append("-oStrictHostKeyChecking=no") - elif ask_key: - options.extend([ - "-oStrictHostKeyChecking=ask", - ]) + else: + # non-batch mode + + if ask_key: + options.append("-oStrictHostKeyChecking=ask") + elif strict_host_check: + options.append("-oStrictHostKeyChecking=yes") + else: + options.append("-oStrictHostKeyChecking=no") + + if self.ipv6: + options.append("-6") + else: + options.append("-4") return options def BuildCmd(self, hostname, user, command, batch=True, ask_key=False, - tty=False, use_cluster_key=True, strict_host_check=True): + tty=False, use_cluster_key=True, strict_host_check=True, + private_key=None, quiet=True): """Build an ssh command to execute a command on a remote node. @param hostname: the target host, string @@ -135,16 +203,29 @@ class SshRunner: @param use_cluster_key: whether to expect and use the cluster-global SSH key @param strict_host_check: whether to check the host's SSH key at all + @param private_key: use this private key instead of the default + @param quiet: whether to enable -q to ssh @return: the ssh call to run 'command' on the remote host. """ - argv = [constants.SSH, "-q"] + argv = [constants.SSH] argv.extend(self._BuildSshOptions(batch, ask_key, use_cluster_key, - strict_host_check)) + strict_host_check, private_key, + quiet=quiet)) if tty: - argv.append("-t") - argv.extend(["%s@%s" % (user, hostname), command]) + argv.extend(["-t", "-t"]) + + argv.append("%s@%s" % (user, hostname)) + + # Insert variables for virtual nodes + argv.extend("export %s=%s;" % + (utils.ShellQuote(name), utils.ShellQuote(value)) + for (name, value) in + vcluster.EnvironmentForHost(hostname).items()) + + argv.append(command) + return argv def Run(self, *args, **kwargs): @@ -155,8 +236,8 @@ class SshRunner: Args: see SshRunner.BuildCmd. - @rtype: L{utils.RunResult} - @return: the result as from L{utils.RunCmd()} + @rtype: L{utils.process.RunResult} + @return: the result as from L{utils.process.RunCmd()} """ return utils.RunCmd(self.BuildCmd(*args, **kwargs)) @@ -179,16 +260,19 @@ class SshRunner: logging.error("File %s does not exist", filename) return False - command = [constants.SCP, "-q", "-p"] + command = [constants.SCP, "-p"] command.extend(self._BuildSshOptions(True, False, True, True)) command.append(filename) - command.append("%s:%s" % (node, filename)) + if netutils.IP6Address.IsValid(node): + node = netutils.FormatAddress((node, None)) + + command.append("%s:%s" % (node, vcluster.ExchangeNodeRoot(node, filename))) result = utils.RunCmd(command) if result.failed: - logging.error("Copy to node %s failed (%s) error %s," - " command was %s", + logging.error("Copy to node %s failed (%s) error '%s'," + " command was '%s'", node, result.fail_reason, result.output, result.cmd) return not result.failed @@ -201,7 +285,7 @@ class SshRunner: connected to). This is used to detect problems in ssh known_hosts files - (conflicting known hosts) and incosistencies between dns/hosts + (conflicting known hosts) and inconsistencies between dns/hosts entries and local machine names @param node: nodename of a host to check; can be short or @@ -212,19 +296,32 @@ class SshRunner: - detail: string with details """ - retval = self.Run(node, 'root', 'hostname') + cmd = ("if test -z \"$GANETI_HOSTNAME\"; then" + " hostname --fqdn;" + "else" + " echo \"$GANETI_HOSTNAME\";" + "fi") + retval = self.Run(node, constants.SSH_LOGIN_USER, cmd, quiet=False) if retval.failed: msg = "ssh problem" output = retval.output if output: msg += ": %s" % output + else: + msg += ": %s (no output)" % retval.fail_reason + logging.error("Command %s failed: %s", retval.cmd, msg) return False, msg remotehostname = retval.stdout.strip() if not remotehostname or remotehostname != node: - return False, "hostname mismatch, got %s" % remotehostname + if node.startswith(remotehostname + "."): + msg = "hostname not FQDN" + else: + msg = "hostname mismatch" + return False, ("%s: expected %s but got %s" % + (msg, node, remotehostname)) return True, "host matches" @@ -233,6 +330,10 @@ def WriteKnownHostsFile(cfg, file_name): """Writes the cluster-wide equally known_hosts file. """ - utils.WriteFile(file_name, mode=0600, - data="%s ssh-rsa %s\n" % (cfg.GetClusterName(), - cfg.GetHostKey())) + data = "" + if cfg.GetRsaHostKey(): + data += "%s ssh-rsa %s\n" % (cfg.GetClusterName(), cfg.GetRsaHostKey()) + if cfg.GetDsaHostKey(): + data += "%s ssh-dss %s\n" % (cfg.GetClusterName(), cfg.GetDsaHostKey()) + + utils.WriteFile(file_name, mode=0600, data=data)