Merge branch 'stable-2.7' into stable-2.8
[ganeti-local] / test / py / ganeti.tools.node_daemon_setup_unittest.py
1 #!/usr/bin/python
2 #
3
4 # Copyright (C) 2012 Google Inc.
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 # 02110-1301, USA.
20
21
22 """Script for testing ganeti.tools.node_daemon_setup"""
23
24 import unittest
25 import shutil
26 import tempfile
27 import os.path
28 import OpenSSL
29
30 from ganeti import errors
31 from ganeti import constants
32 from ganeti import serializer
33 from ganeti import pathutils
34 from ganeti import compat
35 from ganeti import utils
36 from ganeti.tools import node_daemon_setup
37
38 import testutils
39
40
41 _SetupError = node_daemon_setup.SetupError
42
43
44 class TestLoadData(unittest.TestCase):
45   def testNoJson(self):
46     for data in ["", "{", "}"]:
47       self.assertRaises(errors.ParseError, node_daemon_setup.LoadData, data)
48
49   def testInvalidDataStructure(self):
50     raw = serializer.DumpJson({
51       "some other thing": False,
52       })
53     self.assertRaises(errors.ParseError, node_daemon_setup.LoadData, raw)
54
55     raw = serializer.DumpJson([])
56     self.assertRaises(errors.ParseError, node_daemon_setup.LoadData, raw)
57
58   def testValidData(self):
59     raw = serializer.DumpJson({})
60     self.assertEqual(node_daemon_setup.LoadData(raw), {})
61
62
63 class TestVerifyCertificate(testutils.GanetiTestCase):
64   def setUp(self):
65     testutils.GanetiTestCase.setUp(self)
66     self.tmpdir = tempfile.mkdtemp()
67
68   def tearDown(self):
69     testutils.GanetiTestCase.tearDown(self)
70     shutil.rmtree(self.tmpdir)
71
72   def testNoCert(self):
73     self.assertRaises(_SetupError, node_daemon_setup.VerifyCertificate,
74                       {}, _verify_fn=NotImplemented)
75
76   def testVerificationSuccessWithCert(self):
77     node_daemon_setup.VerifyCertificate({
78       constants.NDS_NODE_DAEMON_CERTIFICATE: "something",
79       }, _verify_fn=lambda _: None)
80
81   def testNoPrivateKey(self):
82     cert_filename = testutils.TestDataFilename("cert1.pem")
83     cert_pem = utils.ReadFile(cert_filename)
84
85     self.assertRaises(errors.X509CertError,
86                       node_daemon_setup._VerifyCertificate,
87                       cert_pem, _check_fn=NotImplemented)
88
89   def testInvalidCertificate(self):
90     self.assertRaises(errors.X509CertError,
91                       node_daemon_setup._VerifyCertificate,
92                       "Something that's not a certificate",
93                       _check_fn=NotImplemented)
94
95   @staticmethod
96   def _Check(cert):
97     assert cert.get_subject()
98
99   def testSuccessfulCheck(self):
100     cert_filename = testutils.TestDataFilename("cert2.pem")
101     cert_pem = utils.ReadFile(cert_filename)
102     result = \
103       node_daemon_setup._VerifyCertificate(cert_pem, _check_fn=self._Check)
104
105     cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, result)
106     self.assertTrue(cert)
107
108     key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, result)
109     self.assertTrue(key)
110
111   def testMismatchingKey(self):
112     cert1_path = testutils.TestDataFilename("cert1.pem")
113     cert2_path = testutils.TestDataFilename("cert2.pem")
114
115     # Extract certificate
116     cert1 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
117                                             utils.ReadFile(cert1_path))
118     cert1_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM,
119                                                 cert1)
120
121     # Extract mismatching key
122     key2 = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM,
123                                           utils.ReadFile(cert2_path))
124     key2_pem = OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM,
125                                               key2)
126
127     try:
128       node_daemon_setup._VerifyCertificate(cert1_pem + key2_pem,
129                                            _check_fn=NotImplemented)
130     except errors.X509CertError, err:
131       self.assertEqual(err.args,
132                        ("(stdin)", "Certificate is not signed with given key"))
133     else:
134       self.fail("Exception was not raised")
135
136
137 class TestVerifyClusterName(unittest.TestCase):
138   def setUp(self):
139     unittest.TestCase.setUp(self)
140     self.tmpdir = tempfile.mkdtemp()
141
142   def tearDown(self):
143     unittest.TestCase.tearDown(self)
144     shutil.rmtree(self.tmpdir)
145
146   def testNoName(self):
147     self.assertRaises(_SetupError, node_daemon_setup.VerifyClusterName,
148                       {}, _verify_fn=NotImplemented)
149
150   @staticmethod
151   def _FailingVerify(name):
152     assert name == "somecluster.example.com"
153     raise errors.GenericError()
154
155   def testFailingVerification(self):
156     data = {
157       constants.NDS_CLUSTER_NAME: "somecluster.example.com",
158       }
159
160     self.assertRaises(errors.GenericError, node_daemon_setup.VerifyClusterName,
161                       data, _verify_fn=self._FailingVerify)
162
163   def testSuccess(self):
164     data = {
165       constants.NDS_CLUSTER_NAME: "cluster.example.com",
166       }
167
168     result = \
169       node_daemon_setup.VerifyClusterName(data, _verify_fn=lambda _: None)
170
171     self.assertEqual(result, "cluster.example.com")
172
173
174 class TestVerifySsconf(unittest.TestCase):
175   def testNoSsconf(self):
176     self.assertRaises(_SetupError, node_daemon_setup.VerifySsconf,
177                       {}, NotImplemented, _verify_fn=NotImplemented)
178
179     for items in [None, {}]:
180       self.assertRaises(_SetupError, node_daemon_setup.VerifySsconf, {
181         constants.NDS_SSCONF: items,
182         }, NotImplemented, _verify_fn=NotImplemented)
183
184   def _Check(self, names):
185     self.assertEqual(frozenset(names), frozenset([
186       constants.SS_CLUSTER_NAME,
187       constants.SS_INSTANCE_LIST,
188       ]))
189
190   def testSuccess(self):
191     ssdata = {
192       constants.SS_CLUSTER_NAME: "cluster.example.com",
193       constants.SS_INSTANCE_LIST: [],
194       }
195
196     result = node_daemon_setup.VerifySsconf({
197       constants.NDS_SSCONF: ssdata,
198       }, "cluster.example.com", _verify_fn=self._Check)
199
200     self.assertEqual(result, ssdata)
201
202     self.assertRaises(_SetupError, node_daemon_setup.VerifySsconf, {
203       constants.NDS_SSCONF: ssdata,
204       }, "wrong.example.com", _verify_fn=self._Check)
205
206   def testInvalidKey(self):
207     self.assertRaises(errors.GenericError, node_daemon_setup.VerifySsconf, {
208       constants.NDS_SSCONF: {
209         "no-valid-ssconf-key": "value",
210         },
211       }, NotImplemented)
212
213
214 if __name__ == "__main__":
215   testutils.GanetiTestProgram()