Merge branch 'devel-2.6'
[ganeti-local] / test / ganeti.ssh_unittest.py
index bd6c951..a2c13cd 100755 (executable)
@@ -24,6 +24,7 @@
 import os
 import tempfile
 import unittest
+import shutil
 
 import testutils
 import mocks
@@ -31,6 +32,7 @@ import mocks
 from ganeti import constants
 from ganeti import utils
 from ganeti import ssh
+from ganeti import errors
 
 
 class TestKnownHosts(testutils.GanetiTestCase):
@@ -47,12 +49,103 @@ class TestKnownHosts(testutils.GanetiTestCase):
         "%s ssh-rsa %s\n" % (cfg.GetClusterName(),
                              mocks.FAKE_CLUSTER_KEY))
 
-  def testFormatParamikoFingerprintCorrect(self):
-    self.assertEqual(ssh.FormatParamikoFingerprint("C0Ffee"), "c0:ff:ee")
 
-  def testFormatParamikoFingerprintNotDividableByTwo(self):
-    self.assertRaises(AssertionError, ssh.FormatParamikoFingerprint, "C0Ffe")
+class TestGetUserFiles(unittest.TestCase):
+  def setUp(self):
+    self.tmpdir = tempfile.mkdtemp()
+
+  def tearDown(self):
+    shutil.rmtree(self.tmpdir)
+
+  @staticmethod
+  def _GetNoHomedir(_):
+    return None
+
+  def _GetTempHomedir(self, _):
+    return self.tmpdir
+
+  def testNonExistantUser(self):
+    for kind in constants.SSHK_ALL:
+      self.assertRaises(errors.OpExecError, ssh.GetUserFiles, "example",
+                        kind=kind, _homedir_fn=self._GetNoHomedir)
+
+  def testUnknownKind(self):
+    kind = "something-else"
+    assert kind not in constants.SSHK_ALL
+    self.assertRaises(errors.ProgrammerError, ssh.GetUserFiles, "example4645",
+                      kind=kind, _homedir_fn=self._GetTempHomedir)
+
+    self.assertEqual(os.listdir(self.tmpdir), [])
+
+  def testNoSshDirectory(self):
+    for kind in constants.SSHK_ALL:
+      self.assertRaises(errors.OpExecError, ssh.GetUserFiles, "example29694",
+                        kind=kind, _homedir_fn=self._GetTempHomedir)
+      self.assertEqual(os.listdir(self.tmpdir), [])
+
+  def testSshIsFile(self):
+    utils.WriteFile(os.path.join(self.tmpdir, ".ssh"), data="")
+    for kind in constants.SSHK_ALL:
+      self.assertRaises(errors.OpExecError, ssh.GetUserFiles, "example26237",
+                        kind=kind, _homedir_fn=self._GetTempHomedir)
+      self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
+
+  def testMakeSshDirectory(self):
+    sshdir = os.path.join(self.tmpdir, ".ssh")
+
+    self.assertEqual(os.listdir(self.tmpdir), [])
+
+    for kind in constants.SSHK_ALL:
+      ssh.GetUserFiles("example20745", mkdir=True, kind=kind,
+                       _homedir_fn=self._GetTempHomedir)
+      self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
+      self.assertEqual(os.stat(sshdir).st_mode & 0777, 0700)
+
+  def testFilenames(self):
+    sshdir = os.path.join(self.tmpdir, ".ssh")
+
+    os.mkdir(sshdir)
+
+    for kind in constants.SSHK_ALL:
+      result = ssh.GetUserFiles("example15103", mkdir=False, kind=kind,
+                                _homedir_fn=self._GetTempHomedir)
+      self.assertEqual(result, [
+        os.path.join(self.tmpdir, ".ssh", "id_%s" % kind),
+        os.path.join(self.tmpdir, ".ssh", "id_%s.pub" % kind),
+        os.path.join(self.tmpdir, ".ssh", "authorized_keys"),
+        ])
+
+      self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
+      self.assertEqual(os.listdir(sshdir), [])
+
+  def testNoDirCheck(self):
+    self.assertEqual(os.listdir(self.tmpdir), [])
+
+    for kind in constants.SSHK_ALL:
+      ssh.GetUserFiles("example14528", mkdir=False, dircheck=False, kind=kind,
+                       _homedir_fn=self._GetTempHomedir)
+      self.assertEqual(os.listdir(self.tmpdir), [])
+
+  def testGetAllUserFiles(self):
+    result = ssh.GetAllUserFiles("example7475", mkdir=False, dircheck=False,
+                                 _homedir_fn=self._GetTempHomedir)
+    self.assertEqual(result,
+      (os.path.join(self.tmpdir, ".ssh", "authorized_keys"), {
+        constants.SSHK_RSA:
+          (os.path.join(self.tmpdir, ".ssh", "id_rsa"),
+           os.path.join(self.tmpdir, ".ssh", "id_rsa.pub")),
+        constants.SSHK_DSA:
+          (os.path.join(self.tmpdir, ".ssh", "id_dsa"),
+           os.path.join(self.tmpdir, ".ssh", "id_dsa.pub")),
+      }))
+    self.assertEqual(os.listdir(self.tmpdir), [])
+
+  def testGetAllUserFilesNoDirectoryNoMkdir(self):
+    self.assertRaises(errors.OpExecError, ssh.GetAllUserFiles,
+                      "example17270", mkdir=False, dircheck=True,
+                      _homedir_fn=self._GetTempHomedir)
+    self.assertEqual(os.listdir(self.tmpdir), [])
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
   testutils.GanetiTestProgram()