root / test / ganeti.tools.prepare_node_join_unittest.py @ d12b9f66
History | View | Annotate | Download (10.3 kB)
1 |
#!/usr/bin/python
|
---|---|
2 |
#
|
3 |
|
4 |
# Copyright (C) 2012 Google Inc.
|
5 |
#
|
6 |
# This program is free software; you can redistribute it and/or modify
|
7 |
# it under the terms of the GNU General Public License as published by
|
8 |
# the Free Software Foundation; either version 2 of the License, or
|
9 |
# (at your option) any later version.
|
10 |
#
|
11 |
# This program is distributed in the hope that it will be useful, but
|
12 |
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
13 |
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
14 |
# General Public License for more details.
|
15 |
#
|
16 |
# You should have received a copy of the GNU General Public License
|
17 |
# along with this program; if not, write to the Free Software
|
18 |
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
|
19 |
# 02110-1301, USA.
|
20 |
|
21 |
|
22 |
"""Script for testing ganeti.tools.prepare_node_join"""
|
23 |
|
24 |
import unittest |
25 |
import shutil |
26 |
import tempfile |
27 |
import os.path |
28 |
import OpenSSL |
29 |
|
30 |
from ganeti import errors |
31 |
from ganeti import constants |
32 |
from ganeti import serializer |
33 |
from ganeti import pathutils |
34 |
from ganeti import compat |
35 |
from ganeti import utils |
36 |
from ganeti.tools import prepare_node_join |
37 |
|
38 |
import testutils |
39 |
|
40 |
|
41 |
_JoinError = prepare_node_join.JoinError |
42 |
|
43 |
|
44 |
class TestLoadData(unittest.TestCase): |
45 |
def testNoJson(self): |
46 |
self.assertRaises(errors.ParseError, prepare_node_join.LoadData, "") |
47 |
self.assertRaises(errors.ParseError, prepare_node_join.LoadData, "}") |
48 |
|
49 |
def testInvalidDataStructure(self): |
50 |
raw = serializer.DumpJson({ |
51 |
"some other thing": False, |
52 |
}) |
53 |
self.assertRaises(errors.ParseError, prepare_node_join.LoadData, raw)
|
54 |
|
55 |
raw = serializer.DumpJson([]) |
56 |
self.assertRaises(errors.ParseError, prepare_node_join.LoadData, raw)
|
57 |
|
58 |
def testValidData(self): |
59 |
raw = serializer.DumpJson({}) |
60 |
self.assertEqual(prepare_node_join.LoadData(raw), {})
|
61 |
|
62 |
|
63 |
class TestVerifyCertificate(testutils.GanetiTestCase): |
64 |
def setUp(self): |
65 |
testutils.GanetiTestCase.setUp(self)
|
66 |
self.tmpdir = tempfile.mkdtemp()
|
67 |
|
68 |
def tearDown(self): |
69 |
testutils.GanetiTestCase.tearDown(self)
|
70 |
shutil.rmtree(self.tmpdir)
|
71 |
|
72 |
def testNoCert(self): |
73 |
prepare_node_join.VerifyCertificate({}, _verify_fn=NotImplemented)
|
74 |
|
75 |
def testMismatchingKey(self): |
76 |
other_cert = self._TestDataFilename("cert1.pem") |
77 |
node_cert = self._TestDataFilename("cert2.pem") |
78 |
|
79 |
self.assertRaises(_JoinError, prepare_node_join._VerifyCertificate,
|
80 |
utils.ReadFile(other_cert), _noded_cert_file=node_cert) |
81 |
|
82 |
def testGivenPrivateKey(self): |
83 |
cert_filename = self._TestDataFilename("cert2.pem") |
84 |
cert_pem = utils.ReadFile(cert_filename) |
85 |
|
86 |
self.assertRaises(_JoinError, prepare_node_join._VerifyCertificate,
|
87 |
cert_pem, _noded_cert_file=cert_filename) |
88 |
|
89 |
def testMatchingKey(self): |
90 |
cert_filename = self._TestDataFilename("cert2.pem") |
91 |
|
92 |
# Extract certificate
|
93 |
cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, |
94 |
utils.ReadFile(cert_filename)) |
95 |
cert_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, |
96 |
cert) |
97 |
|
98 |
prepare_node_join._VerifyCertificate(cert_pem, |
99 |
_noded_cert_file=cert_filename) |
100 |
|
101 |
def testMissingFile(self): |
102 |
cert = self._TestDataFilename("cert1.pem") |
103 |
nodecert = utils.PathJoin(self.tmpdir, "does-not-exist") |
104 |
prepare_node_join._VerifyCertificate(utils.ReadFile(cert), |
105 |
_noded_cert_file=nodecert) |
106 |
|
107 |
def testInvalidCertificate(self): |
108 |
self.assertRaises(errors.X509CertError,
|
109 |
prepare_node_join._VerifyCertificate, |
110 |
"Something that's not a certificate",
|
111 |
_noded_cert_file=NotImplemented)
|
112 |
|
113 |
def testNoPrivateKey(self): |
114 |
cert = self._TestDataFilename("cert1.pem") |
115 |
self.assertRaises(errors.X509CertError,
|
116 |
prepare_node_join._VerifyCertificate, |
117 |
utils.ReadFile(cert), _noded_cert_file=cert) |
118 |
|
119 |
|
120 |
class TestVerifyClusterName(unittest.TestCase): |
121 |
def setUp(self): |
122 |
unittest.TestCase.setUp(self)
|
123 |
self.tmpdir = tempfile.mkdtemp()
|
124 |
|
125 |
def tearDown(self): |
126 |
unittest.TestCase.tearDown(self)
|
127 |
shutil.rmtree(self.tmpdir)
|
128 |
|
129 |
def testNoName(self): |
130 |
self.assertRaises(_JoinError, prepare_node_join.VerifyClusterName,
|
131 |
{}, _verify_fn=NotImplemented)
|
132 |
|
133 |
def testMissingFile(self): |
134 |
tmpfile = utils.PathJoin(self.tmpdir, "does-not-exist") |
135 |
prepare_node_join._VerifyClusterName(NotImplemented,
|
136 |
_ss_cluster_name_file=tmpfile) |
137 |
|
138 |
def testMatchingName(self): |
139 |
tmpfile = utils.PathJoin(self.tmpdir, "cluster_name") |
140 |
|
141 |
for content in ["cluster.example.com", "cluster.example.com\n\n"]: |
142 |
utils.WriteFile(tmpfile, data=content) |
143 |
prepare_node_join._VerifyClusterName("cluster.example.com",
|
144 |
_ss_cluster_name_file=tmpfile) |
145 |
|
146 |
def testNameMismatch(self): |
147 |
tmpfile = utils.PathJoin(self.tmpdir, "cluster_name") |
148 |
|
149 |
for content in ["something.example.com", "foobar\n\ncluster.example.com"]: |
150 |
utils.WriteFile(tmpfile, data=content) |
151 |
self.assertRaises(_JoinError, prepare_node_join._VerifyClusterName,
|
152 |
"cluster.example.com", _ss_cluster_name_file=tmpfile)
|
153 |
|
154 |
|
155 |
class TestUpdateSshDaemon(unittest.TestCase): |
156 |
def setUp(self): |
157 |
unittest.TestCase.setUp(self)
|
158 |
self.tmpdir = tempfile.mkdtemp()
|
159 |
|
160 |
self.keyfiles = {
|
161 |
constants.SSHK_RSA: |
162 |
(utils.PathJoin(self.tmpdir, "rsa.public"), |
163 |
utils.PathJoin(self.tmpdir, "rsa.private")), |
164 |
constants.SSHK_DSA: |
165 |
(utils.PathJoin(self.tmpdir, "dsa.public"), |
166 |
utils.PathJoin(self.tmpdir, "dsa.private")), |
167 |
} |
168 |
|
169 |
def tearDown(self): |
170 |
unittest.TestCase.tearDown(self)
|
171 |
shutil.rmtree(self.tmpdir)
|
172 |
|
173 |
def testNoKeys(self): |
174 |
data_empty_keys = { |
175 |
constants.SSHS_SSH_HOST_KEY: [], |
176 |
} |
177 |
|
178 |
for data in [{}, data_empty_keys]: |
179 |
for dry_run in [False, True]: |
180 |
prepare_node_join.UpdateSshDaemon(data, dry_run, |
181 |
_runcmd_fn=NotImplemented,
|
182 |
_keyfiles=NotImplemented)
|
183 |
self.assertEqual(os.listdir(self.tmpdir), []) |
184 |
|
185 |
def _TestDryRun(self, data): |
186 |
prepare_node_join.UpdateSshDaemon(data, True, _runcmd_fn=NotImplemented, |
187 |
_keyfiles=self.keyfiles)
|
188 |
self.assertEqual(os.listdir(self.tmpdir), []) |
189 |
|
190 |
def testDryRunRsa(self): |
191 |
self._TestDryRun({
|
192 |
constants.SSHS_SSH_HOST_KEY: [ |
193 |
(constants.SSHK_RSA, "rsapub", "rsapriv"), |
194 |
], |
195 |
}) |
196 |
|
197 |
def testDryRunDsa(self): |
198 |
self._TestDryRun({
|
199 |
constants.SSHS_SSH_HOST_KEY: [ |
200 |
(constants.SSHK_DSA, "dsapub", "dsapriv"), |
201 |
], |
202 |
}) |
203 |
|
204 |
def _RunCmd(self, fail, cmd, interactive=NotImplemented): |
205 |
self.assertTrue(interactive)
|
206 |
self.assertEqual(cmd, [pathutils.DAEMON_UTIL, "reload-ssh-keys"]) |
207 |
if fail:
|
208 |
exit_code = constants.EXIT_FAILURE |
209 |
else:
|
210 |
exit_code = constants.EXIT_SUCCESS |
211 |
return utils.RunResult(exit_code, None, "stdout", "stderr", |
212 |
utils.ShellQuoteArgs(cmd), |
213 |
NotImplemented, NotImplemented) |
214 |
|
215 |
def _TestUpdate(self, failcmd): |
216 |
data = { |
217 |
constants.SSHS_SSH_HOST_KEY: [ |
218 |
(constants.SSHK_DSA, "dsapub", "dsapriv"), |
219 |
(constants.SSHK_RSA, "rsapub", "rsapriv"), |
220 |
], |
221 |
} |
222 |
runcmd_fn = compat.partial(self._RunCmd, failcmd)
|
223 |
if failcmd:
|
224 |
self.assertRaises(_JoinError, prepare_node_join.UpdateSshDaemon,
|
225 |
data, False, _runcmd_fn=runcmd_fn,
|
226 |
_keyfiles=self.keyfiles)
|
227 |
else:
|
228 |
prepare_node_join.UpdateSshDaemon(data, False, _runcmd_fn=runcmd_fn,
|
229 |
_keyfiles=self.keyfiles)
|
230 |
self.assertEqual(sorted(os.listdir(self.tmpdir)), sorted([ |
231 |
"rsa.private", "rsa.public", |
232 |
"dsa.private", "dsa.public", |
233 |
])) |
234 |
self.assertEqual(utils.ReadFile(utils.PathJoin(self.tmpdir, "rsa.public")), |
235 |
"rsapub")
|
236 |
self.assertEqual(utils.ReadFile(utils.PathJoin(self.tmpdir, "rsa.private")), |
237 |
"rsapriv")
|
238 |
self.assertEqual(utils.ReadFile(utils.PathJoin(self.tmpdir, "dsa.public")), |
239 |
"dsapub")
|
240 |
self.assertEqual(utils.ReadFile(utils.PathJoin(self.tmpdir, "dsa.private")), |
241 |
"dsapriv")
|
242 |
|
243 |
def testSuccess(self): |
244 |
self._TestUpdate(False) |
245 |
|
246 |
def testFailure(self): |
247 |
self._TestUpdate(True) |
248 |
|
249 |
|
250 |
class TestUpdateSshRoot(unittest.TestCase): |
251 |
def setUp(self): |
252 |
unittest.TestCase.setUp(self)
|
253 |
self.tmpdir = tempfile.mkdtemp()
|
254 |
self.sshdir = utils.PathJoin(self.tmpdir, ".ssh") |
255 |
|
256 |
def tearDown(self): |
257 |
unittest.TestCase.tearDown(self)
|
258 |
shutil.rmtree(self.tmpdir)
|
259 |
|
260 |
def _GetHomeDir(self, user): |
261 |
self.assertEqual(user, constants.SSH_LOGIN_USER)
|
262 |
return self.tmpdir |
263 |
|
264 |
def testNoKeys(self): |
265 |
data_empty_keys = { |
266 |
constants.SSHS_SSH_ROOT_KEY: [], |
267 |
} |
268 |
|
269 |
for data in [{}, data_empty_keys]: |
270 |
for dry_run in [False, True]: |
271 |
prepare_node_join.UpdateSshRoot(data, dry_run, |
272 |
_homedir_fn=NotImplemented)
|
273 |
self.assertEqual(os.listdir(self.tmpdir), []) |
274 |
|
275 |
def testDryRun(self): |
276 |
data = { |
277 |
constants.SSHS_SSH_ROOT_KEY: [ |
278 |
(constants.SSHK_RSA, "aaa", "bbb"), |
279 |
] |
280 |
} |
281 |
|
282 |
prepare_node_join.UpdateSshRoot(data, True,
|
283 |
_homedir_fn=self._GetHomeDir)
|
284 |
self.assertEqual(os.listdir(self.tmpdir), [".ssh"]) |
285 |
self.assertEqual(os.listdir(self.sshdir), []) |
286 |
|
287 |
def testUpdate(self): |
288 |
data = { |
289 |
constants.SSHS_SSH_ROOT_KEY: [ |
290 |
(constants.SSHK_DSA, "pubdsa", "privatedsa"), |
291 |
] |
292 |
} |
293 |
|
294 |
prepare_node_join.UpdateSshRoot(data, False,
|
295 |
_homedir_fn=self._GetHomeDir)
|
296 |
self.assertEqual(os.listdir(self.tmpdir), [".ssh"]) |
297 |
self.assertEqual(sorted(os.listdir(self.sshdir)), |
298 |
sorted(["authorized_keys", "id_dsa", "id_dsa.pub"])) |
299 |
self.assertEqual(utils.ReadFile(utils.PathJoin(self.sshdir, "id_dsa")), |
300 |
"privatedsa")
|
301 |
self.assertEqual(utils.ReadFile(utils.PathJoin(self.sshdir, "id_dsa.pub")), |
302 |
"pubdsa")
|
303 |
self.assertEqual(utils.ReadFile(utils.PathJoin(self.sshdir, |
304 |
"authorized_keys")),
|
305 |
"ssh-dss pubdsa\n")
|
306 |
|
307 |
|
308 |
if __name__ == "__main__": |
309 |
testutils.GanetiTestProgram() |