Shared file support for tools/burnin
[ganeti-local] / tools / setup-ssh
index ff9880d..112caf6 100755 (executable)
@@ -38,6 +38,7 @@ from ganeti import cli
 from ganeti import constants
 from ganeti import errors
 from ganeti import netutils
+from ganeti import ssconf
 from ganeti import ssh
 from ganeti import utils
 
@@ -48,6 +49,59 @@ class RemoteCommandError(errors.GenericError):
   """
 
 
+class JoinCheckError(errors.GenericError):
+  """Exception raised if join check fails.
+
+  """
+
+
+class HostKeyVerificationError(errors.GenericError):
+  """Exception if host key do not match.
+
+  """
+
+
+class AuthError(errors.GenericError):
+  """Exception for authentication errors to hosts.
+
+  """
+
+
+def _CheckJoin(transport):
+  """Checks if a join is safe or dangerous.
+
+  Note: This function relies on the fact, that all
+  hosts have the same configuration at compile time of
+  Ganeti. So that the constants do not mismatch.
+
+  @param transport: The paramiko transport instance
+  @return: True if the join is safe; False otherwise
+
+  """
+  sftp = transport.open_sftp_client()
+  ss = ssconf.SimpleStore()
+  ss_cluster_name_path = ss.KeyToFilename(constants.SS_CLUSTER_NAME)
+
+  cluster_files = [
+    (constants.NODED_CERT_FILE, utils.ReadFile(constants.NODED_CERT_FILE)),
+    (ss_cluster_name_path, utils.ReadFile(ss_cluster_name_path)),
+    ]
+
+  for (filename, local_content) in cluster_files:
+    try:
+      remote_content = _ReadSftpFile(sftp, filename)
+    except IOError, err:
+      # Assume file does not exist. Paramiko's error reporting is lacking.
+      logging.debug("Failed to read %s: %s", filename, err)
+      continue
+
+    if remote_content != local_content:
+      logging.error("File %s doesn't match local version", filename)
+      return False
+
+  return True
+
+
 def _RunRemoteCommand(transport, command):
   """Invokes and wait for the command over SSH.
 
@@ -84,6 +138,21 @@ def _InvokeDaemonUtil(transport, command):
   _RunRemoteCommand(transport, "%s %s" % (constants.DAEMON_UTIL, command))
 
 
