Add constants for storage types to constants.py
[ganeti-local] / lib / ssh.py
index b61f7c4..dba13df 100644 (file)
@@ -26,7 +26,6 @@
 
 import os
 import logging
-import re
 
 from ganeti import utils
 from ganeti import errors
@@ -34,28 +33,19 @@ from ganeti import constants
 from ganeti import netutils
 from ganeti import pathutils
 from ganeti import vcluster
+from ganeti import compat
 
 
-def FormatParamikoFingerprint(fingerprint):
-  """Format paramiko PKey fingerprint.
-
-  @type fingerprint: str
-  @param fingerprint: PKey fingerprint
-  @return: The string hex representation of the fingerprint
-
-  """
-  assert len(fingerprint) % 2 == 0
-  return ":".join(re.findall(r"..", fingerprint.lower()))
-
-
-def GetUserFiles(user, mkdir=False, kind=constants.SSHK_DSA,
-                 _homedir_fn=utils.GetHomeDir):
+def GetUserFiles(user, mkdir=False, dircheck=True, kind=constants.SSHK_DSA,
+                 _homedir_fn=None):
   """Return the paths of a user's SSH files.
 
   @type user: string
   @param user: Username
   @type mkdir: bool
   @param mkdir: Whether to create ".ssh" directory if it doesn't exist
+  @type dircheck: bool
+  @param dircheck: Whether to check if ".ssh" directory exists
   @type kind: string
   @param kind: One of L{constants.SSHK_ALL}
   @rtype: tuple; (string, string, string)
@@ -64,9 +54,13 @@ def GetUserFiles(user, mkdir=False, kind=constants.SSHK_DSA,
   @raise errors.OpExecError: When home directory of the user can not be
     determined
   @raise errors.OpExecError: Regardless of the C{mkdir} parameters, this
-    exception is raised if C{~$user/.ssh} is not a directory
+    exception is raised if C{~$user/.ssh} is not a directory and C{dircheck}
+    is set to C{True}
 
   """
+  if _homedir_fn is None:
+    _homedir_fn = utils.GetHomeDir
+
   user_dir = _homedir_fn(user)
   if not user_dir:
     raise errors.OpExecError("Cannot resolve home of user '%s'" % user)
@@ -81,7 +75,7 @@ def GetUserFiles(user, mkdir=False, kind=constants.SSHK_DSA,
   ssh_dir = utils.PathJoin(user_dir, ".ssh")
   if mkdir:
     utils.EnsureDirs([(ssh_dir, constants.SECURE_DIR_MODE)])
-  elif not os.path.isdir(ssh_dir):
+  elif dircheck and not os.path.isdir(ssh_dir):
     raise errors.OpExecError("Path %s is not a directory" % ssh_dir)
 
   return [utils.PathJoin(ssh_dir, base)
@@ -89,6 +83,29 @@ def GetUserFiles(user, mkdir=False, kind=constants.SSHK_DSA,
                        "authorized_keys"]]
 
 
+def GetAllUserFiles(user, mkdir=False, dircheck=True, _homedir_fn=None):
+  """Wrapper over L{GetUserFiles} to retrieve files for all SSH key types.
+
+  See L{GetUserFiles} for details.
+
+  @rtype: tuple; (string, dict with string as key, tuple of (string, string) as
+    value)
+
+  """
+  helper = compat.partial(GetUserFiles, user, mkdir=mkdir, dircheck=dircheck,
+                          _homedir_fn=_homedir_fn)
+  result = [(kind, helper(kind=kind)) for kind in constants.SSHK_ALL]
+
+  authorized_keys = [i for (_, (_, _, i)) in result]
+
+  assert len(frozenset(authorized_keys)) == 1, \
+    "Different paths for authorized_keys were returned"
+
+  return (authorized_keys[0],
+          dict((kind, (privkey, pubkey))
+               for (kind, (privkey, pubkey, _)) in result))
+
+
 class SshRunner:
   """Wrapper for SSH commands.