4 # Copyright (C) 2010 Google Inc.
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 # General Public License for more details.
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
22 """Script for testing ganeti.rpc"""
28 from ganeti import constants
29 from ganeti import compat
30 from ganeti import rpc
31 from ganeti import http
32 from ganeti import errors
33 from ganeti import serializer
34 from ganeti import objects
39 class TestTimeouts(unittest.TestCase):
41 names = [name[len("call_"):] for name in dir(rpc.RpcRunner)
42 if name.startswith("call_")]
43 self.assertEqual(len(names), len(rpc._TIMEOUTS))
44 self.assertFalse([name for name in names
45 if not (rpc._TIMEOUTS[name] is None or
46 rpc._TIMEOUTS[name] > 0)])
50 def __init__(self, response_fn):
51 self._response_fn = response_fn
54 def ProcessRequests(self, reqs):
57 self._response_fn(req)
60 def GetFakeSimpleStoreClass(fn):
61 class FakeSimpleStore:
62 GetNodePrimaryIPList = fn
63 GetPrimaryIPFamily = lambda _: None
65 return FakeSimpleStore
68 class TestRpcProcessor(unittest.TestCase):
69 def _FakeAddressLookup(self, map):
70 return lambda node_list: [map.get(node) for node in node_list]
72 def _GetVersionResponse(self, req):
73 self.assertEqual(req.host, "127.0.0.1")
74 self.assertEqual(req.port, 24094)
75 self.assertEqual(req.path, "/version")
76 self.assertEqual(req.read_timeout, rpc._TMO_URGENT)
78 req.resp_status_code = http.HTTP_OK
79 req.resp_body = serializer.DumpJson((True, 123))
81 def testVersionSuccess(self):
82 resolver = rpc._StaticResolver(["127.0.0.1"])
83 pool = FakeHttpPool(self._GetVersionResponse)
84 proc = rpc._RpcProcessor(resolver, 24094)
85 result = proc(["localhost"], "version", None, http_pool=pool)
86 self.assertEqual(result.keys(), ["localhost"])
87 lhresp = result["localhost"]
88 self.assertFalse(lhresp.offline)
89 self.assertEqual(lhresp.node, "localhost")
90 self.assertFalse(lhresp.fail_msg)
91 self.assertEqual(lhresp.payload, 123)
92 self.assertEqual(lhresp.call, "version")
93 lhresp.Raise("should not raise")
94 self.assertEqual(pool.reqcount, 1)
96 def _ReadTimeoutResponse(self, req):
97 self.assertEqual(req.host, "192.0.2.13")
98 self.assertEqual(req.port, 19176)
99 self.assertEqual(req.path, "/version")
100 self.assertEqual(req.read_timeout, 12356)
102 req.resp_status_code = http.HTTP_OK
103 req.resp_body = serializer.DumpJson((True, -1))
105 def testReadTimeout(self):
106 resolver = rpc._StaticResolver(["192.0.2.13"])
107 pool = FakeHttpPool(self._ReadTimeoutResponse)
108 proc = rpc._RpcProcessor(resolver, 19176)
109 result = proc(["node31856"], "version", None, http_pool=pool,
111 self.assertEqual(result.keys(), ["node31856"])
112 lhresp = result["node31856"]
113 self.assertFalse(lhresp.offline)
114 self.assertEqual(lhresp.node, "node31856")
115 self.assertFalse(lhresp.fail_msg)
116 self.assertEqual(lhresp.payload, -1)
117 self.assertEqual(lhresp.call, "version")
118 lhresp.Raise("should not raise")
119 self.assertEqual(pool.reqcount, 1)
121 def testOfflineNode(self):
122 resolver = rpc._StaticResolver([rpc._OFFLINE])
123 pool = FakeHttpPool(NotImplemented)
124 proc = rpc._RpcProcessor(resolver, 30668)
125 result = proc(["n17296"], "version", None, http_pool=pool)
126 self.assertEqual(result.keys(), ["n17296"])
127 lhresp = result["n17296"]
128 self.assertTrue(lhresp.offline)
129 self.assertEqual(lhresp.node, "n17296")
130 self.assertTrue(lhresp.fail_msg)
131 self.assertFalse(lhresp.payload)
132 self.assertEqual(lhresp.call, "version")
135 self.assertRaises(errors.OpExecError, lhresp.Raise, "should raise")
138 self.assertRaises(errors.OpExecError, lhresp.Raise, None)
140 self.assertEqual(pool.reqcount, 0)
142 def _GetMultiVersionResponse(self, req):
143 self.assert_(req.host.startswith("node"))
144 self.assertEqual(req.port, 23245)
145 self.assertEqual(req.path, "/version")
147 req.resp_status_code = http.HTTP_OK
148 req.resp_body = serializer.DumpJson((True, 987))
150 def testMultiVersionSuccess(self):
151 nodes = ["node%s" % i for i in range(50)]
152 resolver = rpc._StaticResolver(nodes)
153 pool = FakeHttpPool(self._GetMultiVersionResponse)
154 proc = rpc._RpcProcessor(resolver, 23245)
155 result = proc(nodes, "version", None, http_pool=pool)
156 self.assertEqual(sorted(result.keys()), sorted(nodes))
159 lhresp = result[name]
160 self.assertFalse(lhresp.offline)
161 self.assertEqual(lhresp.node, name)
162 self.assertFalse(lhresp.fail_msg)
163 self.assertEqual(lhresp.payload, 987)
164 self.assertEqual(lhresp.call, "version")
165 lhresp.Raise("should not raise")
167 self.assertEqual(pool.reqcount, len(nodes))
169 def _GetVersionResponseFail(self, errinfo, req):
170 self.assertEqual(req.path, "/version")
172 req.resp_status_code = http.HTTP_OK
173 req.resp_body = serializer.DumpJson((False, errinfo))
175 def testVersionFailure(self):
176 resolver = rpc._StaticResolver(["aef9ur4i.example.com"])
177 proc = rpc._RpcProcessor(resolver, 5903)
178 for errinfo in [None, "Unknown error"]:
179 pool = FakeHttpPool(compat.partial(self._GetVersionResponseFail, errinfo))
180 result = proc(["aef9ur4i.example.com"], "version", None, http_pool=pool)
181 self.assertEqual(result.keys(), ["aef9ur4i.example.com"])
182 lhresp = result["aef9ur4i.example.com"]
183 self.assertFalse(lhresp.offline)
184 self.assertEqual(lhresp.node, "aef9ur4i.example.com")
185 self.assert_(lhresp.fail_msg)
186 self.assertFalse(lhresp.payload)
187 self.assertEqual(lhresp.call, "version")
188 self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
189 self.assertEqual(pool.reqcount, 1)
191 def _GetHttpErrorResponse(self, httperrnodes, failnodes, req):
192 self.assertEqual(req.path, "/vg_list")
193 self.assertEqual(req.port, 15165)
195 if req.host in httperrnodes:
197 req.error = "Node set up for HTTP errors"
199 elif req.host in failnodes:
201 req.resp_status_code = 404
202 req.resp_body = serializer.DumpJson({
204 "message": "Method not found",
205 "explain": "Explanation goes here",
209 req.resp_status_code = http.HTTP_OK
210 req.resp_body = serializer.DumpJson((True, hash(req.host)))
212 def testHttpError(self):
213 nodes = ["uaf6pbbv%s" % i for i in range(50)]
214 resolver = rpc._StaticResolver(nodes)
216 httperrnodes = set(nodes[1::7])
217 self.assertEqual(len(httperrnodes), 7)
219 failnodes = set(nodes[2::3]) - httperrnodes
220 self.assertEqual(len(failnodes), 14)
222 self.assertEqual(len(set(nodes) - failnodes - httperrnodes), 29)
224 proc = rpc._RpcProcessor(resolver, 15165)
225 pool = FakeHttpPool(compat.partial(self._GetHttpErrorResponse,
226 httperrnodes, failnodes))
227 result = proc(nodes, "vg_list", None, http_pool=pool)
228 self.assertEqual(sorted(result.keys()), sorted(nodes))
231 lhresp = result[name]
232 self.assertFalse(lhresp.offline)
233 self.assertEqual(lhresp.node, name)
234 self.assertEqual(lhresp.call, "vg_list")
236 if name in httperrnodes:
237 self.assert_(lhresp.fail_msg)
238 self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
239 elif name in failnodes:
240 self.assert_(lhresp.fail_msg)
241 self.assertRaises(errors.OpPrereqError, lhresp.Raise, "failed",
242 prereq=True, ecode=errors.ECODE_INVAL)
244 self.assertFalse(lhresp.fail_msg)
245 self.assertEqual(lhresp.payload, hash(name))
246 lhresp.Raise("should not raise")
248 self.assertEqual(pool.reqcount, len(nodes))
250 def _GetInvalidResponseA(self, req):
251 self.assertEqual(req.path, "/version")
253 req.resp_status_code = http.HTTP_OK
254 req.resp_body = serializer.DumpJson(("This", "is", "an", "invalid",
255 "response", "!", 1, 2, 3))
257 def _GetInvalidResponseB(self, req):
258 self.assertEqual(req.path, "/version")
260 req.resp_status_code = http.HTTP_OK
261 req.resp_body = serializer.DumpJson("invalid response")
263 def testInvalidResponse(self):
264 resolver = rpc._StaticResolver(["oqo7lanhly.example.com"])
265 proc = rpc._RpcProcessor(resolver, 19978)
267 for fn in [self._GetInvalidResponseA, self._GetInvalidResponseB]:
268 pool = FakeHttpPool(fn)
269 result = proc(["oqo7lanhly.example.com"], "version", None, http_pool=pool)
270 self.assertEqual(result.keys(), ["oqo7lanhly.example.com"])
271 lhresp = result["oqo7lanhly.example.com"]
272 self.assertFalse(lhresp.offline)
273 self.assertEqual(lhresp.node, "oqo7lanhly.example.com")
274 self.assert_(lhresp.fail_msg)
275 self.assertFalse(lhresp.payload)
276 self.assertEqual(lhresp.call, "version")
277 self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
278 self.assertEqual(pool.reqcount, 1)
280 def _GetBodyTestResponse(self, test_data, req):
281 self.assertEqual(req.host, "192.0.2.84")
282 self.assertEqual(req.port, 18700)
283 self.assertEqual(req.path, "/upload_file")
284 self.assertEqual(serializer.LoadJson(req.post_data), test_data)
286 req.resp_status_code = http.HTTP_OK
287 req.resp_body = serializer.DumpJson((True, None))
289 def testResponseBody(self):
294 resolver = rpc._StaticResolver(["192.0.2.84"])
295 pool = FakeHttpPool(compat.partial(self._GetBodyTestResponse, test_data))
296 proc = rpc._RpcProcessor(resolver, 18700)
297 body = serializer.DumpJson(test_data)
298 result = proc(["node19759"], "upload_file", body, http_pool=pool)
299 self.assertEqual(result.keys(), ["node19759"])
300 lhresp = result["node19759"]
301 self.assertFalse(lhresp.offline)
302 self.assertEqual(lhresp.node, "node19759")
303 self.assertFalse(lhresp.fail_msg)
304 self.assertEqual(lhresp.payload, None)
305 self.assertEqual(lhresp.call, "upload_file")
306 lhresp.Raise("should not raise")
307 self.assertEqual(pool.reqcount, 1)
310 class TestSsconfResolver(unittest.TestCase):
311 def testSsconfLookup(self):
312 addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
313 node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
314 node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
315 ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
316 result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=NotImplemented)
317 self.assertEqual(result, zip(node_list, addr_list))
319 def testNsLookup(self):
320 addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
321 node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
322 ssc = GetFakeSimpleStoreClass(lambda _: [])
323 node_addr_map = dict(zip(node_list, addr_list))
324 nslookup_fn = lambda name, family=None: node_addr_map.get(name)
325 result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=nslookup_fn)
326 self.assertEqual(result, zip(node_list, addr_list))
328 def testBothLookups(self):
329 addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
330 node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
331 n = len(addr_list) / 2
332 node_addr_list = [" ".join(t) for t in zip(node_list[n:], addr_list[n:])]
333 ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
334 node_addr_map = dict(zip(node_list[:n], addr_list[:n]))
335 nslookup_fn = lambda name, family=None: node_addr_map.get(name)
336 result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=nslookup_fn)
337 self.assertEqual(result, zip(node_list, addr_list))
339 def testAddressLookupIPv6(self):
340 addr_list = ["2001:db8::%d" % n for n in range(0, 255, 11)]
341 node_list = ["node%d.example.com" % n for n in range(0, 255, 11)]
342 node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
343 ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
344 result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=NotImplemented)
345 self.assertEqual(result, zip(node_list, addr_list))
348 class TestStaticResolver(unittest.TestCase):
350 addresses = ["192.0.2.%d" % n for n in range(0, 123, 7)]
351 nodes = ["node%s.example.com" % n for n in range(0, 123, 7)]
352 res = rpc._StaticResolver(addresses)
353 self.assertEqual(res(nodes), zip(nodes, addresses))
355 def testWrongLength(self):
356 res = rpc._StaticResolver([])
357 self.assertRaises(AssertionError, res, ["abc"])
360 class TestNodeConfigResolver(unittest.TestCase):
362 def _GetSingleOnlineNode(name):
363 assert name == "node90.example.com"
364 return objects.Node(name=name, offline=False, primary_ip="192.0.2.90")
367 def _GetSingleOfflineNode(name):
368 assert name == "node100.example.com"
369 return objects.Node(name=name, offline=True, primary_ip="192.0.2.100")
371 def testSingleOnline(self):
372 self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOnlineNode,
374 ["node90.example.com"]),
375 [("node90.example.com", "192.0.2.90")])
377 def testSingleOffline(self):
378 self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOfflineNode,
380 ["node100.example.com"]),
381 [("node100.example.com", rpc._OFFLINE)])
383 def testUnknownSingleNode(self):
384 self.assertEqual(rpc._NodeConfigResolver(lambda _: None, NotImplemented,
385 ["node110.example.com"]),
386 [("node110.example.com", "node110.example.com")])
388 def testMultiEmpty(self):
389 self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
394 def testMultiSomeOffline(self):
395 nodes = dict(("node%s.example.com" % i,
396 objects.Node(name="node%s.example.com" % i,
397 offline=((i % 3) == 0),
398 primary_ip="192.0.2.%s" % i))
399 for i in range(1, 255))
402 self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
407 # Offline, online and unknown hosts
408 self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
410 ["node3.example.com",
411 "node92.example.com",
412 "node54.example.com",
413 "unknown.example.com",]), [
414 ("node3.example.com", rpc._OFFLINE),
415 ("node92.example.com", "192.0.2.92"),
416 ("node54.example.com", rpc._OFFLINE),
417 ("unknown.example.com", "unknown.example.com"),
421 if __name__ == "__main__":
422 testutils.GanetiTestProgram()