+def _ReadSftpFile(sftp, filename):
+  """Reads a file over sftp.
+
+  @param sftp: An open paramiko SFTP client
+  @param filename: The filename of the file to read
+  @return: The content of the file
+
+  """
+  remote_file = sftp.open(filename, "r")
+  try:
+    return remote_file.read()
+  finally:
+    remote_file.close()
+
+
 def _WriteSftpFile(sftp, name, perm, data):
   """SFTPs data to a remote file.
 
@@ -126,17 +195,25 @@ def SetupSSH(transport):
 
   try:
     sftp.mkdir(auth_path, 0700)
-  except IOError:
+  except IOError, err:
     # Sadly paramiko doesn't provide errno or similiar
     # so we can just assume that the path already exists
-    logging.info("Path %s seems already to exist on remote node. Ignoring.",
-                 auth_path)
+    logging.info("Assuming directory %s on remote node exists: %s",
+                 auth_path, err)
 
   for name, (data, perm) in filemap.iteritems():
     _WriteSftpFile(sftp, name, perm, data)
 
   authorized_keys = sftp.open(auth_keys, "a+")
   try:
+    # Due to the way SFTPFile and BufferedFile are implemented,
+    # opening in a+ mode and then issuing a read(), readline() or
+    # iterating over the file (which uses read() internally) will see
+    # an empty file, since the paramiko internal file position and the
+    # OS-level file-position are desynchronized; therefore, we issue
+    # an explicit seek to resynchronize these; writes should (note
+    # should) still go to the right place
+    authorized_keys.seek(0, 0)
     # We don't have to close, as the close happened already in AddAuthorizedKey
     utils.AddAuthorizedKey(authorized_keys, filemap[pub_key][0])
   finally:
@@ -145,30 +222,28 @@ def SetupSSH(transport):
   _InvokeDaemonUtil(transport, "reload-ssh-keys")
 
 
-def SetupNodeDaemon(transport):
-  """Sets the node daemon up on the other side.
-
-  @param transport: The paramiko transport instance
-
-  """
-  noded_cert = utils.ReadFile(constants.NODED_CERT_FILE)
-
-  sftp = transport.open_sftp_client()
-  _WriteSftpFile(sftp, constants.NODED_CERT_FILE, 0400, noded_cert)
-
-  _InvokeDaemonUtil(transport, "start %s" % constants.NODED)
-
-
 def ParseOptions():
   """Parses options passed to program.
 
   """
   program = os.path.basename(sys.argv[0])
 
-  parser = optparse.OptionParser(usage=("%prog [--debug|--verbose] <node>"
-                                        " <node...>"), prog=program)
+  parser = optparse.OptionParser(usage=("%prog [--debug|--verbose] [--force]"
+                                        " <node> <node...>"), prog=program)
   parser.add_option(cli.DEBUG_OPT)
   parser.add_option(cli.VERBOSE_OPT)
+  parser.add_option(cli.NOSSH_KEYCHECK_OPT)
+  default_key = ssh.GetUserFiles(constants.GANETI_RUNAS)[0]
+  parser.add_option(optparse.Option("-f", dest="private_key",
+                                    default=default_key,
+                                    help="The private key to (try to) use for"
+                                    "authentication "))
+  parser.add_option(optparse.Option("--key-type", dest="key_type",
+                                    choices=("rsa", "dsa"), default="dsa",
+                                    help="The private key type (rsa or dsa)"))
+  parser.add_option(optparse.Option("-j", "--force-join", dest="force_join",
+                                    action="store_true", default=False,
+                                    help="Force the join of the host"))
 
   (options, args) = parser.parse_args()
 
@@ -202,7 +277,7 @@ def SetupLogging(options):
     stderr_handler.setLevel(logging.WARNING)
 
   root_logger = logging.getLogger("")
-  root_logger.setLevel(logging.INFO)
+  root_logger.setLevel(logging.NOTSET)
   root_logger.addHandler(stderr_handler)
   root_logger.addHandler(file_handler)
 
@@ -213,6 +288,125 @@ def SetupLogging(options):
   paramiko_logger.setLevel(logging.WARNING)
 
 
+def LoadPrivateKeys(options):
+  """Load the list of available private keys.
+
+  It loads the standard ssh key from disk and then tries to connect to
+  the ssh agent too.
+
+  @rtype: list
+  @return: a list of C{paramiko.PKey}
+
+  """
+  if options.key_type == "rsa":
+    pkclass = paramiko.RSAKey
+  elif options.key_type == "dsa":
+    pkclass = paramiko.DSSKey
+  else:
+    logging.critical("Unknown key type %s selected (choose either rsa or dsa)",
+                     options.key_type)
+    sys.exit(1)
+
+  try:
+    private_key = pkclass.from_private_key_file(options.private_key)
+  except (paramiko.SSHException, EnvironmentError), err:
+    logging.critical("Can't load private key %s: %s", options.private_key, err)
+    sys.exit(1)
+
+  try:
+    agent = paramiko.Agent()
+    agent_keys = agent.get_keys()
+  except paramiko.SSHException, err:
+    # this will only be seen when the agent is broken/uses invalid
+    # protocol; for non-existing agent, get_keys() will just return an
+    # empty tuple
+    logging.warning("Can't connect to the ssh agent: %s; skipping its use",
+                    err)
+    agent_keys = []
+
+  return [private_key] + list(agent_keys)
+
+
+def _FormatFingerprint(fpr):
+  """Formats a paramiko.PKey.get_fingerprint() human readable.
+
+  @param fpr: The fingerprint to be formatted
+  @return: A human readable fingerprint
+
+  """
+  return ssh.FormatParamikoFingerprint(paramiko.util.hexify(fpr))
+
+
+def LoginViaKeys(transport, username, keys):
+  """Try to login on the given transport via a list of keys.
+
+  @param transport: the transport to use
+  @param username: the username to login as
+  @type keys: list
+  @param keys: list of C{paramiko.PKey} to use for authentication
+  @rtype: boolean
+  @return: True or False depending on whether the login was
+      successfull or not
+
+  """
+  for private_key in keys:
+    try:
+      transport.auth_publickey(username, private_key)
+      fpr = _FormatFingerprint(private_key.get_fingerprint())
+      if isinstance(private_key, paramiko.AgentKey):
+        logging.debug("Authentication via the ssh-agent key %s", fpr)
+      else:
+        logging.debug("Authenticated via public key %s", fpr)
+      return True
+    except paramiko.SSHException:
+      continue
+  else:
+    # all keys exhausted
+    return False
+
+
+def LoadKnownHosts():
+  """Load the known hosts.
+
+  @return: paramiko.util.load_host_keys dict
+
+  """
+  homedir = utils.GetHomeDir(constants.GANETI_RUNAS)
+  known_hosts = os.path.join(homedir, ".ssh", "known_hosts")
+
+  try:
+    return paramiko.util.load_host_keys(known_hosts)
+  except EnvironmentError:
+    # We didn't find the path, silently ignore and return an empty dict
+    return {}
+
+
+def _VerifyServerKey(transport, host, host_keys):
+  """Verify the server keys.
+
+  @param transport: A paramiko.transport instance
+  @param host: Name of the host we verify
+  @param host_keys: Loaded host keys
+  @raises HostkeyVerificationError: When the host identify couldn't be verified
+
+  """
+
+  server_key = transport.get_remote_server_key()
+  keytype = server_key.get_name()
+
+  our_server_key = host_keys.get(host, {}).get(keytype, None)
+  if not our_server_key:
+    hexified_key = _FormatFingerprint(server_key.get_fingerprint())
+    msg = ("Unable to verify hostkey of host %s: %s. Do you want to accept"
+           " it?" % (host, hexified_key))
+
+    if cli.AskUser(msg):
+      our_server_key = server_key
+
+  if our_server_key != server_key:
+    raise HostKeyVerificationError("Unable to verify host identity")
+
+
 def main():
   """Main routine.
 
