Statistics
| Branch: | Tag: | Revision:

root / test / ganeti.tools.prepare_node_join_unittest.py @ 910ef222

History | View | Annotate | Download (10.3 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2012 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for testing ganeti.tools.prepare_node_join"""
23

    
24
import unittest
25
import shutil
26
import tempfile
27
import os.path
28
import OpenSSL
29

    
30
from ganeti import errors
31
from ganeti import constants
32
from ganeti import serializer
33
from ganeti import pathutils
34
from ganeti import compat
35
from ganeti import utils
36
from ganeti.tools import prepare_node_join
37

    
38
import testutils
39

    
40

    
41
_JoinError = prepare_node_join.JoinError
42

    
43

    
44
class TestLoadData(unittest.TestCase):
45
  def testNoJson(self):
46
    self.assertRaises(errors.ParseError, prepare_node_join.LoadData, "")
47
    self.assertRaises(errors.ParseError, prepare_node_join.LoadData, "}")
48

    
49
  def testInvalidDataStructure(self):
50
    raw = serializer.DumpJson({
51
      "some other thing": False,
52
      })
53
    self.assertRaises(errors.ParseError, prepare_node_join.LoadData, raw)
54

    
55
    raw = serializer.DumpJson([])
56
    self.assertRaises(errors.ParseError, prepare_node_join.LoadData, raw)
57

    
58
  def testValidData(self):
59
    raw = serializer.DumpJson({})
60
    self.assertEqual(prepare_node_join.LoadData(raw), {})
61

    
62

    
63
class TestVerifyCertificate(testutils.GanetiTestCase):
64
  def setUp(self):
65
    testutils.GanetiTestCase.setUp(self)
66
    self.tmpdir = tempfile.mkdtemp()
67

    
68
  def tearDown(self):
69
    testutils.GanetiTestCase.tearDown(self)
70
    shutil.rmtree(self.tmpdir)
71

    
72
  def testNoCert(self):
73
    prepare_node_join.VerifyCertificate({}, _verify_fn=NotImplemented)
74

    
75
  def testMismatchingKey(self):
76
    other_cert = self._TestDataFilename("cert1.pem")
77
    node_cert = self._TestDataFilename("cert2.pem")
78

    
79
    self.assertRaises(_JoinError, prepare_node_join._VerifyCertificate,
80
                      utils.ReadFile(other_cert), _noded_cert_file=node_cert)
81

    
82
  def testGivenPrivateKey(self):
83
    cert_filename = self._TestDataFilename("cert2.pem")
84
    cert_pem = utils.ReadFile(cert_filename)
85

    
86
    self.assertRaises(_JoinError, prepare_node_join._VerifyCertificate,
87
                      cert_pem, _noded_cert_file=cert_filename)
88

    
89
  def testMatchingKey(self):
90
    cert_filename = self._TestDataFilename("cert2.pem")
91

    
92
    # Extract certificate
93
    cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
94
                                           utils.ReadFile(cert_filename))
95
    cert_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM,
96
                                               cert)
97

    
98
    prepare_node_join._VerifyCertificate(cert_pem,
99
                                         _noded_cert_file=cert_filename)
100

    
101
  def testMissingFile(self):
102
    cert = self._TestDataFilename("cert1.pem")
103
    nodecert = utils.PathJoin(self.tmpdir, "does-not-exist")
104
    prepare_node_join._VerifyCertificate(utils.ReadFile(cert),
105
                                         _noded_cert_file=nodecert)
106

    
107
  def testInvalidCertificate(self):
108
    self.assertRaises(errors.X509CertError,
109
                      prepare_node_join._VerifyCertificate,
110
                      "Something that's not a certificate",
111
                      _noded_cert_file=NotImplemented)
112

    
113
  def testNoPrivateKey(self):
114
    cert = self._TestDataFilename("cert1.pem")
115
    self.assertRaises(errors.X509CertError,
116
                      prepare_node_join._VerifyCertificate,
117
                      utils.ReadFile(cert), _noded_cert_file=cert)
118

    
119

    
120
class TestVerifyClusterName(unittest.TestCase):
121
  def setUp(self):
122
    unittest.TestCase.setUp(self)
123
    self.tmpdir = tempfile.mkdtemp()
124

    
125
  def tearDown(self):
126
    unittest.TestCase.tearDown(self)
127
    shutil.rmtree(self.tmpdir)
128

    
129
  def testNoName(self):
130
    self.assertRaises(_JoinError, prepare_node_join.VerifyClusterName,
131
                      {}, _verify_fn=NotImplemented)
132

    
133
  def testMissingFile(self):
134
    tmpfile = utils.PathJoin(self.tmpdir, "does-not-exist")
135
    prepare_node_join._VerifyClusterName(NotImplemented,
136
                                         _ss_cluster_name_file=tmpfile)
137

    
138
  def testMatchingName(self):
139
    tmpfile = utils.PathJoin(self.tmpdir, "cluster_name")
140

    
141
    for content in ["cluster.example.com", "cluster.example.com\n\n"]:
142
      utils.WriteFile(tmpfile, data=content)
143
      prepare_node_join._VerifyClusterName("cluster.example.com",
144
                                           _ss_cluster_name_file=tmpfile)
145

    
146
  def testNameMismatch(self):
147
    tmpfile = utils.PathJoin(self.tmpdir, "cluster_name")
148

    
149
    for content in ["something.example.com", "foobar\n\ncluster.example.com"]:
150
      utils.WriteFile(tmpfile, data=content)
151
      self.assertRaises(_JoinError, prepare_node_join._VerifyClusterName,
152
                        "cluster.example.com", _ss_cluster_name_file=tmpfile)
153

    
154

    
155
class TestUpdateSshDaemon(unittest.TestCase):
156
  def setUp(self):
157
    unittest.TestCase.setUp(self)
158
    self.tmpdir = tempfile.mkdtemp()
159

    
160
    self.keyfiles = {
161
      constants.SSHK_RSA:
162
        (utils.PathJoin(self.tmpdir, "rsa.public"),
163
         utils.PathJoin(self.tmpdir, "rsa.private")),
164
      constants.SSHK_DSA:
165
        (utils.PathJoin(self.tmpdir, "dsa.public"),
166
         utils.PathJoin(self.tmpdir, "dsa.private")),
167
      }
168

    
169
  def tearDown(self):
170
    unittest.TestCase.tearDown(self)
171
    shutil.rmtree(self.tmpdir)
172

    
173
  def testNoKeys(self):
174
    data_empty_keys = {
175
      constants.SSHS_SSH_HOST_KEY: [],
176
      }
177

    
178
    for data in [{}, data_empty_keys]:
179
      for dry_run in [False, True]:
180
        prepare_node_join.UpdateSshDaemon(data, dry_run,
181
                                          _runcmd_fn=NotImplemented,
182
                                          _keyfiles=NotImplemented)
183
    self.assertEqual(os.listdir(self.tmpdir), [])
184

    
185
  def _TestDryRun(self, data):
186
    prepare_node_join.UpdateSshDaemon(data, True, _runcmd_fn=NotImplemented,
187
                                      _keyfiles=self.keyfiles)
188
    self.assertEqual(os.listdir(self.tmpdir), [])
189

    
190
  def testDryRunRsa(self):
191
    self._TestDryRun({
192
      constants.SSHS_SSH_HOST_KEY: [
193
        (constants.SSHK_RSA, "rsapub", "rsapriv"),
194
        ],
195
      })
196

    
197
  def testDryRunDsa(self):
198
    self._TestDryRun({
199
      constants.SSHS_SSH_HOST_KEY: [
200
        (constants.SSHK_DSA, "dsapub", "dsapriv"),
201
        ],
202
      })
203

    
204
  def _RunCmd(self, fail, cmd, interactive=NotImplemented):
205
    self.assertTrue(interactive)
206
    self.assertEqual(cmd, [pathutils.DAEMON_UTIL, "reload-ssh-keys"])
207
    if fail:
208
      exit_code = constants.EXIT_FAILURE
209
    else:
210
      exit_code = constants.EXIT_SUCCESS
211
    return utils.RunResult(exit_code, None, "stdout", "stderr",
212
                           utils.ShellQuoteArgs(cmd),
213
                           NotImplemented, NotImplemented)
214

    
215
  def _TestUpdate(self, failcmd):
216
    data = {
217
      constants.SSHS_SSH_HOST_KEY: [
218
        (constants.SSHK_DSA, "dsapub", "dsapriv"),
219
        (constants.SSHK_RSA, "rsapub", "rsapriv"),
220
        ],
221
      }
222
    runcmd_fn = compat.partial(self._RunCmd, failcmd)
223
    if failcmd:
224
      self.assertRaises(_JoinError, prepare_node_join.UpdateSshDaemon,
225
                        data, False, _runcmd_fn=runcmd_fn,
226
                        _keyfiles=self.keyfiles)
227
    else:
228
      prepare_node_join.UpdateSshDaemon(data, False, _runcmd_fn=runcmd_fn,
229
                                        _keyfiles=self.keyfiles)
230
    self.assertEqual(sorted(os.listdir(self.tmpdir)), sorted([
231
      "rsa.private", "rsa.public",
232
      "dsa.private", "dsa.public",
233
      ]))
234
    self.assertEqual(utils.ReadFile(utils.PathJoin(self.tmpdir, "rsa.public")),
235
                     "rsapub")
236
    self.assertEqual(utils.ReadFile(utils.PathJoin(self.tmpdir, "rsa.private")),
237
                     "rsapriv")
238
    self.assertEqual(utils.ReadFile(utils.PathJoin(self.tmpdir, "dsa.public")),
239
                     "dsapub")
240
    self.assertEqual(utils.ReadFile(utils.PathJoin(self.tmpdir, "dsa.private")),
241
                     "dsapriv")
242

    
243
  def testSuccess(self):
244
    self._TestUpdate(False)
245

    
246
  def testFailure(self):
247
    self._TestUpdate(True)
248

    
249

    
250
class TestUpdateSshRoot(unittest.TestCase):
251
  def setUp(self):
252
    unittest.TestCase.setUp(self)
253
    self.tmpdir = tempfile.mkdtemp()
254
    self.sshdir = utils.PathJoin(self.tmpdir, ".ssh")
255

    
256
  def tearDown(self):
257
    unittest.TestCase.tearDown(self)
258
    shutil.rmtree(self.tmpdir)
259

    
260
  def _GetHomeDir(self, user):
261
    self.assertEqual(user, constants.SSH_LOGIN_USER)
262
    return self.tmpdir
263

    
264
  def testNoKeys(self):
265
    data_empty_keys = {
266
      constants.SSHS_SSH_ROOT_KEY: [],
267
      }
268

    
269
    for data in [{}, data_empty_keys]:
270
      for dry_run in [False, True]:
271
        prepare_node_join.UpdateSshRoot(data, dry_run,
272
                                        _homedir_fn=NotImplemented)
273
    self.assertEqual(os.listdir(self.tmpdir), [])
274

    
275
  def testDryRun(self):
276
    data = {
277
      constants.SSHS_SSH_ROOT_KEY: [
278
        (constants.SSHK_RSA, "aaa", "bbb"),
279
        ]
280
      }
281

    
282
    prepare_node_join.UpdateSshRoot(data, True,
283
                                    _homedir_fn=self._GetHomeDir)
284
    self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
285
    self.assertEqual(os.listdir(self.sshdir), [])
286

    
287
  def testUpdate(self):
288
    data = {
289
      constants.SSHS_SSH_ROOT_KEY: [
290
        (constants.SSHK_DSA, "ssh-dss pubdsa", "privatedsa"),
291
        ]
292
      }
293

    
294
    prepare_node_join.UpdateSshRoot(data, False,
295
                                    _homedir_fn=self._GetHomeDir)
296
    self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
297
    self.assertEqual(sorted(os.listdir(self.sshdir)),
298
                     sorted(["authorized_keys", "id_dsa", "id_dsa.pub"]))
299
    self.assertEqual(utils.ReadFile(utils.PathJoin(self.sshdir, "id_dsa")),
300
                     "privatedsa")
301
    self.assertEqual(utils.ReadFile(utils.PathJoin(self.sshdir, "id_dsa.pub")),
302
                     "ssh-dss pubdsa")
303
    self.assertEqual(utils.ReadFile(utils.PathJoin(self.sshdir,
304
                                                   "authorized_keys")),
305
                     "ssh-dss pubdsa\n")
306

    
307

    
308
if __name__ == "__main__":
309
  testutils.GanetiTestProgram()