4 # Copyright (C) 2010, 2011 Google Inc.
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 # General Public License for more details.
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
22 """Script for testing ganeti.rpc"""
28 from ganeti import constants
29 from ganeti import compat
30 from ganeti import rpc
31 from ganeti import rpc_defs
32 from ganeti import http
33 from ganeti import errors
34 from ganeti import serializer
35 from ganeti import objects
40 class _FakeRequestProcessor:
41 def __init__(self, response_fn):
42 self._response_fn = response_fn
45 def __call__(self, reqs, lock_monitor_cb=None):
46 assert lock_monitor_cb is None or callable(lock_monitor_cb)
49 self._response_fn(req)
52 def GetFakeSimpleStoreClass(fn):
53 class FakeSimpleStore:
54 GetNodePrimaryIPList = fn
55 GetPrimaryIPFamily = lambda _: None
57 return FakeSimpleStore
60 class TestRpcProcessor(unittest.TestCase):
61 def _FakeAddressLookup(self, map):
62 return lambda node_list: [map.get(node) for node in node_list]
64 def _GetVersionResponse(self, req):
65 self.assertEqual(req.host, "127.0.0.1")
66 self.assertEqual(req.port, 24094)
67 self.assertEqual(req.path, "/version")
68 self.assertEqual(req.read_timeout, rpc._TMO_URGENT)
70 req.resp_status_code = http.HTTP_OK
71 req.resp_body = serializer.DumpJson((True, 123))
73 def testVersionSuccess(self):
74 resolver = rpc._StaticResolver(["127.0.0.1"])
75 http_proc = _FakeRequestProcessor(self._GetVersionResponse)
76 proc = rpc._RpcProcessor(resolver, 24094)
77 result = proc(["localhost"], "version", {"localhost": ""}, 60,
78 NotImplemented, _req_process_fn=http_proc)
79 self.assertEqual(result.keys(), ["localhost"])
80 lhresp = result["localhost"]
81 self.assertFalse(lhresp.offline)
82 self.assertEqual(lhresp.node, "localhost")
83 self.assertFalse(lhresp.fail_msg)
84 self.assertEqual(lhresp.payload, 123)
85 self.assertEqual(lhresp.call, "version")
86 lhresp.Raise("should not raise")
87 self.assertEqual(http_proc.reqcount, 1)
89 def _ReadTimeoutResponse(self, req):
90 self.assertEqual(req.host, "192.0.2.13")
91 self.assertEqual(req.port, 19176)
92 self.assertEqual(req.path, "/version")
93 self.assertEqual(req.read_timeout, 12356)
95 req.resp_status_code = http.HTTP_OK
96 req.resp_body = serializer.DumpJson((True, -1))
98 def testReadTimeout(self):
99 resolver = rpc._StaticResolver(["192.0.2.13"])
100 http_proc = _FakeRequestProcessor(self._ReadTimeoutResponse)
101 proc = rpc._RpcProcessor(resolver, 19176)
104 result = proc([host], "version", body, 12356, NotImplemented,
105 _req_process_fn=http_proc)
106 self.assertEqual(result.keys(), [host])
107 lhresp = result[host]
108 self.assertFalse(lhresp.offline)
109 self.assertEqual(lhresp.node, host)
110 self.assertFalse(lhresp.fail_msg)
111 self.assertEqual(lhresp.payload, -1)
112 self.assertEqual(lhresp.call, "version")
113 lhresp.Raise("should not raise")
114 self.assertEqual(http_proc.reqcount, 1)
116 def testOfflineNode(self):
117 resolver = rpc._StaticResolver([rpc._OFFLINE])
118 http_proc = _FakeRequestProcessor(NotImplemented)
119 proc = rpc._RpcProcessor(resolver, 30668)
122 result = proc([host], "version", body, 60, NotImplemented,
123 _req_process_fn=http_proc)
124 self.assertEqual(result.keys(), [host])
125 lhresp = result[host]
126 self.assertTrue(lhresp.offline)
127 self.assertEqual(lhresp.node, host)
128 self.assertTrue(lhresp.fail_msg)
129 self.assertFalse(lhresp.payload)
130 self.assertEqual(lhresp.call, "version")
133 self.assertRaises(errors.OpExecError, lhresp.Raise, "should raise")
136 self.assertRaises(errors.OpExecError, lhresp.Raise, None)
138 self.assertEqual(http_proc.reqcount, 0)
140 def _GetMultiVersionResponse(self, req):
141 self.assert_(req.host.startswith("node"))
142 self.assertEqual(req.port, 23245)
143 self.assertEqual(req.path, "/version")
145 req.resp_status_code = http.HTTP_OK
146 req.resp_body = serializer.DumpJson((True, 987))
148 def testMultiVersionSuccess(self):
149 nodes = ["node%s" % i for i in range(50)]
150 body = dict((n, "") for n in nodes)
151 resolver = rpc._StaticResolver(nodes)
152 http_proc = _FakeRequestProcessor(self._GetMultiVersionResponse)
153 proc = rpc._RpcProcessor(resolver, 23245)
154 result = proc(nodes, "version", body, 60, NotImplemented,
155 _req_process_fn=http_proc)
156 self.assertEqual(sorted(result.keys()), sorted(nodes))
159 lhresp = result[name]
160 self.assertFalse(lhresp.offline)
161 self.assertEqual(lhresp.node, name)
162 self.assertFalse(lhresp.fail_msg)
163 self.assertEqual(lhresp.payload, 987)
164 self.assertEqual(lhresp.call, "version")
165 lhresp.Raise("should not raise")
167 self.assertEqual(http_proc.reqcount, len(nodes))
169 def _GetVersionResponseFail(self, errinfo, req):
170 self.assertEqual(req.path, "/version")
172 req.resp_status_code = http.HTTP_OK
173 req.resp_body = serializer.DumpJson((False, errinfo))
175 def testVersionFailure(self):
176 resolver = rpc._StaticResolver(["aef9ur4i.example.com"])
177 proc = rpc._RpcProcessor(resolver, 5903)
178 for errinfo in [None, "Unknown error"]:
180 _FakeRequestProcessor(compat.partial(self._GetVersionResponseFail,
182 host = "aef9ur4i.example.com"
184 result = proc(body.keys(), "version", body, 60, NotImplemented,
185 _req_process_fn=http_proc)
186 self.assertEqual(result.keys(), [host])
187 lhresp = result[host]
188 self.assertFalse(lhresp.offline)
189 self.assertEqual(lhresp.node, host)
190 self.assert_(lhresp.fail_msg)
191 self.assertFalse(lhresp.payload)
192 self.assertEqual(lhresp.call, "version")
193 self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
194 self.assertEqual(http_proc.reqcount, 1)
196 def _GetHttpErrorResponse(self, httperrnodes, failnodes, req):
197 self.assertEqual(req.path, "/vg_list")
198 self.assertEqual(req.port, 15165)
200 if req.host in httperrnodes:
202 req.error = "Node set up for HTTP errors"
204 elif req.host in failnodes:
206 req.resp_status_code = 404
207 req.resp_body = serializer.DumpJson({
209 "message": "Method not found",
210 "explain": "Explanation goes here",
214 req.resp_status_code = http.HTTP_OK
215 req.resp_body = serializer.DumpJson((True, hash(req.host)))
217 def testHttpError(self):
218 nodes = ["uaf6pbbv%s" % i for i in range(50)]
219 body = dict((n, "") for n in nodes)
220 resolver = rpc._StaticResolver(nodes)
222 httperrnodes = set(nodes[1::7])
223 self.assertEqual(len(httperrnodes), 7)
225 failnodes = set(nodes[2::3]) - httperrnodes
226 self.assertEqual(len(failnodes), 14)
228 self.assertEqual(len(set(nodes) - failnodes - httperrnodes), 29)
230 proc = rpc._RpcProcessor(resolver, 15165)
232 _FakeRequestProcessor(compat.partial(self._GetHttpErrorResponse,
233 httperrnodes, failnodes))
234 result = proc(nodes, "vg_list", body, rpc._TMO_URGENT, NotImplemented,
235 _req_process_fn=http_proc)
236 self.assertEqual(sorted(result.keys()), sorted(nodes))
239 lhresp = result[name]
240 self.assertFalse(lhresp.offline)
241 self.assertEqual(lhresp.node, name)
242 self.assertEqual(lhresp.call, "vg_list")
244 if name in httperrnodes:
245 self.assert_(lhresp.fail_msg)
246 self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
247 elif name in failnodes:
248 self.assert_(lhresp.fail_msg)
249 self.assertRaises(errors.OpPrereqError, lhresp.Raise, "failed",
250 prereq=True, ecode=errors.ECODE_INVAL)
252 self.assertFalse(lhresp.fail_msg)
253 self.assertEqual(lhresp.payload, hash(name))
254 lhresp.Raise("should not raise")
256 self.assertEqual(http_proc.reqcount, len(nodes))
258 def _GetInvalidResponseA(self, req):
259 self.assertEqual(req.path, "/version")
261 req.resp_status_code = http.HTTP_OK
262 req.resp_body = serializer.DumpJson(("This", "is", "an", "invalid",
263 "response", "!", 1, 2, 3))
265 def _GetInvalidResponseB(self, req):
266 self.assertEqual(req.path, "/version")
268 req.resp_status_code = http.HTTP_OK
269 req.resp_body = serializer.DumpJson("invalid response")
271 def testInvalidResponse(self):
272 resolver = rpc._StaticResolver(["oqo7lanhly.example.com"])
273 proc = rpc._RpcProcessor(resolver, 19978)
275 for fn in [self._GetInvalidResponseA, self._GetInvalidResponseB]:
276 http_proc = _FakeRequestProcessor(fn)
277 host = "oqo7lanhly.example.com"
279 result = proc([host], "version", body, 60, NotImplemented,
280 _req_process_fn=http_proc)
281 self.assertEqual(result.keys(), [host])
282 lhresp = result[host]
283 self.assertFalse(lhresp.offline)
284 self.assertEqual(lhresp.node, host)
285 self.assert_(lhresp.fail_msg)
286 self.assertFalse(lhresp.payload)
287 self.assertEqual(lhresp.call, "version")
288 self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
289 self.assertEqual(http_proc.reqcount, 1)
291 def _GetBodyTestResponse(self, test_data, req):
292 self.assertEqual(req.host, "192.0.2.84")
293 self.assertEqual(req.port, 18700)
294 self.assertEqual(req.path, "/upload_file")
295 self.assertEqual(serializer.LoadJson(req.post_data), test_data)
297 req.resp_status_code = http.HTTP_OK
298 req.resp_body = serializer.DumpJson((True, None))
300 def testResponseBody(self):
305 resolver = rpc._StaticResolver(["192.0.2.84"])
306 http_proc = _FakeRequestProcessor(compat.partial(self._GetBodyTestResponse,
308 proc = rpc._RpcProcessor(resolver, 18700)
310 body = {host: serializer.DumpJson(test_data)}
311 result = proc([host], "upload_file", body, 30, NotImplemented,
312 _req_process_fn=http_proc)
313 self.assertEqual(result.keys(), [host])
314 lhresp = result[host]
315 self.assertFalse(lhresp.offline)
316 self.assertEqual(lhresp.node, host)
317 self.assertFalse(lhresp.fail_msg)
318 self.assertEqual(lhresp.payload, None)
319 self.assertEqual(lhresp.call, "upload_file")
320 lhresp.Raise("should not raise")
321 self.assertEqual(http_proc.reqcount, 1)
324 class TestSsconfResolver(unittest.TestCase):
325 def testSsconfLookup(self):
326 addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
327 node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
328 node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
329 ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
330 result = rpc._SsconfResolver(node_list, NotImplemented,
331 ssc=ssc, nslookup_fn=NotImplemented)
332 self.assertEqual(result, zip(node_list, addr_list))
334 def testNsLookup(self):
335 addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
336 node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
337 ssc = GetFakeSimpleStoreClass(lambda _: [])
338 node_addr_map = dict(zip(node_list, addr_list))
339 nslookup_fn = lambda name, family=None: node_addr_map.get(name)
340 result = rpc._SsconfResolver(node_list, NotImplemented,
341 ssc=ssc, nslookup_fn=nslookup_fn)
342 self.assertEqual(result, zip(node_list, addr_list))
344 def testBothLookups(self):
345 addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
346 node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
347 n = len(addr_list) / 2
348 node_addr_list = [" ".join(t) for t in zip(node_list[n:], addr_list[n:])]
349 ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
350 node_addr_map = dict(zip(node_list[:n], addr_list[:n]))
351 nslookup_fn = lambda name, family=None: node_addr_map.get(name)
352 result = rpc._SsconfResolver(node_list, NotImplemented,
353 ssc=ssc, nslookup_fn=nslookup_fn)
354 self.assertEqual(result, zip(node_list, addr_list))
356 def testAddressLookupIPv6(self):
357 addr_list = ["2001:db8::%d" % n for n in range(0, 255, 11)]
358 node_list = ["node%d.example.com" % n for n in range(0, 255, 11)]
359 node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
360 ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
361 result = rpc._SsconfResolver(node_list, NotImplemented,
362 ssc=ssc, nslookup_fn=NotImplemented)
363 self.assertEqual(result, zip(node_list, addr_list))
366 class TestStaticResolver(unittest.TestCase):
368 addresses = ["192.0.2.%d" % n for n in range(0, 123, 7)]
369 nodes = ["node%s.example.com" % n for n in range(0, 123, 7)]
370 res = rpc._StaticResolver(addresses)
371 self.assertEqual(res(nodes, NotImplemented), zip(nodes, addresses))
373 def testWrongLength(self):
374 res = rpc._StaticResolver([])
375 self.assertRaises(AssertionError, res, ["abc"], NotImplemented)
378 class TestNodeConfigResolver(unittest.TestCase):
380 def _GetSingleOnlineNode(name):
381 assert name == "node90.example.com"
382 return objects.Node(name=name, offline=False, primary_ip="192.0.2.90")
385 def _GetSingleOfflineNode(name):
386 assert name == "node100.example.com"
387 return objects.Node(name=name, offline=True, primary_ip="192.0.2.100")
389 def testSingleOnline(self):
390 self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOnlineNode,
392 ["node90.example.com"], None),
393 [("node90.example.com", "192.0.2.90")])
395 def testSingleOffline(self):
396 self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOfflineNode,
398 ["node100.example.com"], None),
399 [("node100.example.com", rpc._OFFLINE)])
401 def testSingleOfflineWithAcceptOffline(self):
402 fn = self._GetSingleOfflineNode
403 assert fn("node100.example.com").offline
404 self.assertEqual(rpc._NodeConfigResolver(fn, NotImplemented,
405 ["node100.example.com"],
406 rpc_defs.ACCEPT_OFFLINE_NODE),
407 [("node100.example.com", "192.0.2.100")])
408 for i in [False, True, "", "Hello", 0, 1]:
409 self.assertRaises(AssertionError, rpc._NodeConfigResolver,
410 fn, NotImplemented, ["node100.example.com"], i)
412 def testUnknownSingleNode(self):
413 self.assertEqual(rpc._NodeConfigResolver(lambda _: None, NotImplemented,
414 ["node110.example.com"], None),
415 [("node110.example.com", "node110.example.com")])
417 def testMultiEmpty(self):
418 self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
423 def testMultiSomeOffline(self):
424 nodes = dict(("node%s.example.com" % i,
425 objects.Node(name="node%s.example.com" % i,
426 offline=((i % 3) == 0),
427 primary_ip="192.0.2.%s" % i))
428 for i in range(1, 255))
431 self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
436 # Offline, online and unknown hosts
437 self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
439 ["node3.example.com",
440 "node92.example.com",
441 "node54.example.com",
442 "unknown.example.com",],
444 ("node3.example.com", rpc._OFFLINE),
445 ("node92.example.com", "192.0.2.92"),
446 ("node54.example.com", rpc._OFFLINE),
447 ("unknown.example.com", "unknown.example.com"),
451 if __name__ == "__main__":
452 testutils.GanetiTestProgram()