@@ -221,8 +415,12 @@ def main():
 
   SetupLogging(options)
 
-  passwd = getpass.getpass(prompt="%s password:" % constants.GANETI_RUNAS)
+  all_keys = LoadPrivateKeys(options)
+
+  passwd = None
+  username = constants.GANETI_RUNAS
   ssh_port = netutils.GetDaemonPort("ssh")
+  host_keys = LoadKnownHosts()
 
   # Below, we need to join() the transport objects, as otherwise the
   # following happens:
@@ -233,30 +431,65 @@ def main():
   #   wants to log one more message, which fails as the file is closed
   #   now
 
+  success = True
+
   for host in args:
+    logging.info("Configuring %s", host)
+
     transport = paramiko.Transport((host, ssh_port))
     try:
-      transport.connect(username=constants.GANETI_RUNAS, password=passwd)
-    except Exception, err:
-      logging.error("Connection or authentication failed to host %s: %s",
-                    host, err)
-      transport.close()
-      # this is needed for compatibility with older Paramiko or Python
-      # versions
-      transport.join()
-      continue
-    try:
       try:
+        transport.start_client()
+
+        if options.ssh_key_check:
+          _VerifyServerKey(transport, host, host_keys)
+
+        try:
+          if LoginViaKeys(transport, username, all_keys):
+            logging.info("Authenticated to %s via public key", host)
+          else:
+            if all_keys:
+              logging.warning("Authentication to %s via public key failed,"
+                              " trying password", host)
+            if passwd is None:
+              passwd = getpass.getpass(prompt="%s password:" % username)
+            transport.auth_password(username=username, password=passwd)
+            logging.info("Authenticated to %s via password", host)
+        except paramiko.SSHException, err:
+          raise AuthError("Auth error TODO" % err)
+
+        if not _CheckJoin(transport):
+          if not options.force_join:
+            raise JoinCheckError(("Host %s failed join check; Please verify"
+                                  " that the host was not previously joined"
+                                  " to another cluster and use --force-join"
+                                  " to continue") % host)
+
+          logging.warning("Host %s failed join check, forced to continue",
+                          host)
+
         SetupSSH(transport)
-        SetupNodeDaemon(transport)
-      except errors.GenericError, err:
-        logging.error("While doing setup on host %s an error occured: %s",
-                      host, err)
-    finally:
-      transport.close()
-      # this is needed for compatibility with older Paramiko or Python
-      # versions
-      transport.join()
+        logging.info("%s successfully configured", host)
+      finally:
+        transport.close()
+        # this is needed for compatibility with older Paramiko or Python
+        # versions
+        transport.join()
+    except AuthError, err:
+      logging.error("Authentication error: %s", err)
+      success = False
+      break
+    except HostKeyVerificationError, err:
+      logging.error("Host key verification error: %s", err)
+      success = False
+    except Exception, err:
+      logging.exception("During setup of %s: %s", host, err)
+      success = False
+
+  if success:
+    sys.exit(constants.EXIT_SUCCESS)
+
+  sys.exit(constants.EXIT_FAILURE)
 
 
 if __name__ == "__main__":