ssh.GetUserFiles: RSA support, unit tests
authorMichael Hanselmann <hansmi@google.com>
Thu, 18 Oct 2012 15:34:02 +0000 (17:34 +0200)
committerMichael Hanselmann <hansmi@google.com>
Tue, 23 Oct 2012 12:59:05 +0000 (14:59 +0200)
This patch changes “ssh.GetUserFiles” to support two different kinds of
SSH keys, RSA and DSA. Before it would always use DSA. Newly written
unit tests are included.

Signed-off-by: Michael Hanselmann <hansmi@google.com>
Reviewed-by: Iustin Pop <iustin@google.com>

lib/constants.py
lib/ssh.py
test/ganeti.ssh_unittest.py

index 0342ddf..39b6895 100644 (file)
@@ -2044,5 +2044,10 @@ IALLOC_HAIL = "hail"
 FAKE_OP_MASTER_TURNUP = "OP_CLUSTER_IP_TURNUP"
 FAKE_OP_MASTER_TURNDOWN = "OP_CLUSTER_IP_TURNDOWN"
 
+# SSH key types
+SSHK_RSA = "rsa"
+SSHK_DSA = "dsa"
+SSHK_ALL = frozenset([SSHK_RSA, SSHK_DSA])
+
 # Do not re-export imported modules
 del re, _vcsversion, _autoconf, socket, pathutils
index 4c4a18c..b61f7c4 100644 (file)
@@ -48,25 +48,35 @@ def FormatParamikoFingerprint(fingerprint):
   return ":".join(re.findall(r"..", fingerprint.lower()))
 
 
