Add the gnt-storage client
[ganeti-local] / test / ganeti.ssh_unittest.py
1 #!/usr/bin/python
2 #
3
4 # Copyright (C) 2006, 2007, 2008 Google Inc.
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 # 02110-1301, USA.
20
21
22 """Script for unittesting the ssh module"""
23
24 import os
25 import tempfile
26 import unittest
27 import shutil
28
29 import testutils
30 import mocks
31
32 from ganeti import constants
33 from ganeti import utils
34 from ganeti import ssh
35 from ganeti import errors
36
37
38 class TestKnownHosts(testutils.GanetiTestCase):
39   """Test case for function writing the known_hosts file"""
40
41   def setUp(self):
42     testutils.GanetiTestCase.setUp(self)
43     self.tmpfile = self._CreateTempFile()
44
45   def test(self):
46     cfg = mocks.FakeConfig()
47     ssh.WriteKnownHostsFile(cfg, self.tmpfile)
48     self.assertFileContent(self.tmpfile,
49         "%s ssh-rsa %s\n" % (cfg.GetClusterName(),
50                              mocks.FAKE_CLUSTER_KEY))
51
52
53 class TestGetUserFiles(unittest.TestCase):
54   def setUp(self):
55     self.tmpdir = tempfile.mkdtemp()
56
57   def tearDown(self):
58     shutil.rmtree(self.tmpdir)
59
60   @staticmethod
61   def _GetNoHomedir(_):
62     return None
63
64   def _GetTempHomedir(self, _):
65     return self.tmpdir
66
67   def testNonExistantUser(self):
68     for kind in constants.SSHK_ALL:
69       self.assertRaises(errors.OpExecError, ssh.GetUserFiles, "example",
70                         kind=kind, _homedir_fn=self._GetNoHomedir)
71
72   def testUnknownKind(self):
73     kind = "something-else"
74     assert kind not in constants.SSHK_ALL
75     self.assertRaises(errors.ProgrammerError, ssh.GetUserFiles, "example4645",
76                       kind=kind, _homedir_fn=self._GetTempHomedir)
77
78     self.assertEqual(os.listdir(self.tmpdir), [])
79
80   def testNoSshDirectory(self):
81     for kind in constants.SSHK_ALL:
82       self.assertRaises(errors.OpExecError, ssh.GetUserFiles, "example29694",
83                         kind=kind, _homedir_fn=self._GetTempHomedir)
84       self.assertEqual(os.listdir(self.tmpdir), [])
85
86   def testSshIsFile(self):
87     utils.WriteFile(os.path.join(self.tmpdir, ".ssh"), data="")
88     for kind in constants.SSHK_ALL:
89       self.assertRaises(errors.OpExecError, ssh.GetUserFiles, "example26237",
90                         kind=kind, _homedir_fn=self._GetTempHomedir)
91       self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
92
93   def testMakeSshDirectory(self):
94     sshdir = os.path.join(self.tmpdir, ".ssh")
95
96     self.assertEqual(os.listdir(self.tmpdir), [])
97
98     for kind in constants.SSHK_ALL:
99       ssh.GetUserFiles("example20745", mkdir=True, kind=kind,
100                        _homedir_fn=self._GetTempHomedir)
101       self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
102       self.assertEqual(os.stat(sshdir).st_mode & 0777, 0700)
103
104   def testFilenames(self):
105     sshdir = os.path.join(self.tmpdir, ".ssh")
106
107     os.mkdir(sshdir)
108
109     for kind in constants.SSHK_ALL:
110       result = ssh.GetUserFiles("example15103", mkdir=False, kind=kind,
111                                 _homedir_fn=self._GetTempHomedir)
112       self.assertEqual(result, [
113         os.path.join(self.tmpdir, ".ssh", "id_%s" % kind),
114         os.path.join(self.tmpdir, ".ssh", "id_%s.pub" % kind),
115         os.path.join(self.tmpdir, ".ssh", "authorized_keys"),
116         ])
117
118       self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
119       self.assertEqual(os.listdir(sshdir), [])
120
121   def testNoDirCheck(self):
122     self.assertEqual(os.listdir(self.tmpdir), [])
123
124     for kind in constants.SSHK_ALL:
125       ssh.GetUserFiles("example14528", mkdir=False, dircheck=False, kind=kind,
126                        _homedir_fn=self._GetTempHomedir)
127       self.assertEqual(os.listdir(self.tmpdir), [])
128
129   def testGetAllUserFiles(self):
130     result = ssh.GetAllUserFiles("example7475", mkdir=False, dircheck=False,
131                                  _homedir_fn=self._GetTempHomedir)
132     self.assertEqual(result,
133       (os.path.join(self.tmpdir, ".ssh", "authorized_keys"), {
134         constants.SSHK_RSA:
135           (os.path.join(self.tmpdir, ".ssh", "id_rsa"),
136            os.path.join(self.tmpdir, ".ssh", "id_rsa.pub")),
137         constants.SSHK_DSA:
138           (os.path.join(self.tmpdir, ".ssh", "id_dsa"),
139            os.path.join(self.tmpdir, ".ssh", "id_dsa.pub")),
140       }))
141     self.assertEqual(os.listdir(self.tmpdir), [])
142
143   def testGetAllUserFilesNoDirectoryNoMkdir(self):
144     self.assertRaises(errors.OpExecError, ssh.GetAllUserFiles,
145                       "example17270", mkdir=False, dircheck=True,
146                       _homedir_fn=self._GetTempHomedir)
147     self.assertEqual(os.listdir(self.tmpdir), [])
148
149
150 if __name__ == "__main__":
151   testutils.GanetiTestProgram()