4 # Copyright (C) 2010 Google Inc.
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 # General Public License for more details.
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
22 """Script for testing ganeti.rpc"""
28 from ganeti import constants
29 from ganeti import compat
30 from ganeti import rpc
31 from ganeti import http
32 from ganeti import errors
33 from ganeti import serializer
34 from ganeti import objects
39 class _FakeRequestProcessor:
40 def __init__(self, response_fn):
41 self._response_fn = response_fn
44 def __call__(self, reqs, lock_monitor_cb=None):
45 assert lock_monitor_cb is None or callable(lock_monitor_cb)
48 self._response_fn(req)
51 def GetFakeSimpleStoreClass(fn):
52 class FakeSimpleStore:
53 GetNodePrimaryIPList = fn
54 GetPrimaryIPFamily = lambda _: None
56 return FakeSimpleStore
59 class TestRpcProcessor(unittest.TestCase):
60 def _FakeAddressLookup(self, map):
61 return lambda node_list: [map.get(node) for node in node_list]
63 def _GetVersionResponse(self, req):
64 self.assertEqual(req.host, "127.0.0.1")
65 self.assertEqual(req.port, 24094)
66 self.assertEqual(req.path, "/version")
67 self.assertEqual(req.read_timeout, rpc._TMO_URGENT)
69 req.resp_status_code = http.HTTP_OK
70 req.resp_body = serializer.DumpJson((True, 123))
72 def testVersionSuccess(self):
73 resolver = rpc._StaticResolver(["127.0.0.1"])
74 http_proc = _FakeRequestProcessor(self._GetVersionResponse)
75 proc = rpc._RpcProcessor(resolver, 24094)
76 result = proc(["localhost"], "version", None, _req_process_fn=http_proc)
77 self.assertEqual(result.keys(), ["localhost"])
78 lhresp = result["localhost"]
79 self.assertFalse(lhresp.offline)
80 self.assertEqual(lhresp.node, "localhost")
81 self.assertFalse(lhresp.fail_msg)
82 self.assertEqual(lhresp.payload, 123)
83 self.assertEqual(lhresp.call, "version")
84 lhresp.Raise("should not raise")
85 self.assertEqual(http_proc.reqcount, 1)
87 def _ReadTimeoutResponse(self, req):
88 self.assertEqual(req.host, "192.0.2.13")
89 self.assertEqual(req.port, 19176)
90 self.assertEqual(req.path, "/version")
91 self.assertEqual(req.read_timeout, 12356)
93 req.resp_status_code = http.HTTP_OK
94 req.resp_body = serializer.DumpJson((True, -1))
96 def testReadTimeout(self):
97 resolver = rpc._StaticResolver(["192.0.2.13"])
98 http_proc = _FakeRequestProcessor(self._ReadTimeoutResponse)
99 proc = rpc._RpcProcessor(resolver, 19176)
100 result = proc(["node31856"], "version", None, _req_process_fn=http_proc,
102 self.assertEqual(result.keys(), ["node31856"])
103 lhresp = result["node31856"]
104 self.assertFalse(lhresp.offline)
105 self.assertEqual(lhresp.node, "node31856")
106 self.assertFalse(lhresp.fail_msg)
107 self.assertEqual(lhresp.payload, -1)
108 self.assertEqual(lhresp.call, "version")
109 lhresp.Raise("should not raise")
110 self.assertEqual(http_proc.reqcount, 1)
112 def testOfflineNode(self):
113 resolver = rpc._StaticResolver([rpc._OFFLINE])
114 http_proc = _FakeRequestProcessor(NotImplemented)
115 proc = rpc._RpcProcessor(resolver, 30668)
116 result = proc(["n17296"], "version", None, _req_process_fn=http_proc)
117 self.assertEqual(result.keys(), ["n17296"])
118 lhresp = result["n17296"]
119 self.assertTrue(lhresp.offline)
120 self.assertEqual(lhresp.node, "n17296")
121 self.assertTrue(lhresp.fail_msg)
122 self.assertFalse(lhresp.payload)
123 self.assertEqual(lhresp.call, "version")
126 self.assertRaises(errors.OpExecError, lhresp.Raise, "should raise")
129 self.assertRaises(errors.OpExecError, lhresp.Raise, None)
131 self.assertEqual(http_proc.reqcount, 0)
133 def _GetMultiVersionResponse(self, req):
134 self.assert_(req.host.startswith("node"))
135 self.assertEqual(req.port, 23245)
136 self.assertEqual(req.path, "/version")
138 req.resp_status_code = http.HTTP_OK
139 req.resp_body = serializer.DumpJson((True, 987))
141 def testMultiVersionSuccess(self):
142 nodes = ["node%s" % i for i in range(50)]
143 resolver = rpc._StaticResolver(nodes)
144 http_proc = _FakeRequestProcessor(self._GetMultiVersionResponse)
145 proc = rpc._RpcProcessor(resolver, 23245)
146 result = proc(nodes, "version", None, _req_process_fn=http_proc)
147 self.assertEqual(sorted(result.keys()), sorted(nodes))
150 lhresp = result[name]
151 self.assertFalse(lhresp.offline)
152 self.assertEqual(lhresp.node, name)
153 self.assertFalse(lhresp.fail_msg)
154 self.assertEqual(lhresp.payload, 987)
155 self.assertEqual(lhresp.call, "version")
156 lhresp.Raise("should not raise")
158 self.assertEqual(http_proc.reqcount, len(nodes))
160 def _GetVersionResponseFail(self, errinfo, req):
161 self.assertEqual(req.path, "/version")
163 req.resp_status_code = http.HTTP_OK
164 req.resp_body = serializer.DumpJson((False, errinfo))
166 def testVersionFailure(self):
167 resolver = rpc._StaticResolver(["aef9ur4i.example.com"])
168 proc = rpc._RpcProcessor(resolver, 5903)
169 for errinfo in [None, "Unknown error"]:
171 _FakeRequestProcessor(compat.partial(self._GetVersionResponseFail,
173 result = proc(["aef9ur4i.example.com"], "version", None,
174 _req_process_fn=http_proc)
175 self.assertEqual(result.keys(), ["aef9ur4i.example.com"])
176 lhresp = result["aef9ur4i.example.com"]
177 self.assertFalse(lhresp.offline)
178 self.assertEqual(lhresp.node, "aef9ur4i.example.com")
179 self.assert_(lhresp.fail_msg)
180 self.assertFalse(lhresp.payload)
181 self.assertEqual(lhresp.call, "version")
182 self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
183 self.assertEqual(http_proc.reqcount, 1)
185 def _GetHttpErrorResponse(self, httperrnodes, failnodes, req):
186 self.assertEqual(req.path, "/vg_list")
187 self.assertEqual(req.port, 15165)
189 if req.host in httperrnodes:
191 req.error = "Node set up for HTTP errors"
193 elif req.host in failnodes:
195 req.resp_status_code = 404
196 req.resp_body = serializer.DumpJson({
198 "message": "Method not found",
199 "explain": "Explanation goes here",
203 req.resp_status_code = http.HTTP_OK
204 req.resp_body = serializer.DumpJson((True, hash(req.host)))
206 def testHttpError(self):
207 nodes = ["uaf6pbbv%s" % i for i in range(50)]
208 resolver = rpc._StaticResolver(nodes)
210 httperrnodes = set(nodes[1::7])
211 self.assertEqual(len(httperrnodes), 7)
213 failnodes = set(nodes[2::3]) - httperrnodes
214 self.assertEqual(len(failnodes), 14)
216 self.assertEqual(len(set(nodes) - failnodes - httperrnodes), 29)
218 proc = rpc._RpcProcessor(resolver, 15165)
220 _FakeRequestProcessor(compat.partial(self._GetHttpErrorResponse,
221 httperrnodes, failnodes))
222 result = proc(nodes, "vg_list", None, _req_process_fn=http_proc,
223 read_timeout=rpc._TMO_URGENT)
224 self.assertEqual(sorted(result.keys()), sorted(nodes))
227 lhresp = result[name]
228 self.assertFalse(lhresp.offline)
229 self.assertEqual(lhresp.node, name)
230 self.assertEqual(lhresp.call, "vg_list")
232 if name in httperrnodes:
233 self.assert_(lhresp.fail_msg)
234 self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
235 elif name in failnodes:
236 self.assert_(lhresp.fail_msg)
237 self.assertRaises(errors.OpPrereqError, lhresp.Raise, "failed",
238 prereq=True, ecode=errors.ECODE_INVAL)
240 self.assertFalse(lhresp.fail_msg)
241 self.assertEqual(lhresp.payload, hash(name))
242 lhresp.Raise("should not raise")
244 self.assertEqual(http_proc.reqcount, len(nodes))
246 def _GetInvalidResponseA(self, req):
247 self.assertEqual(req.path, "/version")
249 req.resp_status_code = http.HTTP_OK
250 req.resp_body = serializer.DumpJson(("This", "is", "an", "invalid",
251 "response", "!", 1, 2, 3))
253 def _GetInvalidResponseB(self, req):
254 self.assertEqual(req.path, "/version")
256 req.resp_status_code = http.HTTP_OK
257 req.resp_body = serializer.DumpJson("invalid response")
259 def testInvalidResponse(self):
260 resolver = rpc._StaticResolver(["oqo7lanhly.example.com"])
261 proc = rpc._RpcProcessor(resolver, 19978)
263 for fn in [self._GetInvalidResponseA, self._GetInvalidResponseB]:
264 http_proc = _FakeRequestProcessor(fn)
265 result = proc(["oqo7lanhly.example.com"], "version", None,
266 _req_process_fn=http_proc)
267 self.assertEqual(result.keys(), ["oqo7lanhly.example.com"])
268 lhresp = result["oqo7lanhly.example.com"]
269 self.assertFalse(lhresp.offline)
270 self.assertEqual(lhresp.node, "oqo7lanhly.example.com")
271 self.assert_(lhresp.fail_msg)
272 self.assertFalse(lhresp.payload)
273 self.assertEqual(lhresp.call, "version")
274 self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
275 self.assertEqual(http_proc.reqcount, 1)
277 def _GetBodyTestResponse(self, test_data, req):
278 self.assertEqual(req.host, "192.0.2.84")
279 self.assertEqual(req.port, 18700)
280 self.assertEqual(req.path, "/upload_file")
281 self.assertEqual(serializer.LoadJson(req.post_data), test_data)
283 req.resp_status_code = http.HTTP_OK
284 req.resp_body = serializer.DumpJson((True, None))
286 def testResponseBody(self):
291 resolver = rpc._StaticResolver(["192.0.2.84"])
292 http_proc = _FakeRequestProcessor(compat.partial(self._GetBodyTestResponse,
294 proc = rpc._RpcProcessor(resolver, 18700)
295 body = serializer.DumpJson(test_data)
296 result = proc(["node19759"], "upload_file", body, _req_process_fn=http_proc)
297 self.assertEqual(result.keys(), ["node19759"])
298 lhresp = result["node19759"]
299 self.assertFalse(lhresp.offline)
300 self.assertEqual(lhresp.node, "node19759")
301 self.assertFalse(lhresp.fail_msg)
302 self.assertEqual(lhresp.payload, None)
303 self.assertEqual(lhresp.call, "upload_file")
304 lhresp.Raise("should not raise")
305 self.assertEqual(http_proc.reqcount, 1)
308 class TestSsconfResolver(unittest.TestCase):
309 def testSsconfLookup(self):
310 addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
311 node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
312 node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
313 ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
314 result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=NotImplemented)
315 self.assertEqual(result, zip(node_list, addr_list))
317 def testNsLookup(self):
318 addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
319 node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
320 ssc = GetFakeSimpleStoreClass(lambda _: [])
321 node_addr_map = dict(zip(node_list, addr_list))
322 nslookup_fn = lambda name, family=None: node_addr_map.get(name)
323 result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=nslookup_fn)
324 self.assertEqual(result, zip(node_list, addr_list))
326 def testBothLookups(self):
327 addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
328 node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
329 n = len(addr_list) / 2
330 node_addr_list = [" ".join(t) for t in zip(node_list[n:], addr_list[n:])]
331 ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
332 node_addr_map = dict(zip(node_list[:n], addr_list[:n]))
333 nslookup_fn = lambda name, family=None: node_addr_map.get(name)
334 result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=nslookup_fn)
335 self.assertEqual(result, zip(node_list, addr_list))
337 def testAddressLookupIPv6(self):
338 addr_list = ["2001:db8::%d" % n for n in range(0, 255, 11)]
339 node_list = ["node%d.example.com" % n for n in range(0, 255, 11)]
340 node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
341 ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
342 result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=NotImplemented)
343 self.assertEqual(result, zip(node_list, addr_list))
346 class TestStaticResolver(unittest.TestCase):
348 addresses = ["192.0.2.%d" % n for n in range(0, 123, 7)]
349 nodes = ["node%s.example.com" % n for n in range(0, 123, 7)]
350 res = rpc._StaticResolver(addresses)
351 self.assertEqual(res(nodes), zip(nodes, addresses))
353 def testWrongLength(self):
354 res = rpc._StaticResolver([])
355 self.assertRaises(AssertionError, res, ["abc"])
358 class TestNodeConfigResolver(unittest.TestCase):
360 def _GetSingleOnlineNode(name):
361 assert name == "node90.example.com"
362 return objects.Node(name=name, offline=False, primary_ip="192.0.2.90")
365 def _GetSingleOfflineNode(name):
366 assert name == "node100.example.com"
367 return objects.Node(name=name, offline=True, primary_ip="192.0.2.100")
369 def testSingleOnline(self):
370 self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOnlineNode,
372 ["node90.example.com"]),
373 [("node90.example.com", "192.0.2.90")])
375 def testSingleOffline(self):
376 self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOfflineNode,
378 ["node100.example.com"]),
379 [("node100.example.com", rpc._OFFLINE)])
381 def testUnknownSingleNode(self):
382 self.assertEqual(rpc._NodeConfigResolver(lambda _: None, NotImplemented,
383 ["node110.example.com"]),
384 [("node110.example.com", "node110.example.com")])
386 def testMultiEmpty(self):
387 self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
392 def testMultiSomeOffline(self):
393 nodes = dict(("node%s.example.com" % i,
394 objects.Node(name="node%s.example.com" % i,
395 offline=((i % 3) == 0),
396 primary_ip="192.0.2.%s" % i))
397 for i in range(1, 255))
400 self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
405 # Offline, online and unknown hosts
406 self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
408 ["node3.example.com",
409 "node92.example.com",
410 "node54.example.com",
411 "unknown.example.com",]), [
412 ("node3.example.com", rpc._OFFLINE),
413 ("node92.example.com", "192.0.2.92"),
414 ("node54.example.com", rpc._OFFLINE),
415 ("unknown.example.com", "unknown.example.com"),
419 if __name__ == "__main__":
420 testutils.GanetiTestProgram()