Statistics
| Branch: | Tag: | Revision:

root / test / py / ganeti.utils.x509_unittest.py @ ec3a7362

History | View | Annotate | Download (13.5 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2006, 2007, 2010, 2011, 2012 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for testing ganeti.utils.x509"""
23

    
24
import os
25
import tempfile
26
import unittest
27
import shutil
28
import OpenSSL
29
import distutils.version
30
import string
31

    
32
from ganeti import constants
33
from ganeti import utils
34
from ganeti import errors
35

    
36
import testutils
37

    
38

    
39
class TestParseAsn1Generalizedtime(unittest.TestCase):
40
  def setUp(self):
41
    self._Parse = utils.x509._ParseAsn1Generalizedtime
42

    
43
  def test(self):
44
    # UTC
45
    self.assertEqual(self._Parse("19700101000000Z"), 0)
46
    self.assertEqual(self._Parse("20100222174152Z"), 1266860512)
47
    self.assertEqual(self._Parse("20380119031407Z"), (2**31) - 1)
48

    
49
    # With offset
50
    self.assertEqual(self._Parse("20100222174152+0000"), 1266860512)
51
    self.assertEqual(self._Parse("20100223131652+0000"), 1266931012)
52
    self.assertEqual(self._Parse("20100223051808-0800"), 1266931088)
53
    self.assertEqual(self._Parse("20100224002135+1100"), 1266931295)
54
    self.assertEqual(self._Parse("19700101000000-0100"), 3600)
55

    
56
    # Leap seconds are not supported by datetime.datetime
57
    self.assertRaises(ValueError, self._Parse, "19841231235960+0000")
58
    self.assertRaises(ValueError, self._Parse, "19920630235960+0000")
59

    
60
    # Errors
61
    self.assertRaises(ValueError, self._Parse, "")
62
    self.assertRaises(ValueError, self._Parse, "invalid")
63
    self.assertRaises(ValueError, self._Parse, "20100222174152")
64
    self.assertRaises(ValueError, self._Parse, "Mon Feb 22 17:47:02 UTC 2010")
65
    self.assertRaises(ValueError, self._Parse, "2010-02-22 17:42:02")
66

    
67

    
68
class TestGetX509CertValidity(testutils.GanetiTestCase):
69
  def setUp(self):
70
    testutils.GanetiTestCase.setUp(self)
71

    
72
    pyopenssl_version = distutils.version.LooseVersion(OpenSSL.__version__)
73

    
74
    # Test whether we have pyOpenSSL 0.7 or above
75
    self.pyopenssl0_7 = (pyopenssl_version >= "0.7")
76

    
77
    if not self.pyopenssl0_7:
78
      warnings.warn("This test requires pyOpenSSL 0.7 or above to"
79
                    " function correctly")
80

    
81
  def _LoadCert(self, name):
82
    return OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
83
                                           testutils.ReadTestData(name))
84

    
85
  def test(self):
86
    validity = utils.GetX509CertValidity(self._LoadCert("cert1.pem"))
87
    if self.pyopenssl0_7:
88
      self.assertEqual(validity, (1266919967, 1267524767))
89
    else:
90
      self.assertEqual(validity, (None, None))
91

    
92

    
93
class TestSignX509Certificate(unittest.TestCase):
94
  KEY = "My private key!"
95
  KEY_OTHER = "Another key"
96

    
97
  def test(self):
98
    # Generate certificate valid for 5 minutes
99
    (_, cert_pem) = utils.GenerateSelfSignedX509Cert(None, 300, 1)
100

    
101
    cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
102
                                           cert_pem)
103

    
104
    # No signature at all
105
    self.assertRaises(errors.GenericError,
106
                      utils.LoadSignedX509Certificate, cert_pem, self.KEY)
107

    
108
    # Invalid input
109
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
110
                      "", self.KEY)
111
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
112
                      "X-Ganeti-Signature: \n", self.KEY)
113
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
114
                      "X-Ganeti-Sign: $1234$abcdef\n", self.KEY)
115
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
116
                      "X-Ganeti-Signature: $1234567890$abcdef\n", self.KEY)
117
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
118
                      "X-Ganeti-Signature: $1234$abc\n\n" + cert_pem, self.KEY)
119

    
120
    # Invalid salt
121
    for salt in list("-_@$,:;/\\ \t\n"):
122
      self.assertRaises(errors.GenericError, utils.SignX509Certificate,
123
                        cert_pem, self.KEY, "foo%sbar" % salt)
124

    
125
    for salt in ["HelloWorld", "salt", string.letters, string.digits,
126
                 utils.GenerateSecret(numbytes=4),
127
                 utils.GenerateSecret(numbytes=16),
128
                 "{123:456}".encode("hex")]:
129
      signed_pem = utils.SignX509Certificate(cert, self.KEY, salt)
130

    
131
      self._Check(cert, salt, signed_pem)
132

    
133
      self._Check(cert, salt, "X-Another-Header: with a value\n" + signed_pem)
134
      self._Check(cert, salt, (10 * "Hello World!\n") + signed_pem)
135
      self._Check(cert, salt, (signed_pem + "\n\na few more\n"
136
                               "lines----\n------ at\nthe end!"))
137

    
138
  def _Check(self, cert, salt, pem):
139
    (cert2, salt2) = utils.LoadSignedX509Certificate(pem, self.KEY)
140
    self.assertEqual(salt, salt2)
141
    self.assertEqual(cert.digest("sha1"), cert2.digest("sha1"))
142

    
143
    # Other key
144
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
145
                      pem, self.KEY_OTHER)
146

    
147

    
148
class TestCertVerification(testutils.GanetiTestCase):
149
  def setUp(self):
150
    testutils.GanetiTestCase.setUp(self)
151

    
152
    self.tmpdir = tempfile.mkdtemp()
153

    
154
  def tearDown(self):
155
    shutil.rmtree(self.tmpdir)
156

    
157
  def testVerifyCertificate(self):
158
    cert_pem = testutils.ReadTestData("cert1.pem")
159
    cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
160
                                           cert_pem)
161

    
162
    # Not checking return value as this certificate is expired
163
    utils.VerifyX509Certificate(cert, 30, 7)
164

    
165
  @staticmethod
166
  def _GenCert(key, before, validity):
167
    # Urgh... mostly copied from x509.py :(
168

    
169
    # Create self-signed certificate
170
    cert = OpenSSL.crypto.X509()
171
    cert.set_serial_number(1)
172
    if before != 0:
173
      cert.gmtime_adj_notBefore(int(before))
174
    cert.gmtime_adj_notAfter(validity)
175
    cert.set_issuer(cert.get_subject())
176
    cert.set_pubkey(key)
177
    cert.sign(key, constants.X509_CERT_SIGN_DIGEST)
178
    return cert
179

    
180
  def testClockSkew(self):
181
    SKEW = constants.NODE_MAX_CLOCK_SKEW
182
    # Create private and public key
183
    key = OpenSSL.crypto.PKey()
184
    key.generate_key(OpenSSL.crypto.TYPE_RSA, constants.RSA_KEY_BITS)
185

    
186
    validity = 7 * 86400
187
    # skew small enough, accepting cert; note that this is a timed
188
    # test, and could fail if the machine is so loaded that the next
189
    # few lines take more than NODE_MAX_CLOCK_SKEW / 2
190
    for before in [-1, 0, SKEW / 4, SKEW / 2]:
191
      cert = self._GenCert(key, before, validity)
192
      result = utils.VerifyX509Certificate(cert, 1, 2)
193
      self.assertEqual(result, (None, None))
194

    
195
    # skew too great, not accepting certs
196
    for before in [SKEW * 2, SKEW * 10]:
197
      cert = self._GenCert(key, before, validity)
198
      (status, msg) = utils.VerifyX509Certificate(cert, 1, 2)
199
      self.assertEqual(status, utils.CERT_WARNING)
200
      self.assertTrue(msg.startswith("Certificate not yet valid"))
201

    
202

    
203
class TestVerifyCertificateInner(unittest.TestCase):
204
  def test(self):
205
    vci = utils.x509._VerifyCertificateInner
206

    
207
    # Valid
208
    self.assertEqual(vci(False, 1263916313, 1298476313, 1266940313, 30, 7),
209
                     (None, None))
210

    
211
    # Not yet valid
212
    (errcode, msg) = vci(False, 1266507600, 1267544400, 1266075600, 30, 7)
213
    self.assertEqual(errcode, utils.CERT_WARNING)
214

    
215
    # Expiring soon
216
    (errcode, msg) = vci(False, 1266507600, 1267544400, 1266939600, 30, 7)
217
    self.assertEqual(errcode, utils.CERT_ERROR)
218

    
219
    (errcode, msg) = vci(False, 1266507600, 1267544400, 1266939600, 30, 1)
220
    self.assertEqual(errcode, utils.CERT_WARNING)
221

    
222
    (errcode, msg) = vci(False, 1266507600, None, 1266939600, 30, 7)
223
    self.assertEqual(errcode, None)
224

    
225
    # Expired
226
    (errcode, msg) = vci(True, 1266507600, 1267544400, 1266939600, 30, 7)
227
    self.assertEqual(errcode, utils.CERT_ERROR)
228

    
229
    (errcode, msg) = vci(True, None, 1267544400, 1266939600, 30, 7)
230
    self.assertEqual(errcode, utils.CERT_ERROR)
231

    
232
    (errcode, msg) = vci(True, 1266507600, None, 1266939600, 30, 7)
233
    self.assertEqual(errcode, utils.CERT_ERROR)
234

    
235
    (errcode, msg) = vci(True, None, None, 1266939600, 30, 7)
236
    self.assertEqual(errcode, utils.CERT_ERROR)
237

    
238

    
239
class TestGenerateSelfSignedX509Cert(unittest.TestCase):
240
  def setUp(self):
241
    self.tmpdir = tempfile.mkdtemp()
242

    
243
  def tearDown(self):
244
    shutil.rmtree(self.tmpdir)
245

    
246
  def _checkRsaPrivateKey(self, key):
247
    lines = key.splitlines()
248
    return (("-----BEGIN RSA PRIVATE KEY-----" in lines and
249
             "-----END RSA PRIVATE KEY-----" in lines) or
250
            ("-----BEGIN PRIVATE KEY-----" in lines and
251
             "-----END PRIVATE KEY-----" in lines))
252

    
253
  def _checkCertificate(self, cert):
254
    lines = cert.splitlines()
255
    return ("-----BEGIN CERTIFICATE-----" in lines and
256
            "-----END CERTIFICATE-----" in lines)
257

    
258
  def test(self):
259
    for common_name in [None, ".", "Ganeti", "node1.example.com"]:
260
      (key_pem, cert_pem) = utils.GenerateSelfSignedX509Cert(common_name, 300,
261
                                                             1)
262
      self._checkRsaPrivateKey(key_pem)
263
      self._checkCertificate(cert_pem)
264

    
265
      key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM,
266
                                           key_pem)
267
      self.assert_(key.bits() >= 1024)
268
      self.assertEqual(key.bits(), constants.RSA_KEY_BITS)
269
      self.assertEqual(key.type(), OpenSSL.crypto.TYPE_RSA)
270

    
271
      x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
272
                                             cert_pem)
273
      self.failIf(x509.has_expired())
274
      self.assertEqual(x509.get_issuer().CN, common_name)
275
      self.assertEqual(x509.get_subject().CN, common_name)
276
      self.assertEqual(x509.get_pubkey().bits(), constants.RSA_KEY_BITS)
277

    
278
  def testLegacy(self):
279
    cert1_filename = os.path.join(self.tmpdir, "cert1.pem")
280

    
281
    utils.GenerateSelfSignedSslCert(cert1_filename, 1, validity=1)
282

    
283
    cert1 = utils.ReadFile(cert1_filename)
284

    
285
    self.assert_(self._checkRsaPrivateKey(cert1))
286
    self.assert_(self._checkCertificate(cert1))
287

    
288

    
289
class TestCheckNodeCertificate(testutils.GanetiTestCase):
290
  def setUp(self):
291
    testutils.GanetiTestCase.setUp(self)
292
    self.tmpdir = tempfile.mkdtemp()
293

    
294
  def tearDown(self):
295
    testutils.GanetiTestCase.tearDown(self)
296
    shutil.rmtree(self.tmpdir)
297

    
298
  def testMismatchingKey(self):
299
    other_cert = testutils.TestDataFilename("cert1.pem")
300
    node_cert = testutils.TestDataFilename("cert2.pem")
301

    
302
    cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
303
                                           utils.ReadFile(other_cert))
304

    
305
    try:
306
      utils.CheckNodeCertificate(cert, _noded_cert_file=node_cert)
307
    except errors.GenericError, err:
308
      self.assertEqual(str(err),
309
                       "Given cluster certificate does not match local key")
310
    else:
311
      self.fail("Exception was not raised")
312

    
313
  def testMatchingKey(self):
314
    cert_filename = testutils.TestDataFilename("cert2.pem")
315

    
316
    # Extract certificate
317
    cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
318
                                           utils.ReadFile(cert_filename))
319
    cert_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM,
320
                                               cert)
321

    
322
    utils.CheckNodeCertificate(cert, _noded_cert_file=cert_filename)
323

    
324
  def testMissingFile(self):
325
    cert_path = testutils.TestDataFilename("cert1.pem")
326
    nodecert = utils.PathJoin(self.tmpdir, "does-not-exist")
327

    
328
    utils.CheckNodeCertificate(NotImplemented, _noded_cert_file=nodecert)
329

    
330
    self.assertFalse(os.path.exists(nodecert))
331

    
332
  def testInvalidCertificate(self):
333
    tmpfile = utils.PathJoin(self.tmpdir, "cert")
334
    utils.WriteFile(tmpfile, data="not a certificate")
335

    
336
    self.assertRaises(errors.X509CertError, utils.CheckNodeCertificate,
337
                      NotImplemented, _noded_cert_file=tmpfile)
338

    
339
  def testNoPrivateKey(self):
340
    cert = testutils.TestDataFilename("cert1.pem")
341
    self.assertRaises(errors.X509CertError, utils.CheckNodeCertificate,
342
                      NotImplemented, _noded_cert_file=cert)
343

    
344
  def testMismatchInNodeCert(self):
345
    cert1_path = testutils.TestDataFilename("cert1.pem")
346
    cert2_path = testutils.TestDataFilename("cert2.pem")
347
    tmpfile = utils.PathJoin(self.tmpdir, "cert")
348

    
349
    # Extract certificate
350
    cert1 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
351
                                            utils.ReadFile(cert1_path))
352
    cert1_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM,
353
                                                cert1)
354

    
355
    # Extract mismatching key
356
    key2 = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM,
357
                                          utils.ReadFile(cert2_path))
358
    key2_pem = OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM,
359
                                              key2)
360

    
361
    # Write to file
362
    utils.WriteFile(tmpfile, data=cert1_pem + key2_pem)
363

    
364
    try:
365
      utils.CheckNodeCertificate(cert1, _noded_cert_file=tmpfile)
366
    except errors.X509CertError, err:
367
      self.assertEqual(err.args,
368
                       (tmpfile, "Certificate does not match with private key"))
369
    else:
370
      self.fail("Exception was not raised")
371

    
372

    
373
if __name__ == "__main__":
374
  testutils.GanetiTestProgram()