Statistics
| Branch: | Tag: | Revision:

root / test / ganeti.tools.node_daemon_setup_unittest.py @ 69e5fefc

History | View | Annotate | Download (6.6 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2012 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for testing ganeti.tools.node_daemon_setup"""
23

    
24
import unittest
25
import shutil
26
import tempfile
27
import os.path
28
import OpenSSL
29

    
30
from ganeti import errors
31
from ganeti import constants
32
from ganeti import serializer
33
from ganeti import pathutils
34
from ganeti import compat
35
from ganeti import utils
36
from ganeti.tools import node_daemon_setup
37

    
38
import testutils
39

    
40

    
41
_SetupError = node_daemon_setup.SetupError
42

    
43

    
44
class TestLoadData(unittest.TestCase):
45
  def testNoJson(self):
46
    for data in ["", "{", "}"]:
47
      self.assertRaises(errors.ParseError, node_daemon_setup.LoadData, data)
48

    
49
  def testInvalidDataStructure(self):
50
    raw = serializer.DumpJson({
51
      "some other thing": False,
52
      })
53
    self.assertRaises(errors.ParseError, node_daemon_setup.LoadData, raw)
54

    
55
    raw = serializer.DumpJson([])
56
    self.assertRaises(errors.ParseError, node_daemon_setup.LoadData, raw)
57

    
58
  def testValidData(self):
59
    raw = serializer.DumpJson({})
60
    self.assertEqual(node_daemon_setup.LoadData(raw), {})
61

    
62

    
63
class TestVerifyCertificate(testutils.GanetiTestCase):
64
  def setUp(self):
65
    testutils.GanetiTestCase.setUp(self)
66
    self.tmpdir = tempfile.mkdtemp()
67

    
68
  def tearDown(self):
69
    testutils.GanetiTestCase.tearDown(self)
70
    shutil.rmtree(self.tmpdir)
71

    
72
  def testNoCert(self):
73
    self.assertRaises(_SetupError, node_daemon_setup.VerifyCertificate,
74
                      {}, _verify_fn=NotImplemented)
75

    
76
  def testVerificationSuccessWithCert(self):
77
    node_daemon_setup.VerifyCertificate({
78
      constants.NDS_NODE_DAEMON_CERTIFICATE: "something",
79
      }, _verify_fn=lambda _: None)
80

    
81
  def testNoPrivateKey(self):
82
    cert_filename = self._TestDataFilename("cert1.pem")
83
    cert_pem = utils.ReadFile(cert_filename)
84

    
85
    self.assertRaises(errors.X509CertError,
86
                      node_daemon_setup._VerifyCertificate,
87
                      cert_pem, _check_fn=NotImplemented)
88

    
89
  def testInvalidCertificate(self):
90
    self.assertRaises(errors.X509CertError,
91
                      node_daemon_setup._VerifyCertificate,
92
                      "Something that's not a certificate",
93
                      _check_fn=NotImplemented)
94

    
95
  @staticmethod
96
  def _Check(cert):
97
    assert cert.get_subject()
98

    
99
  def testSuccessfulCheck(self):
100
    cert_filename = self._TestDataFilename("cert2.pem")
101
    cert_pem = utils.ReadFile(cert_filename)
102
    result = \
103
      node_daemon_setup._VerifyCertificate(cert_pem, _check_fn=self._Check)
104
    self.assertTrue("-----BEGIN PRIVATE KEY-----" in result)
105
    self.assertTrue("-----BEGIN CERTIFICATE-----" in result)
106

    
107
  def testMismatchingKey(self):
108
    cert1_path = self._TestDataFilename("cert1.pem")
109
    cert2_path = self._TestDataFilename("cert2.pem")
110

    
111
    # Extract certificate
112
    cert1 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
113
                                            utils.ReadFile(cert1_path))
114
    cert1_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM,
115
                                                cert1)
116

    
117
    # Extract mismatching key
118
    key2 = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM,
119
                                          utils.ReadFile(cert2_path))
120
    key2_pem = OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM,
121
                                              key2)
122

    
123
    try:
124
      node_daemon_setup._VerifyCertificate(cert1_pem + key2_pem,
125
                                           _check_fn=NotImplemented)
126
    except errors.X509CertError, err:
127
      self.assertEqual(err.args,
128
                       ("(stdin)", "Certificate is not signed with given key"))
129
    else:
130
      self.fail("Exception was not raised")
131

    
132

    
133
class TestVerifyClusterName(unittest.TestCase):
134
  def setUp(self):
135
    unittest.TestCase.setUp(self)
136
    self.tmpdir = tempfile.mkdtemp()
137

    
138
  def tearDown(self):
139
    unittest.TestCase.tearDown(self)
140
    shutil.rmtree(self.tmpdir)
141

    
142
  def testNoName(self):
143
    self.assertRaises(_SetupError, node_daemon_setup.VerifyClusterName,
144
                      {}, _verify_fn=NotImplemented)
145

    
146
  @staticmethod
147
  def _FailingVerify(name):
148
    assert name == "somecluster.example.com"
149
    raise errors.GenericError()
150

    
151
  def testFailingVerification(self):
152
    data = {
153
      constants.NDS_CLUSTER_NAME: "somecluster.example.com",
154
      }
155

    
156
    self.assertRaises(errors.GenericError, node_daemon_setup.VerifyClusterName,
157
                      data, _verify_fn=self._FailingVerify)
158

    
159
  def testSuccess(self):
160
    data = {
161
      constants.NDS_CLUSTER_NAME: "cluster.example.com",
162
      }
163

    
164
    result = \
165
      node_daemon_setup.VerifyClusterName(data, _verify_fn=lambda _: None)
166

    
167
    self.assertEqual(result, "cluster.example.com")
168

    
169

    
170
class TestVerifySsconf(unittest.TestCase):
171
  def testNoSsconf(self):
172
    self.assertRaises(_SetupError, node_daemon_setup.VerifySsconf,
173
                      {}, NotImplemented, _verify_fn=NotImplemented)
174

    
175
    for items in [None, {}]:
176
      self.assertRaises(_SetupError, node_daemon_setup.VerifySsconf, {
177
        constants.NDS_SSCONF: items,
178
        }, NotImplemented, _verify_fn=NotImplemented)
179

    
180
  def _Check(self, names):
181
    self.assertEqual(frozenset(names), frozenset([
182
      constants.SS_CLUSTER_NAME,
183
      constants.SS_INSTANCE_LIST,
184
      ]))
185

    
186
  def testSuccess(self):
187
    ssdata = {
188
      constants.SS_CLUSTER_NAME: "cluster.example.com",
189
      constants.SS_INSTANCE_LIST: [],
190
      }
191

    
192
    result = node_daemon_setup.VerifySsconf({
193
      constants.NDS_SSCONF: ssdata,
194
      }, "cluster.example.com", _verify_fn=self._Check)
195

    
196
    self.assertEqual(result, ssdata)
197

    
198
    self.assertRaises(_SetupError, node_daemon_setup.VerifySsconf, {
199
      constants.NDS_SSCONF: ssdata,
200
      }, "wrong.example.com", _verify_fn=self._Check)
201

    
202
  def testInvalidKey(self):
203
    self.assertRaises(errors.GenericError, node_daemon_setup.VerifySsconf, {
204
      constants.NDS_SSCONF: {
205
        "no-valid-ssconf-key": "value",
206
        },
207
      }, NotImplemented)
208

    
209

    
210
if __name__ == "__main__":
211
  testutils.GanetiTestProgram()