Remove old "reason" implementation
[ganeti-local] / lib / ssh.py
index eb0be65..dba13df 100644 (file)
@@ -1,7 +1,7 @@
 #
 #
 
 #
 #
 
-# Copyright (C) 2006, 2007, 2010 Google Inc.
+# Copyright (C) 2006, 2007, 2010, 2011 Google Inc.
 #
 # This program is free software; you can redistribute it and/or modify
 # it under the terms of the GNU General Public License as published by
 #
 # This program is free software; you can redistribute it and/or modify
 # it under the terms of the GNU General Public License as published by
 
 import os
 import logging
 
 import os
 import logging
-import re
 
 from ganeti import utils
 from ganeti import errors
 from ganeti import constants
 from ganeti import netutils
 
 from ganeti import utils
 from ganeti import errors
 from ganeti import constants
 from ganeti import netutils
+from ganeti import pathutils
+from ganeti import vcluster
+from ganeti import compat
+
+
+def GetUserFiles(user, mkdir=False, dircheck=True, kind=constants.SSHK_DSA,
+                 _homedir_fn=None):
+  """Return the paths of a user's SSH files.
+
+  @type user: string
+  @param user: Username
+  @type mkdir: bool
+  @param mkdir: Whether to create ".ssh" directory if it doesn't exist
+  @type dircheck: bool
+  @param dircheck: Whether to check if ".ssh" directory exists
+  @type kind: string
+  @param kind: One of L{constants.SSHK_ALL}
+  @rtype: tuple; (string, string, string)
+  @return: Tuple containing three file system paths; the private SSH key file,
+    the public SSH key file and the user's C{authorized_keys} file
+  @raise errors.OpExecError: When home directory of the user can not be
+    determined
+  @raise errors.OpExecError: Regardless of the C{mkdir} parameters, this
+    exception is raised if C{~$user/.ssh} is not a directory and C{dircheck}
+    is set to C{True}
 
 
+  """
+  if _homedir_fn is None:
+    _homedir_fn = utils.GetHomeDir
 
 
-def FormatParamikoFingerprint(fingerprint):
-  """Formats the fingerprint of L{paramiko.PKey.get_fingerprint()}
+  user_dir = _homedir_fn(user)
+  if not user_dir:
+    raise errors.OpExecError("Cannot resolve home of user '%s'" % user)
 
 
-  @type fingerprint: str
-  @param fingerprint: PKey fingerprint
-  @return The string hex representation of the fingerprint
+  if kind == constants.SSHK_DSA:
+    suffix = "dsa"
+  elif kind == constants.SSHK_RSA:
+    suffix = "rsa"
+  else:
+    raise errors.ProgrammerError("Unknown SSH key kind '%s'" % kind)
 
 
-  """
-  assert len(fingerprint) % 2 == 0
-  return ":".join(re.findall(r"..", fingerprint.lower()))
+  ssh_dir = utils.PathJoin(user_dir, ".ssh")
+  if mkdir:
+    utils.EnsureDirs([(ssh_dir, constants.SECURE_DIR_MODE)])
+  elif dircheck and not os.path.isdir(ssh_dir):
+    raise errors.OpExecError("Path %s is not a directory" % ssh_dir)
 
 
+  return [utils.PathJoin(ssh_dir, base)
+          for base in ["id_%s" % suffix, "id_%s.pub" % suffix,
+                       "authorized_keys"]]
 
 
-def GetUserFiles(user, mkdir=False):
-  """Return the paths of a user's ssh files.
 
 
-  The function will return a triplet (priv_key_path, pub_key_path,
-  auth_key_path) that are used for ssh authentication. Currently, the
-  keys used are DSA keys, so this function will return:
-  (~user/.ssh/id_dsa, ~user/.ssh/id_dsa.pub,
-  ~user/.ssh/authorized_keys).
+def GetAllUserFiles(user, mkdir=False, dircheck=True, _homedir_fn=None):
+  """Wrapper over L{GetUserFiles} to retrieve files for all SSH key types.
 
 
-  If the optional parameter mkdir is True, the ssh directory will be
-  created if it doesn't exist.
+  See L{GetUserFiles} for details.
 
 
-  Regardless of the mkdir parameters, the script will raise an error
-  if ~user/.ssh is not a directory.
+  @rtype: tuple; (string, dict with string as key, tuple of (string, string) as
+    value)
 
   """
 
   """
