4 # Copyright (C) 2006, 2007, 2008 Google Inc.
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 # General Public License for more details.
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
22 """Script for unittesting the ssh module"""
32 from ganeti import constants
33 from ganeti import utils
34 from ganeti import ssh
35 from ganeti import errors
38 class TestKnownHosts(testutils.GanetiTestCase):
39 """Test case for function writing the known_hosts file"""
42 testutils.GanetiTestCase.setUp(self)
43 self.tmpfile = self._CreateTempFile()
46 cfg = mocks.FakeConfig()
47 ssh.WriteKnownHostsFile(cfg, self.tmpfile)
48 self.assertFileContent(self.tmpfile,
49 "%s ssh-rsa %s\n%s ssh-dss %s\n" %
50 (cfg.GetClusterName(), mocks.FAKE_CLUSTER_KEY,
51 cfg.GetClusterName(), mocks.FAKE_CLUSTER_KEY))
54 class TestGetUserFiles(unittest.TestCase):
56 self.tmpdir = tempfile.mkdtemp()
59 shutil.rmtree(self.tmpdir)
65 def _GetTempHomedir(self, _):
68 def testNonExistantUser(self):
69 for kind in constants.SSHK_ALL:
70 self.assertRaises(errors.OpExecError, ssh.GetUserFiles, "example",
71 kind=kind, _homedir_fn=self._GetNoHomedir)
73 def testUnknownKind(self):
74 kind = "something-else"
75 assert kind not in constants.SSHK_ALL
76 self.assertRaises(errors.ProgrammerError, ssh.GetUserFiles, "example4645",
77 kind=kind, _homedir_fn=self._GetTempHomedir)
79 self.assertEqual(os.listdir(self.tmpdir), [])
81 def testNoSshDirectory(self):
82 for kind in constants.SSHK_ALL:
83 self.assertRaises(errors.OpExecError, ssh.GetUserFiles, "example29694",
84 kind=kind, _homedir_fn=self._GetTempHomedir)
85 self.assertEqual(os.listdir(self.tmpdir), [])
87 def testSshIsFile(self):
88 utils.WriteFile(os.path.join(self.tmpdir, ".ssh"), data="")
89 for kind in constants.SSHK_ALL:
90 self.assertRaises(errors.OpExecError, ssh.GetUserFiles, "example26237",
91 kind=kind, _homedir_fn=self._GetTempHomedir)
92 self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
94 def testMakeSshDirectory(self):
95 sshdir = os.path.join(self.tmpdir, ".ssh")
97 self.assertEqual(os.listdir(self.tmpdir), [])
99 for kind in constants.SSHK_ALL:
100 ssh.GetUserFiles("example20745", mkdir=True, kind=kind,
101 _homedir_fn=self._GetTempHomedir)
102 self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
103 self.assertEqual(os.stat(sshdir).st_mode & 0777, 0700)
105 def testFilenames(self):
106 sshdir = os.path.join(self.tmpdir, ".ssh")
110 for kind in constants.SSHK_ALL:
111 result = ssh.GetUserFiles("example15103", mkdir=False, kind=kind,
112 _homedir_fn=self._GetTempHomedir)
113 self.assertEqual(result, [
114 os.path.join(self.tmpdir, ".ssh", "id_%s" % kind),
115 os.path.join(self.tmpdir, ".ssh", "id_%s.pub" % kind),
116 os.path.join(self.tmpdir, ".ssh", "authorized_keys"),
119 self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
120 self.assertEqual(os.listdir(sshdir), [])
122 def testNoDirCheck(self):
123 self.assertEqual(os.listdir(self.tmpdir), [])
125 for kind in constants.SSHK_ALL:
126 ssh.GetUserFiles("example14528", mkdir=False, dircheck=False, kind=kind,
127 _homedir_fn=self._GetTempHomedir)
128 self.assertEqual(os.listdir(self.tmpdir), [])
130 def testGetAllUserFiles(self):
131 result = ssh.GetAllUserFiles("example7475", mkdir=False, dircheck=False,
132 _homedir_fn=self._GetTempHomedir)
133 self.assertEqual(result,
134 (os.path.join(self.tmpdir, ".ssh", "authorized_keys"), {
136 (os.path.join(self.tmpdir, ".ssh", "id_rsa"),
137 os.path.join(self.tmpdir, ".ssh", "id_rsa.pub")),
139 (os.path.join(self.tmpdir, ".ssh", "id_dsa"),
140 os.path.join(self.tmpdir, ".ssh", "id_dsa.pub")),
142 self.assertEqual(os.listdir(self.tmpdir), [])
144 def testGetAllUserFilesNoDirectoryNoMkdir(self):
145 self.assertRaises(errors.OpExecError, ssh.GetAllUserFiles,
146 "example17270", mkdir=False, dircheck=True,
147 _homedir_fn=self._GetTempHomedir)
148 self.assertEqual(os.listdir(self.tmpdir), [])
151 if __name__ == "__main__":
152 testutils.GanetiTestProgram()