Statistics
| Branch: | Tag: | Revision:

root / test / ganeti.tools.node_daemon_setup_unittest.py @ e712e5b8

History | View | Annotate | Download (6.6 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2012 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for testing ganeti.tools.node_daemon_setup"""
23

    
24
import unittest
25
import shutil
26
import tempfile
27
import os.path
28
import OpenSSL
29

    
30
from ganeti import errors
31
from ganeti import constants
32
from ganeti import serializer
33
from ganeti import pathutils
34
from ganeti import compat
35
from ganeti import utils
36
from ganeti.tools import node_daemon_setup
37

    
38
import testutils
39

    
40

    
41
_SetupError = node_daemon_setup.SetupError
42

    
43

    
44
class TestLoadData(unittest.TestCase):
45
  def testNoJson(self):
46
    for data in ["", "{", "}"]:
47
      self.assertRaises(errors.ParseError, node_daemon_setup.LoadData, data)
48

    
49
  def testInvalidDataStructure(self):
50
    raw = serializer.DumpJson({
51
      "some other thing": False,
52
      })
53
    self.assertRaises(errors.ParseError, node_daemon_setup.LoadData, raw)
54

    
55
    raw = serializer.DumpJson([])
56
    self.assertRaises(errors.ParseError, node_daemon_setup.LoadData, raw)
57

    
58
  def testValidData(self):
59
    raw = serializer.DumpJson({})
60
    self.assertEqual(node_daemon_setup.LoadData(raw), {})
61

    
62

    
63
class TestVerifyCertificate(testutils.GanetiTestCase):
64
  def setUp(self):
65
    testutils.GanetiTestCase.setUp(self)
66
    self.tmpdir = tempfile.mkdtemp()
67

    
68
  def tearDown(self):
69
    testutils.GanetiTestCase.tearDown(self)
70
    shutil.rmtree(self.tmpdir)
71

    
72
  def testNoCert(self):
73
    self.assertRaises(_SetupError, node_daemon_setup.VerifyCertificate,
74
                      {}, _verify_fn=NotImplemented)
75

    
76
  def testVerificationSuccessWithCert(self):
77
    node_daemon_setup.VerifyCertificate({
78
      constants.NDS_NODE_DAEMON_CERTIFICATE: "something",
79
      }, _verify_fn=lambda _: None)
80

    
81
  def testNoPrivateKey(self):
82
    cert_filename = self._TestDataFilename("cert1.pem")
83
    cert_pem = utils.ReadFile(cert_filename)
84

    
85
    self.assertRaises(errors.X509CertError,
86
                      node_daemon_setup._VerifyCertificate,
87
                      cert_pem, _check_fn=NotImplemented)
88

    
89
  def testInvalidCertificate(self):
90
    self.assertRaises(errors.X509CertError,
91
                      node_daemon_setup._VerifyCertificate,
92
                      "Something that's not a certificate",
93
                      _check_fn=NotImplemented)
94

    
95
  @staticmethod
96
  def _Check(cert):
97
    assert cert.get_subject()
98

    
99
  def testSuccessfulCheck(self):
100
    cert_filename = self._TestDataFilename("cert2.pem")
101
    cert_pem = utils.ReadFile(cert_filename)
102
    result = \
103
      node_daemon_setup._VerifyCertificate(cert_pem, _check_fn=self._Check)
104

    
105
    cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, result)
106
    self.assertTrue(cert)
107

    
108
    key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, result)
109
    self.assertTrue(key)
110

    
111
  def testMismatchingKey(self):
112
    cert1_path = self._TestDataFilename("cert1.pem")
113
    cert2_path = self._TestDataFilename("cert2.pem")
114

    
115
    # Extract certificate
116
    cert1 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
117
                                            utils.ReadFile(cert1_path))
118
    cert1_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM,
119
                                                cert1)
120

    
121
    # Extract mismatching key
122
    key2 = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM,
123
                                          utils.ReadFile(cert2_path))
124
    key2_pem = OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM,
125
                                              key2)
126

    
127
    try:
128
      node_daemon_setup._VerifyCertificate(cert1_pem + key2_pem,
129
                                           _check_fn=NotImplemented)
130
    except errors.X509CertError, err:
131
      self.assertEqual(err.args,
132
                       ("(stdin)", "Certificate is not signed with given key"))
133
    else:
134
      self.fail("Exception was not raised")
135

    
136

    
137
class TestVerifyClusterName(unittest.TestCase):
138
  def setUp(self):
139
    unittest.TestCase.setUp(self)
140
    self.tmpdir = tempfile.mkdtemp()
141

    
142
  def tearDown(self):
143
    unittest.TestCase.tearDown(self)
144
    shutil.rmtree(self.tmpdir)
145

    
146
  def testNoName(self):
147
    self.assertRaises(_SetupError, node_daemon_setup.VerifyClusterName,
148
                      {}, _verify_fn=NotImplemented)
149

    
150
  @staticmethod
151
  def _FailingVerify(name):
152
    assert name == "somecluster.example.com"
153
    raise errors.GenericError()
154

    
155
  def testFailingVerification(self):
156
    data = {
157
      constants.NDS_CLUSTER_NAME: "somecluster.example.com",
158
      }
159

    
160
    self.assertRaises(errors.GenericError, node_daemon_setup.VerifyClusterName,
161
                      data, _verify_fn=self._FailingVerify)
162

    
163
  def testSuccess(self):
164
    data = {
165
      constants.NDS_CLUSTER_NAME: "cluster.example.com",
166
      }
167

    
168
    result = \
169
      node_daemon_setup.VerifyClusterName(data, _verify_fn=lambda _: None)
170

    
171
    self.assertEqual(result, "cluster.example.com")
172

    
173

    
174
class TestVerifySsconf(unittest.TestCase):
175
  def testNoSsconf(self):
176
    self.assertRaises(_SetupError, node_daemon_setup.VerifySsconf,
177
                      {}, NotImplemented, _verify_fn=NotImplemented)
178

    
179
    for items in [None, {}]:
180
      self.assertRaises(_SetupError, node_daemon_setup.VerifySsconf, {
181
        constants.NDS_SSCONF: items,
182
        }, NotImplemented, _verify_fn=NotImplemented)
183

    
184
  def _Check(self, names):
185
    self.assertEqual(frozenset(names), frozenset([
186
      constants.SS_CLUSTER_NAME,
187
      constants.SS_INSTANCE_LIST,
188
      ]))
189

    
190
  def testSuccess(self):
191
    ssdata = {
192
      constants.SS_CLUSTER_NAME: "cluster.example.com",
193
      constants.SS_INSTANCE_LIST: [],
194
      }
195

    
196
    result = node_daemon_setup.VerifySsconf({
197
      constants.NDS_SSCONF: ssdata,
198
      }, "cluster.example.com", _verify_fn=self._Check)
199

    
200
    self.assertEqual(result, ssdata)
201

    
202
    self.assertRaises(_SetupError, node_daemon_setup.VerifySsconf, {
203
      constants.NDS_SSCONF: ssdata,
204
      }, "wrong.example.com", _verify_fn=self._Check)
205

    
206
  def testInvalidKey(self):
207
    self.assertRaises(errors.GenericError, node_daemon_setup.VerifySsconf, {
208
      constants.NDS_SSCONF: {
209
        "no-valid-ssconf-key": "value",
210
        },
211
      }, NotImplemented)
212

    
213

    
214
if __name__ == "__main__":
215
  testutils.GanetiTestProgram()