-  user_dir = utils.GetHomeDir(user)
-  if not user_dir:
-    raise errors.OpExecError("Cannot resolve home of user %s" % user)
+  helper = compat.partial(GetUserFiles, user, mkdir=mkdir, dircheck=dircheck,
+                          _homedir_fn=_homedir_fn)
+  result = [(kind, helper(kind=kind)) for kind in constants.SSHK_ALL]
 
 
-  ssh_dir = utils.PathJoin(user_dir, ".ssh")
-  if mkdir:
-    utils.EnsureDirs([(ssh_dir, constants.SECURE_DIR_MODE)])
-  elif not os.path.isdir(ssh_dir):
-    raise errors.OpExecError("Path %s is not a directory" % ssh_dir)
+  authorized_keys = [i for (_, (_, _, i)) in result]
 
 
-  return [utils.PathJoin(ssh_dir, base)
-          for base in ["id_dsa", "id_dsa.pub", "authorized_keys"]]
+  assert len(frozenset(authorized_keys)) == 1, \
+    "Different paths for authorized_keys were returned"
+
+  return (authorized_keys[0],
+          dict((kind, (privkey, pubkey))
+               for (kind, (privkey, pubkey, _)) in result))
 
 
 class SshRunner:
 
 
 class SshRunner:
@@ -106,13 +136,13 @@ class SshRunner:
     @param quiet: whether to enable -q to ssh
 
     @rtype: list
     @param quiet: whether to enable -q to ssh
 
     @rtype: list
-    @return: the list of options ready to use in L{utils.RunCmd}
+    @return: the list of options ready to use in L{utils.process.RunCmd}
 
     """
     options = [
       "-oEscapeChar=none",
       "-oHashKnownHosts=no",
 
     """
     options = [
       "-oEscapeChar=none",
       "-oHashKnownHosts=no",
-      "-oGlobalKnownHostsFile=%s" % constants.SSH_KNOWN_HOSTS_FILE,
+      "-oGlobalKnownHostsFile=%s" % pathutils.SSH_KNOWN_HOSTS_FILE,
       "-oUserKnownHostsFile=/dev/null",
       "-oCheckHostIp=no",
       ]
       "-oUserKnownHostsFile=/dev/null",
       "-oCheckHostIp=no",
       ]
@@ -183,7 +213,17 @@ class SshRunner:
                                       quiet=quiet))
     if tty:
       argv.extend(["-t", "-t"])
                                       quiet=quiet))
     if tty:
       argv.extend(["-t", "-t"])
-    argv.extend(["%s@%s" % (user, hostname), command])
+
+    argv.append("%s@%s" % (user, hostname))
+
+    # Insert variables for virtual nodes
+    argv.extend("export %s=%s;" %
+                (utils.ShellQuote(name), utils.ShellQuote(value))
+                for (name, value) in
+                  vcluster.EnvironmentForHost(hostname).items())
+
+    argv.append(command)
+
     return argv
 
   def Run(self, *args, **kwargs):
     return argv
 
   def Run(self, *args, **kwargs):
@@ -194,8 +234,8 @@ class SshRunner:
 
     Args: see SshRunner.BuildCmd.
 
 
     Args: see SshRunner.BuildCmd.
 
-    @rtype: L{utils.RunResult}
-    @return: the result as from L{utils.RunCmd()}
+    @rtype: L{utils.process.RunResult}
+    @return: the result as from L{utils.process.RunCmd()}
 
     """
     return utils.RunCmd(self.BuildCmd(*args, **kwargs))
 
     """
     return utils.RunCmd(self.BuildCmd(*args, **kwargs))
@@ -224,13 +264,13 @@ class SshRunner:
     if netutils.IP6Address.IsValid(node):
       node = netutils.FormatAddress((node, None))
 
     if netutils.IP6Address.IsValid(node):
       node = netutils.FormatAddress((node, None))
 
-    command.append("%s:%s" % (node, filename))
+    command.append("%s:%s" % (node, vcluster.ExchangeNodeRoot(node, filename)))
 
     result = utils.RunCmd(command)
 
     if result.failed:
 
     result = utils.RunCmd(command)
 
     if result.failed:
-      logging.error("Copy to node %s failed (%s) error %s,"
-                    " command was %s",
+      logging.error("Copy to node %s failed (%s) error '%s',"
+                    " command was '%s'",
                     node, result.fail_reason, result.output, result.cmd)
 
     return not result.failed
                     node, result.fail_reason, result.output, result.cmd)
 
     return not result.failed
@@ -254,7 +294,12 @@ class SshRunner:
         - detail: string with details
 
     """
         - detail: string with details
 
     """
-    retval = self.Run(node, 'root', 'hostname --fqdn')
+    cmd = ("if test -z \"$GANETI_HOSTNAME\"; then"
+           "  hostname --fqdn;"
+           "else"
+           "  echo \"$GANETI_HOSTNAME\";"
+           "fi")
+    retval = self.Run(node, constants.SSH_LOGIN_USER, cmd, quiet=False)
 
     if retval.failed:
       msg = "ssh problem"
 
     if retval.failed:
       msg = "ssh problem"