Statistics
| Branch: | Tag: | Revision:

root / test / ganeti.tools.prepare_node_join_unittest.py @ 0602cef3

History | View | Annotate | Download (8.7 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2012 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for testing ganeti.tools.prepare_node_join"""
23

    
24
import unittest
25
import shutil
26
import tempfile
27
import os.path
28
import OpenSSL
29

    
30
from ganeti import errors
31
from ganeti import constants
32
from ganeti import serializer
33
from ganeti import pathutils
34
from ganeti import compat
35
from ganeti import utils
36
from ganeti.tools import prepare_node_join
37

    
38
import testutils
39

    
40

    
41
_JoinError = prepare_node_join.JoinError
42

    
43

    
44
class TestLoadData(unittest.TestCase):
45
  def testNoJson(self):
46
    self.assertRaises(errors.ParseError, prepare_node_join.LoadData, "")
47
    self.assertRaises(errors.ParseError, prepare_node_join.LoadData, "}")
48

    
49
  def testInvalidDataStructure(self):
50
    raw = serializer.DumpJson({
51
      "some other thing": False,
52
      })
53
    self.assertRaises(errors.ParseError, prepare_node_join.LoadData, raw)
54

    
55
    raw = serializer.DumpJson([])
56
    self.assertRaises(errors.ParseError, prepare_node_join.LoadData, raw)
57

    
58
  def testValidData(self):
59
    raw = serializer.DumpJson({})
60
    self.assertEqual(prepare_node_join.LoadData(raw), {})
61

    
62

    
63
class TestVerifyCertificate(testutils.GanetiTestCase):
64
  def setUp(self):
65
    testutils.GanetiTestCase.setUp(self)
66
    self.tmpdir = tempfile.mkdtemp()
67

    
68
  def tearDown(self):
69
    testutils.GanetiTestCase.tearDown(self)
70
    shutil.rmtree(self.tmpdir)
71

    
72
  def testNoCert(self):
73
    prepare_node_join.VerifyCertificate({}, _verify_fn=NotImplemented)
74

    
75
  def testGivenPrivateKey(self):
76
    cert_filename = self._TestDataFilename("cert2.pem")
77
    cert_pem = utils.ReadFile(cert_filename)
78

    
79
    self.assertRaises(_JoinError, prepare_node_join._VerifyCertificate,
80
                      cert_pem, _check_fn=NotImplemented)
81

    
82
  def testInvalidCertificate(self):
83
    self.assertRaises(errors.X509CertError,
84
                      prepare_node_join._VerifyCertificate,
85
                      "Something that's not a certificate",
86
                      _check_fn=NotImplemented)
87

    
88
  @staticmethod
89
  def _Check(cert):
90
    assert cert.get_subject()
91

    
92
  def testSuccessfulCheck(self):
93
    cert_filename = self._TestDataFilename("cert1.pem")
94
    cert_pem = utils.ReadFile(cert_filename)
95
    prepare_node_join._VerifyCertificate(cert_pem, _check_fn=self._Check)
96

    
97

    
98
class TestVerifyClusterName(unittest.TestCase):
99
  def setUp(self):
100
    unittest.TestCase.setUp(self)
101
    self.tmpdir = tempfile.mkdtemp()
102

    
103
  def tearDown(self):
104
    unittest.TestCase.tearDown(self)
105
    shutil.rmtree(self.tmpdir)
106

    
107
  def testNoName(self):
108
    self.assertRaises(_JoinError, prepare_node_join.VerifyClusterName,
109
                      {}, _verify_fn=NotImplemented)
110

    
111
  @staticmethod
112
  def _FailingVerify(name):
113
    assert name == "cluster.example.com"
114
    raise errors.GenericError()
115

    
116
  def testFailingVerification(self):
117
    data = {
118
      constants.SSHS_CLUSTER_NAME: "cluster.example.com",
119
      }
120

    
121
    self.assertRaises(errors.GenericError, prepare_node_join.VerifyClusterName,
122
                      data, _verify_fn=self._FailingVerify)
123

    
124

    
125
class TestUpdateSshDaemon(unittest.TestCase):
126
  def setUp(self):
127
    unittest.TestCase.setUp(self)
128
    self.tmpdir = tempfile.mkdtemp()
129

    
130
    self.keyfiles = {
131
      constants.SSHK_RSA:
132
        (utils.PathJoin(self.tmpdir, "rsa.private"),
133
         utils.PathJoin(self.tmpdir, "rsa.public")),
134
      constants.SSHK_DSA:
135
        (utils.PathJoin(self.tmpdir, "dsa.private"),
136
         utils.PathJoin(self.tmpdir, "dsa.public")),
137
      }
138

    
139
  def tearDown(self):
140
    unittest.TestCase.tearDown(self)
141
    shutil.rmtree(self.tmpdir)
142

    
143
  def testNoKeys(self):
144
    data_empty_keys = {
145
      constants.SSHS_SSH_HOST_KEY: [],
146
      }
147

    
148
    for data in [{}, data_empty_keys]:
149
      for dry_run in [False, True]:
150
        prepare_node_join.UpdateSshDaemon(data, dry_run,
151
                                          _runcmd_fn=NotImplemented,
152
                                          _keyfiles=NotImplemented)
153
    self.assertEqual(os.listdir(self.tmpdir), [])
154

    
155
  def _TestDryRun(self, data):
156
    prepare_node_join.UpdateSshDaemon(data, True, _runcmd_fn=NotImplemented,
157
                                      _keyfiles=self.keyfiles)
158
    self.assertEqual(os.listdir(self.tmpdir), [])
159

    
160
  def testDryRunRsa(self):
161
    self._TestDryRun({
162
      constants.SSHS_SSH_HOST_KEY: [
163
        (constants.SSHK_RSA, "rsapriv", "rsapub"),
164
        ],
165
      })
166

    
167
  def testDryRunDsa(self):
168
    self._TestDryRun({
169
      constants.SSHS_SSH_HOST_KEY: [
170
        (constants.SSHK_DSA, "dsapriv", "dsapub"),
171
        ],
172
      })
173

    
174
  def _RunCmd(self, fail, cmd, interactive=NotImplemented):
175
    self.assertTrue(interactive)
176
    self.assertEqual(cmd, [pathutils.DAEMON_UTIL, "reload-ssh-keys"])
177
    if fail:
178
      exit_code = constants.EXIT_FAILURE
179
    else:
180
      exit_code = constants.EXIT_SUCCESS
181
    return utils.RunResult(exit_code, None, "stdout", "stderr",
182
                           utils.ShellQuoteArgs(cmd),
183
                           NotImplemented, NotImplemented)
184

    
185
  def _TestUpdate(self, failcmd):
186
    data = {
187
      constants.SSHS_SSH_HOST_KEY: [
188
        (constants.SSHK_DSA, "dsapriv", "dsapub"),
189
        (constants.SSHK_RSA, "rsapriv", "rsapub"),
190
        ],
191
      }
192
    runcmd_fn = compat.partial(self._RunCmd, failcmd)
193
    if failcmd:
194
      self.assertRaises(_JoinError, prepare_node_join.UpdateSshDaemon,
195
                        data, False, _runcmd_fn=runcmd_fn,
196
                        _keyfiles=self.keyfiles)
197
    else:
198
      prepare_node_join.UpdateSshDaemon(data, False, _runcmd_fn=runcmd_fn,
199
                                        _keyfiles=self.keyfiles)
200
    self.assertEqual(sorted(os.listdir(self.tmpdir)), sorted([
201
      "rsa.public", "rsa.private",
202
      "dsa.public", "dsa.private",
203
      ]))
204
    self.assertEqual(utils.ReadFile(utils.PathJoin(self.tmpdir, "rsa.public")),
205
                     "rsapub")
206
    self.assertEqual(utils.ReadFile(utils.PathJoin(self.tmpdir, "rsa.private")),
207
                     "rsapriv")
208
    self.assertEqual(utils.ReadFile(utils.PathJoin(self.tmpdir, "dsa.public")),
209
                     "dsapub")
210
    self.assertEqual(utils.ReadFile(utils.PathJoin(self.tmpdir, "dsa.private")),
211
                     "dsapriv")
212

    
213
  def testSuccess(self):
214
    self._TestUpdate(False)
215

    
216
  def testFailure(self):
217
    self._TestUpdate(True)
218

    
219

    
220
class TestUpdateSshRoot(unittest.TestCase):
221
  def setUp(self):
222
    unittest.TestCase.setUp(self)
223
    self.tmpdir = tempfile.mkdtemp()
224
    self.sshdir = utils.PathJoin(self.tmpdir, ".ssh")
225

    
226
  def tearDown(self):
227
    unittest.TestCase.tearDown(self)
228
    shutil.rmtree(self.tmpdir)
229

    
230
  def _GetHomeDir(self, user):
231
    self.assertEqual(user, constants.SSH_LOGIN_USER)
232
    return self.tmpdir
233

    
234
  def testNoKeys(self):
235
    data_empty_keys = {
236
      constants.SSHS_SSH_ROOT_KEY: [],
237
      }
238

    
239
    for data in [{}, data_empty_keys]:
240
      for dry_run in [False, True]:
241
        prepare_node_join.UpdateSshRoot(data, dry_run,
242
                                        _homedir_fn=NotImplemented)
243
    self.assertEqual(os.listdir(self.tmpdir), [])
244

    
245
  def testDryRun(self):
246
    data = {
247
      constants.SSHS_SSH_ROOT_KEY: [
248
        (constants.SSHK_RSA, "aaa", "bbb"),
249
        ]
250
      }
251

    
252
    prepare_node_join.UpdateSshRoot(data, True,
253
                                    _homedir_fn=self._GetHomeDir)
254
    self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
255
    self.assertEqual(os.listdir(self.sshdir), [])
256

    
257
  def testUpdate(self):
258
    data = {
259
      constants.SSHS_SSH_ROOT_KEY: [
260
        (constants.SSHK_DSA, "privatedsa", "ssh-dss pubdsa"),
261
        ]
262
      }
263

    
264
    prepare_node_join.UpdateSshRoot(data, False,
265
                                    _homedir_fn=self._GetHomeDir)
266
    self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
267
    self.assertEqual(sorted(os.listdir(self.sshdir)),
268
                     sorted(["authorized_keys", "id_dsa", "id_dsa.pub"]))
269
    self.assertEqual(utils.ReadFile(utils.PathJoin(self.sshdir, "id_dsa")),
270
                     "privatedsa")
271
    self.assertEqual(utils.ReadFile(utils.PathJoin(self.sshdir, "id_dsa.pub")),
272
                     "ssh-dss pubdsa")
273
    self.assertEqual(utils.ReadFile(utils.PathJoin(self.sshdir,
274
                                                   "authorized_keys")),
275
                     "ssh-dss pubdsa\n")
276

    
277

    
278
if __name__ == "__main__":
279
  testutils.GanetiTestProgram()