-def GetUserFiles(user, mkdir=False):
-  """Return the paths of a user's ssh files.
-
-  The function will return a triplet (priv_key_path, pub_key_path,
-  auth_key_path) that are used for ssh authentication. Currently, the
-  keys used are DSA keys, so this function will return:
-  (~user/.ssh/id_dsa, ~user/.ssh/id_dsa.pub,
-  ~user/.ssh/authorized_keys).
-
-  If the optional parameter mkdir is True, the ssh directory will be
-  created if it doesn't exist.
-
-  Regardless of the mkdir parameters, the script will raise an error
-  if ~user/.ssh is not a directory.
+def GetUserFiles(user, mkdir=False, kind=constants.SSHK_DSA,
+                 _homedir_fn=utils.GetHomeDir):
+  """Return the paths of a user's SSH files.
+
+  @type user: string
+  @param user: Username
+  @type mkdir: bool
+  @param mkdir: Whether to create ".ssh" directory if it doesn't exist
+  @type kind: string
+  @param kind: One of L{constants.SSHK_ALL}
+  @rtype: tuple; (string, string, string)
+  @return: Tuple containing three file system paths; the private SSH key file,
+    the public SSH key file and the user's C{authorized_keys} file
+  @raise errors.OpExecError: When home directory of the user can not be
+    determined
+  @raise errors.OpExecError: Regardless of the C{mkdir} parameters, this
+    exception is raised if C{~$user/.ssh} is not a directory
 
   """
-  user_dir = utils.GetHomeDir(user)
+  user_dir = _homedir_fn(user)
   if not user_dir:
-    raise errors.OpExecError("Cannot resolve home of user %s" % user)
+    raise errors.OpExecError("Cannot resolve home of user '%s'" % user)
+
+  if kind == constants.SSHK_DSA:
+    suffix = "dsa"
+  elif kind == constants.SSHK_RSA:
+    suffix = "rsa"
+  else:
+    raise errors.ProgrammerError("Unknown SSH key kind '%s'" % kind)
 
   ssh_dir = utils.PathJoin(user_dir, ".ssh")
   if mkdir:
@@ -75,7 +85,8 @@ def GetUserFiles(user, mkdir=False):
     raise errors.OpExecError("Path %s is not a directory" % ssh_dir)
 
   return [utils.PathJoin(ssh_dir, base)
-          for base in ["id_dsa", "id_dsa.pub", "authorized_keys"]]
+          for base in ["id_%s" % suffix, "id_%s.pub" % suffix,
+                       "authorized_keys"]]
 
 
 class SshRunner:
index bd6c951..77960c2 100755 (executable)
@@ -24,6 +24,7 @@
 import os
 import tempfile
 import unittest
+import shutil
 
 import testutils
 import mocks
@@ -31,6 +32,7 @@ import mocks
 from ganeti import constants
 from ganeti import utils
 from ganeti import ssh
+from ganeti import errors
 
 
 class TestKnownHosts(testutils.GanetiTestCase):
@@ -54,5 +56,74 @@ class TestKnownHosts(testutils.GanetiTestCase):
     self.assertRaises(AssertionError, ssh.FormatParamikoFingerprint, "C0Ffe")
 
 
-if __name__ == '__main__':
+class TestGetUserFiles(unittest.TestCase):
+  def setUp(self):
+    self.tmpdir = tempfile.mkdtemp()
+
+  def tearDown(self):
+    shutil.rmtree(self.tmpdir)
+
+  @staticmethod
+  def _GetNoHomedir(_):
+    return None
+
+  def _GetTempHomedir(self, _):
+    return self.tmpdir
+
+  def testNonExistantUser(self):
+    for kind in constants.SSHK_ALL:
+      self.assertRaises(errors.OpExecError, ssh.GetUserFiles, "example",
+                        kind=kind, _homedir_fn=self._GetNoHomedir)
+
+  def testUnknownKind(self):
+    kind = "something-else"
+    assert kind not in constants.SSHK_ALL
+    self.assertRaises(errors.ProgrammerError, ssh.GetUserFiles, "example4645",
+                      kind=kind, _homedir_fn=self._GetTempHomedir)
+
+    self.assertEqual(os.listdir(self.tmpdir), [])
+
+  def testNoSshDirectory(self):
+    for kind in constants.SSHK_ALL:
+      self.assertRaises(errors.OpExecError, ssh.GetUserFiles, "example29694",
+                        kind=kind, _homedir_fn=self._GetTempHomedir)
+      self.assertEqual(os.listdir(self.tmpdir), [])
+
+  def testSshIsFile(self):
+    utils.WriteFile(os.path.join(self.tmpdir, ".ssh"), data="")
+    for kind in constants.SSHK_ALL:
+      self.assertRaises(errors.OpExecError, ssh.GetUserFiles, "example26237",
+                        kind=kind, _homedir_fn=self._GetTempHomedir)
+      self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
+
+  def testMakeSshDirectory(self):
+    sshdir = os.path.join(self.tmpdir, ".ssh")
+
+    self.assertEqual(os.listdir(self.tmpdir), [])
+
+    for kind in constants.SSHK_ALL:
+      ssh.GetUserFiles("example20745", mkdir=True, kind=kind,
+                       _homedir_fn=self._GetTempHomedir)
+      self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
+      self.assertEqual(os.stat(sshdir).st_mode & 0777, 0700)
+
+  def testFilenames(self):
+    sshdir = os.path.join(self.tmpdir, ".ssh")
+
+    os.mkdir(sshdir)
+
+    for kind in constants.SSHK_ALL:
+      result = ssh.GetUserFiles("example15103", mkdir=False, kind=kind,
+                                _homedir_fn=self._GetTempHomedir)
+      self.assertEqual(result, [
+        os.path.join(self.tmpdir, ".ssh", "id_%s" % kind),
+        os.path.join(self.tmpdir, ".ssh", "id_%s.pub" % kind),
+        os.path.join(self.tmpdir, ".ssh", "authorized_keys"),
+        ])
+
+      self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
+      self.assertEqual(os.listdir(sshdir), [])
+
+
+if __name__ == "__main__":
   testutils.GanetiTestProgram()