Move cluster verification out of prepare-node-join
[ganeti-local] / test / ganeti.tools.prepare_node_join_unittest.py
1 #!/usr/bin/python
2 #
3
4 # Copyright (C) 2012 Google Inc.
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 # 02110-1301, USA.
20
21
22 """Script for testing ganeti.tools.prepare_node_join"""
23
24 import unittest
25 import shutil
26 import tempfile
27 import os.path
28 import OpenSSL
29
30 from ganeti import errors
31 from ganeti import constants
32 from ganeti import serializer
33 from ganeti import pathutils
34 from ganeti import compat
35 from ganeti import utils
36 from ganeti.tools import prepare_node_join
37
38 import testutils
39
40
41 _JoinError = prepare_node_join.JoinError
42
43
44 class TestLoadData(unittest.TestCase):
45   def testNoJson(self):
46     self.assertRaises(errors.ParseError, prepare_node_join.LoadData, "")
47     self.assertRaises(errors.ParseError, prepare_node_join.LoadData, "}")
48
49   def testInvalidDataStructure(self):
50     raw = serializer.DumpJson({
51       "some other thing": False,
52       })
53     self.assertRaises(errors.ParseError, prepare_node_join.LoadData, raw)
54
55     raw = serializer.DumpJson([])
56     self.assertRaises(errors.ParseError, prepare_node_join.LoadData, raw)
57
58   def testValidData(self):
59     raw = serializer.DumpJson({})
60     self.assertEqual(prepare_node_join.LoadData(raw), {})
61
62
63 class TestVerifyCertificate(testutils.GanetiTestCase):
64   def setUp(self):
65     testutils.GanetiTestCase.setUp(self)
66     self.tmpdir = tempfile.mkdtemp()
67
68   def tearDown(self):
69     testutils.GanetiTestCase.tearDown(self)
70     shutil.rmtree(self.tmpdir)
71
72   def testNoCert(self):
73     prepare_node_join.VerifyCertificate({}, _verify_fn=NotImplemented)
74
75   def testMismatchingKey(self):
76     other_cert = self._TestDataFilename("cert1.pem")
77     node_cert = self._TestDataFilename("cert2.pem")
78
79     self.assertRaises(_JoinError, prepare_node_join._VerifyCertificate,
80                       utils.ReadFile(other_cert), _noded_cert_file=node_cert)
81
82   def testGivenPrivateKey(self):
83     cert_filename = self._TestDataFilename("cert2.pem")
84     cert_pem = utils.ReadFile(cert_filename)
85
86     self.assertRaises(_JoinError, prepare_node_join._VerifyCertificate,
87                       cert_pem, _noded_cert_file=cert_filename)
88
89   def testMatchingKey(self):
90     cert_filename = self._TestDataFilename("cert2.pem")
91
92     # Extract certificate
93     cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
94                                            utils.ReadFile(cert_filename))
95     cert_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM,
96                                                cert)
97
98     prepare_node_join._VerifyCertificate(cert_pem,
99                                          _noded_cert_file=cert_filename)
100
101   def testMissingFile(self):
102     cert = self._TestDataFilename("cert1.pem")
103     nodecert = utils.PathJoin(self.tmpdir, "does-not-exist")
104     prepare_node_join._VerifyCertificate(utils.ReadFile(cert),
105                                          _noded_cert_file=nodecert)
106
107   def testInvalidCertificate(self):
108     self.assertRaises(errors.X509CertError,
109                       prepare_node_join._VerifyCertificate,
110                       "Something that's not a certificate",
111                       _noded_cert_file=NotImplemented)
112
113   def testNoPrivateKey(self):
114     cert = self._TestDataFilename("cert1.pem")
115     self.assertRaises(errors.X509CertError,
116                       prepare_node_join._VerifyCertificate,
117                       utils.ReadFile(cert), _noded_cert_file=cert)
118
119
120 class TestVerifyClusterName(unittest.TestCase):
121   def setUp(self):
122     unittest.TestCase.setUp(self)
123     self.tmpdir = tempfile.mkdtemp()
124
125   def tearDown(self):
126     unittest.TestCase.tearDown(self)
127     shutil.rmtree(self.tmpdir)
128
129   def testNoName(self):
130     self.assertRaises(_JoinError, prepare_node_join.VerifyClusterName,
131                       {}, _verify_fn=NotImplemented)
132
133   @staticmethod
134   def _FailingVerify(name):
135     assert name == "cluster.example.com"
136     raise errors.GenericError()
137
138   def testFailingVerification(self):
139     data = {
140       constants.SSHS_CLUSTER_NAME: "cluster.example.com",
141       }
142
143     self.assertRaises(errors.GenericError, prepare_node_join.VerifyClusterName,
144                       data, _verify_fn=self._FailingVerify)
145
146
147 class TestUpdateSshDaemon(unittest.TestCase):
148   def setUp(self):
149     unittest.TestCase.setUp(self)
150     self.tmpdir = tempfile.mkdtemp()
151
152     self.keyfiles = {
153       constants.SSHK_RSA:
154         (utils.PathJoin(self.tmpdir, "rsa.private"),
155          utils.PathJoin(self.tmpdir, "rsa.public")),
156       constants.SSHK_DSA:
157         (utils.PathJoin(self.tmpdir, "dsa.private"),
158          utils.PathJoin(self.tmpdir, "dsa.public")),
159       }
160
161   def tearDown(self):
162     unittest.TestCase.tearDown(self)
163     shutil.rmtree(self.tmpdir)
164
165   def testNoKeys(self):
166     data_empty_keys = {
167       constants.SSHS_SSH_HOST_KEY: [],
168       }
169
170     for data in [{}, data_empty_keys]:
171       for dry_run in [False, True]:
172         prepare_node_join.UpdateSshDaemon(data, dry_run,
173                                           _runcmd_fn=NotImplemented,
174                                           _keyfiles=NotImplemented)
175     self.assertEqual(os.listdir(self.tmpdir), [])
176
177   def _TestDryRun(self, data):
178     prepare_node_join.UpdateSshDaemon(data, True, _runcmd_fn=NotImplemented,
179                                       _keyfiles=self.keyfiles)
180     self.assertEqual(os.listdir(self.tmpdir), [])
181
182   def testDryRunRsa(self):
183     self._TestDryRun({
184       constants.SSHS_SSH_HOST_KEY: [
185         (constants.SSHK_RSA, "rsapriv", "rsapub"),
186         ],
187       })
188
189   def testDryRunDsa(self):
190     self._TestDryRun({
191       constants.SSHS_SSH_HOST_KEY: [
192         (constants.SSHK_DSA, "dsapriv", "dsapub"),
193         ],
194       })
195
196   def _RunCmd(self, fail, cmd, interactive=NotImplemented):
197     self.assertTrue(interactive)
198     self.assertEqual(cmd, [pathutils.DAEMON_UTIL, "reload-ssh-keys"])
199     if fail:
200       exit_code = constants.EXIT_FAILURE
201     else:
202       exit_code = constants.EXIT_SUCCESS
203     return utils.RunResult(exit_code, None, "stdout", "stderr",
204                            utils.ShellQuoteArgs(cmd),
205                            NotImplemented, NotImplemented)
206
207   def _TestUpdate(self, failcmd):
208     data = {
209       constants.SSHS_SSH_HOST_KEY: [
210         (constants.SSHK_DSA, "dsapriv", "dsapub"),
211         (constants.SSHK_RSA, "rsapriv", "rsapub"),
212         ],
213       }
214     runcmd_fn = compat.partial(self._RunCmd, failcmd)
215     if failcmd:
216       self.assertRaises(_JoinError, prepare_node_join.UpdateSshDaemon,
217                         data, False, _runcmd_fn=runcmd_fn,
218                         _keyfiles=self.keyfiles)
219     else:
220       prepare_node_join.UpdateSshDaemon(data, False, _runcmd_fn=runcmd_fn,
221                                         _keyfiles=self.keyfiles)
222     self.assertEqual(sorted(os.listdir(self.tmpdir)), sorted([
223       "rsa.public", "rsa.private",
224       "dsa.public", "dsa.private",
225       ]))
226     self.assertEqual(utils.ReadFile(utils.PathJoin(self.tmpdir, "rsa.public")),
227                      "rsapub")
228     self.assertEqual(utils.ReadFile(utils.PathJoin(self.tmpdir, "rsa.private")),
229                      "rsapriv")
230     self.assertEqual(utils.ReadFile(utils.PathJoin(self.tmpdir, "dsa.public")),
231                      "dsapub")
232     self.assertEqual(utils.ReadFile(utils.PathJoin(self.tmpdir, "dsa.private")),
233                      "dsapriv")
234
235   def testSuccess(self):
236     self._TestUpdate(False)
237
238   def testFailure(self):
239     self._TestUpdate(True)
240
241
242 class TestUpdateSshRoot(unittest.TestCase):
243   def setUp(self):
244     unittest.TestCase.setUp(self)
245     self.tmpdir = tempfile.mkdtemp()
246     self.sshdir = utils.PathJoin(self.tmpdir, ".ssh")
247
248   def tearDown(self):
249     unittest.TestCase.tearDown(self)
250     shutil.rmtree(self.tmpdir)
251
252   def _GetHomeDir(self, user):
253     self.assertEqual(user, constants.SSH_LOGIN_USER)
254     return self.tmpdir
255
256   def testNoKeys(self):
257     data_empty_keys = {
258       constants.SSHS_SSH_ROOT_KEY: [],
259       }
260
261     for data in [{}, data_empty_keys]:
262       for dry_run in [False, True]:
263         prepare_node_join.UpdateSshRoot(data, dry_run,
264                                         _homedir_fn=NotImplemented)
265     self.assertEqual(os.listdir(self.tmpdir), [])
266
267   def testDryRun(self):
268     data = {
269       constants.SSHS_SSH_ROOT_KEY: [
270         (constants.SSHK_RSA, "aaa", "bbb"),
271         ]
272       }
273
274     prepare_node_join.UpdateSshRoot(data, True,
275                                     _homedir_fn=self._GetHomeDir)
276     self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
277     self.assertEqual(os.listdir(self.sshdir), [])
278
279   def testUpdate(self):
280     data = {
281       constants.SSHS_SSH_ROOT_KEY: [
282         (constants.SSHK_DSA, "privatedsa", "ssh-dss pubdsa"),
283         ]
284       }
285
286     prepare_node_join.UpdateSshRoot(data, False,
287                                     _homedir_fn=self._GetHomeDir)
288     self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
289     self.assertEqual(sorted(os.listdir(self.sshdir)),
290                      sorted(["authorized_keys", "id_dsa", "id_dsa.pub"]))
291     self.assertEqual(utils.ReadFile(utils.PathJoin(self.sshdir, "id_dsa")),
292                      "privatedsa")
293     self.assertEqual(utils.ReadFile(utils.PathJoin(self.sshdir, "id_dsa.pub")),
294                      "ssh-dss pubdsa")
295     self.assertEqual(utils.ReadFile(utils.PathJoin(self.sshdir,
296                                                    "authorized_keys")),
297                      "ssh-dss pubdsa\n")
298
299
300 if __name__ == "__main__":
301   testutils.GanetiTestProgram()