4 # Copyright (C) 2012 Google Inc.
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 # General Public License for more details.
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
22 """Script for testing ganeti.tools.prepare_node_join"""
30 from ganeti import errors
31 from ganeti import constants
32 from ganeti import serializer
33 from ganeti import pathutils
34 from ganeti import compat
35 from ganeti import utils
36 from ganeti.tools import prepare_node_join
41 _JoinError = prepare_node_join.JoinError
44 class TestLoadData(unittest.TestCase):
46 self.assertRaises(errors.ParseError, prepare_node_join.LoadData, "")
47 self.assertRaises(errors.ParseError, prepare_node_join.LoadData, "}")
49 def testInvalidDataStructure(self):
50 raw = serializer.DumpJson({
51 "some other thing": False,
53 self.assertRaises(errors.ParseError, prepare_node_join.LoadData, raw)
55 raw = serializer.DumpJson([])
56 self.assertRaises(errors.ParseError, prepare_node_join.LoadData, raw)
58 def testValidData(self):
59 raw = serializer.DumpJson({})
60 self.assertEqual(prepare_node_join.LoadData(raw), {})
63 class TestVerifyCertificate(testutils.GanetiTestCase):
65 testutils.GanetiTestCase.setUp(self)
66 self.tmpdir = tempfile.mkdtemp()
69 testutils.GanetiTestCase.tearDown(self)
70 shutil.rmtree(self.tmpdir)
73 prepare_node_join.VerifyCertificate({}, _verify_fn=NotImplemented)
75 def testMismatchingKey(self):
76 other_cert = self._TestDataFilename("cert1.pem")
77 node_cert = self._TestDataFilename("cert2.pem")
79 self.assertRaises(_JoinError, prepare_node_join._VerifyCertificate,
80 utils.ReadFile(other_cert), _noded_cert_file=node_cert)
82 def testGivenPrivateKey(self):
83 cert_filename = self._TestDataFilename("cert2.pem")
84 cert_pem = utils.ReadFile(cert_filename)
86 self.assertRaises(_JoinError, prepare_node_join._VerifyCertificate,
87 cert_pem, _noded_cert_file=cert_filename)
89 def testMatchingKey(self):
90 cert_filename = self._TestDataFilename("cert2.pem")
93 cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
94 utils.ReadFile(cert_filename))
95 cert_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM,
98 prepare_node_join._VerifyCertificate(cert_pem,
99 _noded_cert_file=cert_filename)
101 def testMissingFile(self):
102 cert = self._TestDataFilename("cert1.pem")
103 nodecert = utils.PathJoin(self.tmpdir, "does-not-exist")
104 prepare_node_join._VerifyCertificate(utils.ReadFile(cert),
105 _noded_cert_file=nodecert)
107 def testInvalidCertificate(self):
108 self.assertRaises(errors.X509CertError,
109 prepare_node_join._VerifyCertificate,
110 "Something that's not a certificate",
111 _noded_cert_file=NotImplemented)
113 def testNoPrivateKey(self):
114 cert = self._TestDataFilename("cert1.pem")
115 self.assertRaises(errors.X509CertError,
116 prepare_node_join._VerifyCertificate,
117 utils.ReadFile(cert), _noded_cert_file=cert)
120 class TestVerifyClusterName(unittest.TestCase):
122 unittest.TestCase.setUp(self)
123 self.tmpdir = tempfile.mkdtemp()
126 unittest.TestCase.tearDown(self)
127 shutil.rmtree(self.tmpdir)
129 def testNoName(self):
130 self.assertRaises(_JoinError, prepare_node_join.VerifyClusterName,
131 {}, _verify_fn=NotImplemented)
134 def _FailingVerify(name):
135 assert name == "cluster.example.com"
136 raise errors.GenericError()
138 def testFailingVerification(self):
140 constants.SSHS_CLUSTER_NAME: "cluster.example.com",
143 self.assertRaises(errors.GenericError, prepare_node_join.VerifyClusterName,
144 data, _verify_fn=self._FailingVerify)
147 class TestUpdateSshDaemon(unittest.TestCase):
149 unittest.TestCase.setUp(self)
150 self.tmpdir = tempfile.mkdtemp()
154 (utils.PathJoin(self.tmpdir, "rsa.private"),
155 utils.PathJoin(self.tmpdir, "rsa.public")),
157 (utils.PathJoin(self.tmpdir, "dsa.private"),
158 utils.PathJoin(self.tmpdir, "dsa.public")),
162 unittest.TestCase.tearDown(self)
163 shutil.rmtree(self.tmpdir)
165 def testNoKeys(self):
167 constants.SSHS_SSH_HOST_KEY: [],
170 for data in [{}, data_empty_keys]:
171 for dry_run in [False, True]:
172 prepare_node_join.UpdateSshDaemon(data, dry_run,
173 _runcmd_fn=NotImplemented,
174 _keyfiles=NotImplemented)
175 self.assertEqual(os.listdir(self.tmpdir), [])
177 def _TestDryRun(self, data):
178 prepare_node_join.UpdateSshDaemon(data, True, _runcmd_fn=NotImplemented,
179 _keyfiles=self.keyfiles)
180 self.assertEqual(os.listdir(self.tmpdir), [])
182 def testDryRunRsa(self):
184 constants.SSHS_SSH_HOST_KEY: [
185 (constants.SSHK_RSA, "rsapriv", "rsapub"),
189 def testDryRunDsa(self):
191 constants.SSHS_SSH_HOST_KEY: [
192 (constants.SSHK_DSA, "dsapriv", "dsapub"),
196 def _RunCmd(self, fail, cmd, interactive=NotImplemented):
197 self.assertTrue(interactive)
198 self.assertEqual(cmd, [pathutils.DAEMON_UTIL, "reload-ssh-keys"])
200 exit_code = constants.EXIT_FAILURE
202 exit_code = constants.EXIT_SUCCESS
203 return utils.RunResult(exit_code, None, "stdout", "stderr",
204 utils.ShellQuoteArgs(cmd),
205 NotImplemented, NotImplemented)
207 def _TestUpdate(self, failcmd):
209 constants.SSHS_SSH_HOST_KEY: [
210 (constants.SSHK_DSA, "dsapriv", "dsapub"),
211 (constants.SSHK_RSA, "rsapriv", "rsapub"),
214 runcmd_fn = compat.partial(self._RunCmd, failcmd)
216 self.assertRaises(_JoinError, prepare_node_join.UpdateSshDaemon,
217 data, False, _runcmd_fn=runcmd_fn,
218 _keyfiles=self.keyfiles)
220 prepare_node_join.UpdateSshDaemon(data, False, _runcmd_fn=runcmd_fn,
221 _keyfiles=self.keyfiles)
222 self.assertEqual(sorted(os.listdir(self.tmpdir)), sorted([
223 "rsa.public", "rsa.private",
224 "dsa.public", "dsa.private",
226 self.assertEqual(utils.ReadFile(utils.PathJoin(self.tmpdir, "rsa.public")),
228 self.assertEqual(utils.ReadFile(utils.PathJoin(self.tmpdir, "rsa.private")),
230 self.assertEqual(utils.ReadFile(utils.PathJoin(self.tmpdir, "dsa.public")),
232 self.assertEqual(utils.ReadFile(utils.PathJoin(self.tmpdir, "dsa.private")),
235 def testSuccess(self):
236 self._TestUpdate(False)
238 def testFailure(self):
239 self._TestUpdate(True)
242 class TestUpdateSshRoot(unittest.TestCase):
244 unittest.TestCase.setUp(self)
245 self.tmpdir = tempfile.mkdtemp()
246 self.sshdir = utils.PathJoin(self.tmpdir, ".ssh")
249 unittest.TestCase.tearDown(self)
250 shutil.rmtree(self.tmpdir)
252 def _GetHomeDir(self, user):
253 self.assertEqual(user, constants.SSH_LOGIN_USER)
256 def testNoKeys(self):
258 constants.SSHS_SSH_ROOT_KEY: [],
261 for data in [{}, data_empty_keys]:
262 for dry_run in [False, True]:
263 prepare_node_join.UpdateSshRoot(data, dry_run,
264 _homedir_fn=NotImplemented)
265 self.assertEqual(os.listdir(self.tmpdir), [])
267 def testDryRun(self):
269 constants.SSHS_SSH_ROOT_KEY: [
270 (constants.SSHK_RSA, "aaa", "bbb"),
274 prepare_node_join.UpdateSshRoot(data, True,
275 _homedir_fn=self._GetHomeDir)
276 self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
277 self.assertEqual(os.listdir(self.sshdir), [])
279 def testUpdate(self):
281 constants.SSHS_SSH_ROOT_KEY: [
282 (constants.SSHK_DSA, "privatedsa", "ssh-dss pubdsa"),
286 prepare_node_join.UpdateSshRoot(data, False,
287 _homedir_fn=self._GetHomeDir)
288 self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
289 self.assertEqual(sorted(os.listdir(self.sshdir)),
290 sorted(["authorized_keys", "id_dsa", "id_dsa.pub"]))
291 self.assertEqual(utils.ReadFile(utils.PathJoin(self.sshdir, "id_dsa")),
293 self.assertEqual(utils.ReadFile(utils.PathJoin(self.sshdir, "id_dsa.pub")),
295 self.assertEqual(utils.ReadFile(utils.PathJoin(self.sshdir,
300 if __name__ == "__main__":
301 testutils.GanetiTestProgram()