Add initial implementation of prepare-node-join
[ganeti-local] / test / ganeti.tools.prepare_node_join_unittest.py
1 #!/usr/bin/python
2 #
3
4 # Copyright (C) 2012 Google Inc.
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 # 02110-1301, USA.
20
21
22 """Script for testing ganeti.tools.prepare_node_join"""
23
24 import unittest
25 import shutil
26 import tempfile
27 import os.path
28 import OpenSSL
29
30 from ganeti import errors
31 from ganeti import constants
32 from ganeti import serializer
33 from ganeti import pathutils
34 from ganeti import compat
35 from ganeti import utils
36 from ganeti.tools import prepare_node_join
37
38 import testutils
39
40
41 _JoinError = prepare_node_join.JoinError
42
43
44 class TestLoadData(unittest.TestCase):
45   def testNoJson(self):
46     self.assertRaises(errors.ParseError, prepare_node_join.LoadData, "")
47     self.assertRaises(errors.ParseError, prepare_node_join.LoadData, "}")
48
49   def testInvalidDataStructure(self):
50     raw = serializer.DumpJson({
51       "some other thing": False,
52       })
53     self.assertRaises(errors.ParseError, prepare_node_join.LoadData, raw)
54
55     raw = serializer.DumpJson([])
56     self.assertRaises(errors.ParseError, prepare_node_join.LoadData, raw)
57
58   def testValidData(self):
59     raw = serializer.DumpJson({})
60     self.assertEqual(prepare_node_join.LoadData(raw), {})
61
62
63 class TestVerifyCertificate(testutils.GanetiTestCase):
64   def setUp(self):
65     testutils.GanetiTestCase.setUp(self)
66     self.tmpdir = tempfile.mkdtemp()
67
68   def tearDown(self):
69     testutils.GanetiTestCase.tearDown(self)
70     shutil.rmtree(self.tmpdir)
71
72   def testNoCert(self):
73     prepare_node_join.VerifyCertificate({}, _verify_fn=NotImplemented)
74
75   def testMismatchingKey(self):
76     other_cert = self._TestDataFilename("cert1.pem")
77     node_cert = self._TestDataFilename("cert2.pem")
78
79     self.assertRaises(_JoinError, prepare_node_join._VerifyCertificate,
80                       utils.ReadFile(other_cert), _noded_cert_file=node_cert)
81
82   def testGivenPrivateKey(self):
83     cert_filename = self._TestDataFilename("cert2.pem")
84     cert_pem = utils.ReadFile(cert_filename)
85
86     self.assertRaises(_JoinError, prepare_node_join._VerifyCertificate,
87                       cert_pem, _noded_cert_file=cert_filename)
88
89   def testMatchingKey(self):
90     cert_filename = self._TestDataFilename("cert2.pem")
91
92     # Extract certificate
93     cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
94                                            utils.ReadFile(cert_filename))
95     cert_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM,
96                                                cert)
97
98     prepare_node_join._VerifyCertificate(cert_pem,
99                                          _noded_cert_file=cert_filename)
100
101   def testMissingFile(self):
102     cert = self._TestDataFilename("cert1.pem")
103     nodecert = utils.PathJoin(self.tmpdir, "does-not-exist")
104     prepare_node_join._VerifyCertificate(utils.ReadFile(cert),
105                                          _noded_cert_file=nodecert)
106
107   def testInvalidCertificate(self):
108     self.assertRaises(errors.X509CertError,
109                       prepare_node_join._VerifyCertificate,
110                       "Something that's not a certificate",
111                       _noded_cert_file=NotImplemented)
112
113   def testNoPrivateKey(self):
114     cert = self._TestDataFilename("cert1.pem")
115     self.assertRaises(errors.X509CertError,
116                       prepare_node_join._VerifyCertificate,
117                       utils.ReadFile(cert), _noded_cert_file=cert)
118
119
120 class TestVerifyClusterName(unittest.TestCase):
121   def setUp(self):
122     unittest.TestCase.setUp(self)
123     self.tmpdir = tempfile.mkdtemp()
124
125   def tearDown(self):
126     unittest.TestCase.tearDown(self)
127     shutil.rmtree(self.tmpdir)
128
129   def testNoName(self):
130     self.assertRaises(_JoinError, prepare_node_join.VerifyClusterName,
131                       {}, _verify_fn=NotImplemented)
132
133   def testMissingFile(self):
134     tmpfile = utils.PathJoin(self.tmpdir, "does-not-exist")
135     prepare_node_join._VerifyClusterName(NotImplemented,
136                                          _ss_cluster_name_file=tmpfile)
137
138   def testMatchingName(self):
139     tmpfile = utils.PathJoin(self.tmpdir, "cluster_name")
140
141     for content in ["cluster.example.com", "cluster.example.com\n\n"]:
142       utils.WriteFile(tmpfile, data=content)
143       prepare_node_join._VerifyClusterName("cluster.example.com",
144                                            _ss_cluster_name_file=tmpfile)
145
146   def testNameMismatch(self):
147     tmpfile = utils.PathJoin(self.tmpdir, "cluster_name")
148
149     for content in ["something.example.com", "foobar\n\ncluster.example.com"]:
150       utils.WriteFile(tmpfile, data=content)
151       self.assertRaises(_JoinError, prepare_node_join._VerifyClusterName,
152                         "cluster.example.com", _ss_cluster_name_file=tmpfile)
153
154
155 class TestUpdateSshDaemon(unittest.TestCase):
156   def setUp(self):
157     unittest.TestCase.setUp(self)
158     self.tmpdir = tempfile.mkdtemp()
159
160     self.keyfiles = {
161       constants.SSHK_RSA:
162         (utils.PathJoin(self.tmpdir, "rsa.public"),
163          utils.PathJoin(self.tmpdir, "rsa.private")),
164       constants.SSHK_DSA:
165         (utils.PathJoin(self.tmpdir, "dsa.public"),
166          utils.PathJoin(self.tmpdir, "dsa.private")),
167       }
168
169   def tearDown(self):
170     unittest.TestCase.tearDown(self)
171     shutil.rmtree(self.tmpdir)
172
173   def testNoKeys(self):
174     data_empty_keys = {
175       constants.SSHS_SSH_HOST_KEY: [],
176       }
177
178     for data in [{}, data_empty_keys]:
179       for dry_run in [False, True]:
180         prepare_node_join.UpdateSshDaemon(data, dry_run,
181                                           _runcmd_fn=NotImplemented,
182                                           _keyfiles=NotImplemented)
183     self.assertEqual(os.listdir(self.tmpdir), [])
184
185   def _TestDryRun(self, data):
186     prepare_node_join.UpdateSshDaemon(data, True, _runcmd_fn=NotImplemented,
187                                       _keyfiles=self.keyfiles)
188     self.assertEqual(os.listdir(self.tmpdir), [])
189
190   def testDryRunRsa(self):
191     self._TestDryRun({
192       constants.SSHS_SSH_HOST_KEY: [
193         (constants.SSHK_RSA, "rsapub", "rsapriv"),
194         ],
195       })
196
197   def testDryRunDsa(self):
198     self._TestDryRun({
199       constants.SSHS_SSH_HOST_KEY: [
200         (constants.SSHK_DSA, "dsapub", "dsapriv"),
201         ],
202       })
203
204   def _RunCmd(self, fail, cmd, interactive=NotImplemented):
205     self.assertTrue(interactive)
206     self.assertEqual(cmd, [pathutils.DAEMON_UTIL, "reload-ssh-keys"])
207     if fail:
208       exit_code = constants.EXIT_FAILURE
209     else:
210       exit_code = constants.EXIT_SUCCESS
211     return utils.RunResult(exit_code, None, "stdout", "stderr",
212                            utils.ShellQuoteArgs(cmd),
213                            NotImplemented, NotImplemented)
214
215   def _TestUpdate(self, failcmd):
216     data = {
217       constants.SSHS_SSH_HOST_KEY: [
218         (constants.SSHK_DSA, "dsapub", "dsapriv"),
219         (constants.SSHK_RSA, "rsapub", "rsapriv"),
220         ],
221       }
222     runcmd_fn = compat.partial(self._RunCmd, failcmd)
223     if failcmd:
224       self.assertRaises(_JoinError, prepare_node_join.UpdateSshDaemon,
225                         data, False, _runcmd_fn=runcmd_fn,
226                         _keyfiles=self.keyfiles)
227     else:
228       prepare_node_join.UpdateSshDaemon(data, False, _runcmd_fn=runcmd_fn,
229                                         _keyfiles=self.keyfiles)
230     self.assertEqual(sorted(os.listdir(self.tmpdir)), sorted([
231       "rsa.private", "rsa.public",
232       "dsa.private", "dsa.public",
233       ]))
234     self.assertEqual(utils.ReadFile(utils.PathJoin(self.tmpdir, "rsa.public")),
235                      "rsapub")
236     self.assertEqual(utils.ReadFile(utils.PathJoin(self.tmpdir, "rsa.private")),
237                      "rsapriv")
238     self.assertEqual(utils.ReadFile(utils.PathJoin(self.tmpdir, "dsa.public")),
239                      "dsapub")
240     self.assertEqual(utils.ReadFile(utils.PathJoin(self.tmpdir, "dsa.private")),
241                      "dsapriv")
242
243   def testSuccess(self):
244     self._TestUpdate(False)
245
246   def testFailure(self):
247     self._TestUpdate(True)
248
249
250 class TestUpdateSshRoot(unittest.TestCase):
251   def setUp(self):
252     unittest.TestCase.setUp(self)
253     self.tmpdir = tempfile.mkdtemp()
254     self.sshdir = utils.PathJoin(self.tmpdir, ".ssh")
255
256   def tearDown(self):
257     unittest.TestCase.tearDown(self)
258     shutil.rmtree(self.tmpdir)
259
260   def _GetHomeDir(self, user):
261     self.assertEqual(user, constants.SSH_LOGIN_USER)
262     return self.tmpdir
263
264   def testNoKeys(self):
265     data_empty_keys = {
266       constants.SSHS_SSH_ROOT_KEY: [],
267       }
268
269     for data in [{}, data_empty_keys]:
270       for dry_run in [False, True]:
271         prepare_node_join.UpdateSshRoot(data, dry_run,
272                                         _homedir_fn=NotImplemented)
273     self.assertEqual(os.listdir(self.tmpdir), [])
274
275   def testDryRun(self):
276     data = {
277       constants.SSHS_SSH_ROOT_KEY: [
278         (constants.SSHK_RSA, "aaa", "bbb"),
279         ]
280       }
281
282     prepare_node_join.UpdateSshRoot(data, True,
283                                     _homedir_fn=self._GetHomeDir)
284     self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
285     self.assertEqual(os.listdir(self.sshdir), [])
286
287   def testUpdate(self):
288     data = {
289       constants.SSHS_SSH_ROOT_KEY: [
290         (constants.SSHK_DSA, "pubdsa", "privatedsa"),
291         ]
292       }
293
294     prepare_node_join.UpdateSshRoot(data, False,
295                                     _homedir_fn=self._GetHomeDir)
296     self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
297     self.assertEqual(sorted(os.listdir(self.sshdir)),
298                      sorted(["authorized_keys", "id_dsa", "id_dsa.pub"]))
299     self.assertEqual(utils.ReadFile(utils.PathJoin(self.sshdir, "id_dsa")),
300                      "privatedsa")
301     self.assertEqual(utils.ReadFile(utils.PathJoin(self.sshdir, "id_dsa.pub")),
302                      "pubdsa")
303     self.assertEqual(utils.ReadFile(utils.PathJoin(self.sshdir,
304                                                    "authorized_keys")),
305                      "ssh-dss pubdsa\n")
306
307
308 if __name__ == "__main__":
309   testutils.GanetiTestProgram()