return ":".join(re.findall(r"..", fingerprint.lower()))
-def GetUserFiles(user, mkdir=False):
- """Return the paths of a user's ssh files.
-
- The function will return a triplet (priv_key_path, pub_key_path,
- auth_key_path) that are used for ssh authentication. Currently, the
- keys used are DSA keys, so this function will return:
- (~user/.ssh/id_dsa, ~user/.ssh/id_dsa.pub,
- ~user/.ssh/authorized_keys).
-
- If the optional parameter mkdir is True, the ssh directory will be
- created if it doesn't exist.
-
- Regardless of the mkdir parameters, the script will raise an error
- if ~user/.ssh is not a directory.
+def GetUserFiles(user, mkdir=False, kind=constants.SSHK_DSA,
+ _homedir_fn=utils.GetHomeDir):
+ """Return the paths of a user's SSH files.
+
+ @type user: string
+ @param user: Username
+ @type mkdir: bool
+ @param mkdir: Whether to create ".ssh" directory if it doesn't exist
+ @type kind: string
+ @param kind: One of L{constants.SSHK_ALL}
+ @rtype: tuple; (string, string, string)
+ @return: Tuple containing three file system paths; the private SSH key file,
+ the public SSH key file and the user's C{authorized_keys} file
+ @raise errors.OpExecError: When home directory of the user can not be
+ determined
+ @raise errors.OpExecError: Regardless of the C{mkdir} parameters, this
+ exception is raised if C{~$user/.ssh} is not a directory
"""
- user_dir = utils.GetHomeDir(user)
+ user_dir = _homedir_fn(user)
if not user_dir:
- raise errors.OpExecError("Cannot resolve home of user %s" % user)
+ raise errors.OpExecError("Cannot resolve home of user '%s'" % user)
+
+ if kind == constants.SSHK_DSA:
+ suffix = "dsa"
+ elif kind == constants.SSHK_RSA:
+ suffix = "rsa"
+ else:
+ raise errors.ProgrammerError("Unknown SSH key kind '%s'" % kind)
ssh_dir = utils.PathJoin(user_dir, ".ssh")
if mkdir:
raise errors.OpExecError("Path %s is not a directory" % ssh_dir)
return [utils.PathJoin(ssh_dir, base)
- for base in ["id_dsa", "id_dsa.pub", "authorized_keys"]]
+ for base in ["id_%s" % suffix, "id_%s.pub" % suffix,
+ "authorized_keys"]]
class SshRunner:
import os
import tempfile
import unittest
+import shutil
import testutils
import mocks
from ganeti import constants
from ganeti import utils
from ganeti import ssh
+from ganeti import errors
class TestKnownHosts(testutils.GanetiTestCase):
self.assertRaises(AssertionError, ssh.FormatParamikoFingerprint, "C0Ffe")
-if __name__ == '__main__':
+class TestGetUserFiles(unittest.TestCase):
+ def setUp(self):
+ self.tmpdir = tempfile.mkdtemp()
+
+ def tearDown(self):
+ shutil.rmtree(self.tmpdir)
+
+ @staticmethod
+ def _GetNoHomedir(_):
+ return None
+
+ def _GetTempHomedir(self, _):
+ return self.tmpdir
+
+ def testNonExistantUser(self):
+ for kind in constants.SSHK_ALL:
+ self.assertRaises(errors.OpExecError, ssh.GetUserFiles, "example",
+ kind=kind, _homedir_fn=self._GetNoHomedir)
+
+ def testUnknownKind(self):
+ kind = "something-else"
+ assert kind not in constants.SSHK_ALL
+ self.assertRaises(errors.ProgrammerError, ssh.GetUserFiles, "example4645",
+ kind=kind, _homedir_fn=self._GetTempHomedir)
+
+ self.assertEqual(os.listdir(self.tmpdir), [])
+
+ def testNoSshDirectory(self):
+ for kind in constants.SSHK_ALL:
+ self.assertRaises(errors.OpExecError, ssh.GetUserFiles, "example29694",
+ kind=kind, _homedir_fn=self._GetTempHomedir)
+ self.assertEqual(os.listdir(self.tmpdir), [])
+
+ def testSshIsFile(self):
+ utils.WriteFile(os.path.join(self.tmpdir, ".ssh"), data="")
+ for kind in constants.SSHK_ALL:
+ self.assertRaises(errors.OpExecError, ssh.GetUserFiles, "example26237",
+ kind=kind, _homedir_fn=self._GetTempHomedir)
+ self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
+
+ def testMakeSshDirectory(self):
+ sshdir = os.path.join(self.tmpdir, ".ssh")
+
+ self.assertEqual(os.listdir(self.tmpdir), [])
+
+ for kind in constants.SSHK_ALL:
+ ssh.GetUserFiles("example20745", mkdir=True, kind=kind,
+ _homedir_fn=self._GetTempHomedir)
+ self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
+ self.assertEqual(os.stat(sshdir).st_mode & 0777, 0700)
+
+ def testFilenames(self):
+ sshdir = os.path.join(self.tmpdir, ".ssh")
+
+ os.mkdir(sshdir)
+
+ for kind in constants.SSHK_ALL:
+ result = ssh.GetUserFiles("example15103", mkdir=False, kind=kind,
+ _homedir_fn=self._GetTempHomedir)
+ self.assertEqual(result, [
+ os.path.join(self.tmpdir, ".ssh", "id_%s" % kind),
+ os.path.join(self.tmpdir, ".ssh", "id_%s.pub" % kind),
+ os.path.join(self.tmpdir, ".ssh", "authorized_keys"),
+ ])
+
+ self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
+ self.assertEqual(os.listdir(sshdir), [])
+
+
+if __name__ == "__main__":
testutils.GanetiTestProgram()