root / test / ganeti.tools.prepare_node_join_unittest.py @ 0602cef3
History | View | Annotate | Download (8.7 kB)
1 |
#!/usr/bin/python
|
---|---|
2 |
#
|
3 |
|
4 |
# Copyright (C) 2012 Google Inc.
|
5 |
#
|
6 |
# This program is free software; you can redistribute it and/or modify
|
7 |
# it under the terms of the GNU General Public License as published by
|
8 |
# the Free Software Foundation; either version 2 of the License, or
|
9 |
# (at your option) any later version.
|
10 |
#
|
11 |
# This program is distributed in the hope that it will be useful, but
|
12 |
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
13 |
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
14 |
# General Public License for more details.
|
15 |
#
|
16 |
# You should have received a copy of the GNU General Public License
|
17 |
# along with this program; if not, write to the Free Software
|
18 |
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
|
19 |
# 02110-1301, USA.
|
20 |
|
21 |
|
22 |
"""Script for testing ganeti.tools.prepare_node_join"""
|
23 |
|
24 |
import unittest |
25 |
import shutil |
26 |
import tempfile |
27 |
import os.path |
28 |
import OpenSSL |
29 |
|
30 |
from ganeti import errors |
31 |
from ganeti import constants |
32 |
from ganeti import serializer |
33 |
from ganeti import pathutils |
34 |
from ganeti import compat |
35 |
from ganeti import utils |
36 |
from ganeti.tools import prepare_node_join |
37 |
|
38 |
import testutils |
39 |
|
40 |
|
41 |
_JoinError = prepare_node_join.JoinError |
42 |
|
43 |
|
44 |
class TestLoadData(unittest.TestCase): |
45 |
def testNoJson(self): |
46 |
self.assertRaises(errors.ParseError, prepare_node_join.LoadData, "") |
47 |
self.assertRaises(errors.ParseError, prepare_node_join.LoadData, "}") |
48 |
|
49 |
def testInvalidDataStructure(self): |
50 |
raw = serializer.DumpJson({ |
51 |
"some other thing": False, |
52 |
}) |
53 |
self.assertRaises(errors.ParseError, prepare_node_join.LoadData, raw)
|
54 |
|
55 |
raw = serializer.DumpJson([]) |
56 |
self.assertRaises(errors.ParseError, prepare_node_join.LoadData, raw)
|
57 |
|
58 |
def testValidData(self): |
59 |
raw = serializer.DumpJson({}) |
60 |
self.assertEqual(prepare_node_join.LoadData(raw), {})
|
61 |
|
62 |
|
63 |
class TestVerifyCertificate(testutils.GanetiTestCase): |
64 |
def setUp(self): |
65 |
testutils.GanetiTestCase.setUp(self)
|
66 |
self.tmpdir = tempfile.mkdtemp()
|
67 |
|
68 |
def tearDown(self): |
69 |
testutils.GanetiTestCase.tearDown(self)
|
70 |
shutil.rmtree(self.tmpdir)
|
71 |
|
72 |
def testNoCert(self): |
73 |
prepare_node_join.VerifyCertificate({}, _verify_fn=NotImplemented)
|
74 |
|
75 |
def testGivenPrivateKey(self): |
76 |
cert_filename = self._TestDataFilename("cert2.pem") |
77 |
cert_pem = utils.ReadFile(cert_filename) |
78 |
|
79 |
self.assertRaises(_JoinError, prepare_node_join._VerifyCertificate,
|
80 |
cert_pem, _check_fn=NotImplemented)
|
81 |
|
82 |
def testInvalidCertificate(self): |
83 |
self.assertRaises(errors.X509CertError,
|
84 |
prepare_node_join._VerifyCertificate, |
85 |
"Something that's not a certificate",
|
86 |
_check_fn=NotImplemented)
|
87 |
|
88 |
@staticmethod
|
89 |
def _Check(cert): |
90 |
assert cert.get_subject()
|
91 |
|
92 |
def testSuccessfulCheck(self): |
93 |
cert_filename = self._TestDataFilename("cert1.pem") |
94 |
cert_pem = utils.ReadFile(cert_filename) |
95 |
prepare_node_join._VerifyCertificate(cert_pem, _check_fn=self._Check)
|
96 |
|
97 |
|
98 |
class TestVerifyClusterName(unittest.TestCase): |
99 |
def setUp(self): |
100 |
unittest.TestCase.setUp(self)
|
101 |
self.tmpdir = tempfile.mkdtemp()
|
102 |
|
103 |
def tearDown(self): |
104 |
unittest.TestCase.tearDown(self)
|
105 |
shutil.rmtree(self.tmpdir)
|
106 |
|
107 |
def testNoName(self): |
108 |
self.assertRaises(_JoinError, prepare_node_join.VerifyClusterName,
|
109 |
{}, _verify_fn=NotImplemented)
|
110 |
|
111 |
@staticmethod
|
112 |
def _FailingVerify(name): |
113 |
assert name == "cluster.example.com" |
114 |
raise errors.GenericError()
|
115 |
|
116 |
def testFailingVerification(self): |
117 |
data = { |
118 |
constants.SSHS_CLUSTER_NAME: "cluster.example.com",
|
119 |
} |
120 |
|
121 |
self.assertRaises(errors.GenericError, prepare_node_join.VerifyClusterName,
|
122 |
data, _verify_fn=self._FailingVerify)
|
123 |
|
124 |
|
125 |
class TestUpdateSshDaemon(unittest.TestCase): |
126 |
def setUp(self): |
127 |
unittest.TestCase.setUp(self)
|
128 |
self.tmpdir = tempfile.mkdtemp()
|
129 |
|
130 |
self.keyfiles = {
|
131 |
constants.SSHK_RSA: |
132 |
(utils.PathJoin(self.tmpdir, "rsa.private"), |
133 |
utils.PathJoin(self.tmpdir, "rsa.public")), |
134 |
constants.SSHK_DSA: |
135 |
(utils.PathJoin(self.tmpdir, "dsa.private"), |
136 |
utils.PathJoin(self.tmpdir, "dsa.public")), |
137 |
} |
138 |
|
139 |
def tearDown(self): |
140 |
unittest.TestCase.tearDown(self)
|
141 |
shutil.rmtree(self.tmpdir)
|
142 |
|
143 |
def testNoKeys(self): |
144 |
data_empty_keys = { |
145 |
constants.SSHS_SSH_HOST_KEY: [], |
146 |
} |
147 |
|
148 |
for data in [{}, data_empty_keys]: |
149 |
for dry_run in [False, True]: |
150 |
prepare_node_join.UpdateSshDaemon(data, dry_run, |
151 |
_runcmd_fn=NotImplemented,
|
152 |
_keyfiles=NotImplemented)
|
153 |
self.assertEqual(os.listdir(self.tmpdir), []) |
154 |
|
155 |
def _TestDryRun(self, data): |
156 |
prepare_node_join.UpdateSshDaemon(data, True, _runcmd_fn=NotImplemented, |
157 |
_keyfiles=self.keyfiles)
|
158 |
self.assertEqual(os.listdir(self.tmpdir), []) |
159 |
|
160 |
def testDryRunRsa(self): |
161 |
self._TestDryRun({
|
162 |
constants.SSHS_SSH_HOST_KEY: [ |
163 |
(constants.SSHK_RSA, "rsapriv", "rsapub"), |
164 |
], |
165 |
}) |
166 |
|
167 |
def testDryRunDsa(self): |
168 |
self._TestDryRun({
|
169 |
constants.SSHS_SSH_HOST_KEY: [ |
170 |
(constants.SSHK_DSA, "dsapriv", "dsapub"), |
171 |
], |
172 |
}) |
173 |
|
174 |
def _RunCmd(self, fail, cmd, interactive=NotImplemented): |
175 |
self.assertTrue(interactive)
|
176 |
self.assertEqual(cmd, [pathutils.DAEMON_UTIL, "reload-ssh-keys"]) |
177 |
if fail:
|
178 |
exit_code = constants.EXIT_FAILURE |
179 |
else:
|
180 |
exit_code = constants.EXIT_SUCCESS |
181 |
return utils.RunResult(exit_code, None, "stdout", "stderr", |
182 |
utils.ShellQuoteArgs(cmd), |
183 |
NotImplemented, NotImplemented) |
184 |
|
185 |
def _TestUpdate(self, failcmd): |
186 |
data = { |
187 |
constants.SSHS_SSH_HOST_KEY: [ |
188 |
(constants.SSHK_DSA, "dsapriv", "dsapub"), |
189 |
(constants.SSHK_RSA, "rsapriv", "rsapub"), |
190 |
], |
191 |
} |
192 |
runcmd_fn = compat.partial(self._RunCmd, failcmd)
|
193 |
if failcmd:
|
194 |
self.assertRaises(_JoinError, prepare_node_join.UpdateSshDaemon,
|
195 |
data, False, _runcmd_fn=runcmd_fn,
|
196 |
_keyfiles=self.keyfiles)
|
197 |
else:
|
198 |
prepare_node_join.UpdateSshDaemon(data, False, _runcmd_fn=runcmd_fn,
|
199 |
_keyfiles=self.keyfiles)
|
200 |
self.assertEqual(sorted(os.listdir(self.tmpdir)), sorted([ |
201 |
"rsa.public", "rsa.private", |
202 |
"dsa.public", "dsa.private", |
203 |
])) |
204 |
self.assertEqual(utils.ReadFile(utils.PathJoin(self.tmpdir, "rsa.public")), |
205 |
"rsapub")
|
206 |
self.assertEqual(utils.ReadFile(utils.PathJoin(self.tmpdir, "rsa.private")), |
207 |
"rsapriv")
|
208 |
self.assertEqual(utils.ReadFile(utils.PathJoin(self.tmpdir, "dsa.public")), |
209 |
"dsapub")
|
210 |
self.assertEqual(utils.ReadFile(utils.PathJoin(self.tmpdir, "dsa.private")), |
211 |
"dsapriv")
|
212 |
|
213 |
def testSuccess(self): |
214 |
self._TestUpdate(False) |
215 |
|
216 |
def testFailure(self): |
217 |
self._TestUpdate(True) |
218 |
|
219 |
|
220 |
class TestUpdateSshRoot(unittest.TestCase): |
221 |
def setUp(self): |
222 |
unittest.TestCase.setUp(self)
|
223 |
self.tmpdir = tempfile.mkdtemp()
|
224 |
self.sshdir = utils.PathJoin(self.tmpdir, ".ssh") |
225 |
|
226 |
def tearDown(self): |
227 |
unittest.TestCase.tearDown(self)
|
228 |
shutil.rmtree(self.tmpdir)
|
229 |
|
230 |
def _GetHomeDir(self, user): |
231 |
self.assertEqual(user, constants.SSH_LOGIN_USER)
|
232 |
return self.tmpdir |
233 |
|
234 |
def testNoKeys(self): |
235 |
data_empty_keys = { |
236 |
constants.SSHS_SSH_ROOT_KEY: [], |
237 |
} |
238 |
|
239 |
for data in [{}, data_empty_keys]: |
240 |
for dry_run in [False, True]: |
241 |
prepare_node_join.UpdateSshRoot(data, dry_run, |
242 |
_homedir_fn=NotImplemented)
|
243 |
self.assertEqual(os.listdir(self.tmpdir), []) |
244 |
|
245 |
def testDryRun(self): |
246 |
data = { |
247 |
constants.SSHS_SSH_ROOT_KEY: [ |
248 |
(constants.SSHK_RSA, "aaa", "bbb"), |
249 |
] |
250 |
} |
251 |
|
252 |
prepare_node_join.UpdateSshRoot(data, True,
|
253 |
_homedir_fn=self._GetHomeDir)
|
254 |
self.assertEqual(os.listdir(self.tmpdir), [".ssh"]) |
255 |
self.assertEqual(os.listdir(self.sshdir), []) |
256 |
|
257 |
def testUpdate(self): |
258 |
data = { |
259 |
constants.SSHS_SSH_ROOT_KEY: [ |
260 |
(constants.SSHK_DSA, "privatedsa", "ssh-dss pubdsa"), |
261 |
] |
262 |
} |
263 |
|
264 |
prepare_node_join.UpdateSshRoot(data, False,
|
265 |
_homedir_fn=self._GetHomeDir)
|
266 |
self.assertEqual(os.listdir(self.tmpdir), [".ssh"]) |
267 |
self.assertEqual(sorted(os.listdir(self.sshdir)), |
268 |
sorted(["authorized_keys", "id_dsa", "id_dsa.pub"])) |
269 |
self.assertEqual(utils.ReadFile(utils.PathJoin(self.sshdir, "id_dsa")), |
270 |
"privatedsa")
|
271 |
self.assertEqual(utils.ReadFile(utils.PathJoin(self.sshdir, "id_dsa.pub")), |
272 |
"ssh-dss pubdsa")
|
273 |
self.assertEqual(utils.ReadFile(utils.PathJoin(self.sshdir, |
274 |
"authorized_keys")),
|
275 |
"ssh-dss pubdsa\n")
|
276 |
|
277 |
|
278 |
if __name__ == "__main__": |
279 |
testutils.GanetiTestProgram() |