Add unit tests for LUGroupSetParams
[ganeti-local] / test / py / ganeti.ssh_unittest.py
1 #!/usr/bin/python
2 #
3
4 # Copyright (C) 2006, 2007, 2008 Google Inc.
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 # 02110-1301, USA.
20
21
22 """Script for unittesting the ssh module"""
23
24 import os
25 import tempfile
26 import unittest
27 import shutil
28
29 import testutils
30 import mocks
31
32 from ganeti import constants
33 from ganeti import utils
34 from ganeti import ssh
35 from ganeti import errors
36
37
38 class TestKnownHosts(testutils.GanetiTestCase):
39   """Test case for function writing the known_hosts file"""
40
41   def setUp(self):
42     testutils.GanetiTestCase.setUp(self)
43     self.tmpfile = self._CreateTempFile()
44
45   def test(self):
46     cfg = mocks.FakeConfig()
47     ssh.WriteKnownHostsFile(cfg, self.tmpfile)
48     self.assertFileContent(self.tmpfile,
49         "%s ssh-rsa %s\n%s ssh-dss %s\n" %
50         (cfg.GetClusterName(), mocks.FAKE_CLUSTER_KEY,
51          cfg.GetClusterName(), mocks.FAKE_CLUSTER_KEY))
52
53
54 class TestGetUserFiles(unittest.TestCase):
55   def setUp(self):
56     self.tmpdir = tempfile.mkdtemp()
57
58   def tearDown(self):
59     shutil.rmtree(self.tmpdir)
60
61   @staticmethod
62   def _GetNoHomedir(_):
63     return None
64
65   def _GetTempHomedir(self, _):
66     return self.tmpdir
67
68   def testNonExistantUser(self):
69     for kind in constants.SSHK_ALL:
70       self.assertRaises(errors.OpExecError, ssh.GetUserFiles, "example",
71                         kind=kind, _homedir_fn=self._GetNoHomedir)
72
73   def testUnknownKind(self):
74     kind = "something-else"
75     assert kind not in constants.SSHK_ALL
76     self.assertRaises(errors.ProgrammerError, ssh.GetUserFiles, "example4645",
77                       kind=kind, _homedir_fn=self._GetTempHomedir)
78
79     self.assertEqual(os.listdir(self.tmpdir), [])
80
81   def testNoSshDirectory(self):
82     for kind in constants.SSHK_ALL:
83       self.assertRaises(errors.OpExecError, ssh.GetUserFiles, "example29694",
84                         kind=kind, _homedir_fn=self._GetTempHomedir)
85       self.assertEqual(os.listdir(self.tmpdir), [])
86
87   def testSshIsFile(self):
88     utils.WriteFile(os.path.join(self.tmpdir, ".ssh"), data="")
89     for kind in constants.SSHK_ALL:
90       self.assertRaises(errors.OpExecError, ssh.GetUserFiles, "example26237",
91                         kind=kind, _homedir_fn=self._GetTempHomedir)
92       self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
93
94   def testMakeSshDirectory(self):
95     sshdir = os.path.join(self.tmpdir, ".ssh")
96
97     self.assertEqual(os.listdir(self.tmpdir), [])
98
99     for kind in constants.SSHK_ALL:
100       ssh.GetUserFiles("example20745", mkdir=True, kind=kind,
101                        _homedir_fn=self._GetTempHomedir)
102       self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
103       self.assertEqual(os.stat(sshdir).st_mode & 0777, 0700)
104
105   def testFilenames(self):
106     sshdir = os.path.join(self.tmpdir, ".ssh")
107
108     os.mkdir(sshdir)
109
110     for kind in constants.SSHK_ALL:
111       result = ssh.GetUserFiles("example15103", mkdir=False, kind=kind,
112                                 _homedir_fn=self._GetTempHomedir)
113       self.assertEqual(result, [
114         os.path.join(self.tmpdir, ".ssh", "id_%s" % kind),
115         os.path.join(self.tmpdir, ".ssh", "id_%s.pub" % kind),
116         os.path.join(self.tmpdir, ".ssh", "authorized_keys"),
117         ])
118
119       self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
120       self.assertEqual(os.listdir(sshdir), [])
121
122   def testNoDirCheck(self):
123     self.assertEqual(os.listdir(self.tmpdir), [])
124
125     for kind in constants.SSHK_ALL:
126       ssh.GetUserFiles("example14528", mkdir=False, dircheck=False, kind=kind,
127                        _homedir_fn=self._GetTempHomedir)
128       self.assertEqual(os.listdir(self.tmpdir), [])
129
130   def testGetAllUserFiles(self):
131     result = ssh.GetAllUserFiles("example7475", mkdir=False, dircheck=False,
132                                  _homedir_fn=self._GetTempHomedir)
133     self.assertEqual(result,
134       (os.path.join(self.tmpdir, ".ssh", "authorized_keys"), {
135         constants.SSHK_RSA:
136           (os.path.join(self.tmpdir, ".ssh", "id_rsa"),
137            os.path.join(self.tmpdir, ".ssh", "id_rsa.pub")),
138         constants.SSHK_DSA:
139           (os.path.join(self.tmpdir, ".ssh", "id_dsa"),
140            os.path.join(self.tmpdir, ".ssh", "id_dsa.pub")),
141       }))
142     self.assertEqual(os.listdir(self.tmpdir), [])
143
144   def testGetAllUserFilesNoDirectoryNoMkdir(self):
145     self.assertRaises(errors.OpExecError, ssh.GetAllUserFiles,
146                       "example17270", mkdir=False, dircheck=True,
147                       _homedir_fn=self._GetTempHomedir)
148     self.assertEqual(os.listdir(self.tmpdir), [])
149
150
151 if __name__ == "__main__":
152   testutils.GanetiTestProgram()