ssh: Add function to get all of user's SSH files
authorMichael Hanselmann <hansmi@google.com>
Tue, 23 Oct 2012 23:10:36 +0000 (01:10 +0200)
committerMichael Hanselmann <hansmi@google.com>
Fri, 26 Oct 2012 14:27:11 +0000 (16:27 +0200)
This new function returns the file paths for all of a user's SSH-related
files (RSA, DSA and authorized_keys).

Signed-off-by: Michael Hanselmann <hansmi@google.com>
Reviewed-by: Iustin Pop <iustin@google.com>

lib/ssh.py
test/ganeti.ssh_unittest.py

index cec442d..1307148 100644 (file)
@@ -34,6 +34,7 @@ from ganeti import constants
 from ganeti import netutils
 from ganeti import pathutils
 from ganeti import vcluster
+from ganeti import compat
 
 
 def FormatParamikoFingerprint(fingerprint):
@@ -95,6 +96,29 @@ def GetUserFiles(user, mkdir=False, dircheck=True, kind=constants.SSHK_DSA,
                        "authorized_keys"]]
 
 
+def GetAllUserFiles(user, mkdir=False, dircheck=True, _homedir_fn=None):
+  """Wrapper over L{GetUserFiles} to retrieve files for all SSH key types.
+
+  See L{GetUserFiles} for details.
+
+  @rtype: tuple; (string, dict with string as key, tuple of (string, string) as
+    value)
+
+  """
+  helper = compat.partial(GetUserFiles, user, mkdir=mkdir, dircheck=dircheck,
+                          _homedir_fn=_homedir_fn)
+  result = [(kind, helper(kind=kind)) for kind in constants.SSHK_ALL]
+
+  authorized_keys = [i for (_, (_, _, i)) in result]
+
+  assert len(frozenset(authorized_keys)) == 1, \
+    "Different paths for authorized_keys were returned"
+
+  return (authorized_keys[0],
+          dict((kind, (privkey, pubkey))
+               for (kind, (privkey, pubkey, _)) in result))
+
+
 class SshRunner:
   """Wrapper for SSH commands.
 
index 419c05e..bb4f015 100755 (executable)
@@ -132,6 +132,26 @@ class TestGetUserFiles(unittest.TestCase):
                        _homedir_fn=self._GetTempHomedir)
       self.assertEqual(os.listdir(self.tmpdir), [])
 
+  def testGetAllUserFiles(self):
+    result = ssh.GetAllUserFiles("example7475", mkdir=False, dircheck=False,
+                                 _homedir_fn=self._GetTempHomedir)
+    self.assertEqual(result,
+      (os.path.join(self.tmpdir, ".ssh", "authorized_keys"), {
+        constants.SSHK_RSA:
+          (os.path.join(self.tmpdir, ".ssh", "id_rsa"),
+           os.path.join(self.tmpdir, ".ssh", "id_rsa.pub")),
+        constants.SSHK_DSA:
+          (os.path.join(self.tmpdir, ".ssh", "id_dsa"),
+           os.path.join(self.tmpdir, ".ssh", "id_dsa.pub")),
+      }))
+    self.assertEqual(os.listdir(self.tmpdir), [])
+
+  def testGetAllUserFilesNoDirectoryNoMkdir(self):
+    self.assertRaises(errors.OpExecError, ssh.GetAllUserFiles,
+                      "example17270", mkdir=False, dircheck=True,
+                      _homedir_fn=self._GetTempHomedir)
+    self.assertEqual(os.listdir(self.tmpdir), [])
+
 
 if __name__ == "__main__":
   testutils.GanetiTestProgram()