4 # Copyright (C) 2010 Google Inc.
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 # General Public License for more details.
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
22 """Script for testing ganeti.rpc"""
28 from ganeti import constants
29 from ganeti import compat
30 from ganeti import rpc
31 from ganeti import http
32 from ganeti import errors
33 from ganeti import serializer
34 from ganeti import objects
39 class _FakeRequestProcessor:
40 def __init__(self, response_fn):
41 self._response_fn = response_fn
44 def __call__(self, reqs, lock_monitor_cb=None):
45 assert lock_monitor_cb is None or callable(lock_monitor_cb)
48 self._response_fn(req)
51 def GetFakeSimpleStoreClass(fn):
52 class FakeSimpleStore:
53 GetNodePrimaryIPList = fn
54 GetPrimaryIPFamily = lambda _: None
56 return FakeSimpleStore
59 class TestRpcProcessor(unittest.TestCase):
60 def _FakeAddressLookup(self, map):
61 return lambda node_list: [map.get(node) for node in node_list]
63 def _GetVersionResponse(self, req):
64 self.assertEqual(req.host, "127.0.0.1")
65 self.assertEqual(req.port, 24094)
66 self.assertEqual(req.path, "/version")
67 self.assertEqual(req.read_timeout, rpc._TMO_URGENT)
69 req.resp_status_code = http.HTTP_OK
70 req.resp_body = serializer.DumpJson((True, 123))
72 def testVersionSuccess(self):
73 resolver = rpc._StaticResolver(["127.0.0.1"])
74 http_proc = _FakeRequestProcessor(self._GetVersionResponse)
75 proc = rpc._RpcProcessor(resolver, 24094)
76 result = proc(["localhost"], "version", None, _req_process_fn=http_proc,
78 self.assertEqual(result.keys(), ["localhost"])
79 lhresp = result["localhost"]
80 self.assertFalse(lhresp.offline)
81 self.assertEqual(lhresp.node, "localhost")
82 self.assertFalse(lhresp.fail_msg)
83 self.assertEqual(lhresp.payload, 123)
84 self.assertEqual(lhresp.call, "version")
85 lhresp.Raise("should not raise")
86 self.assertEqual(http_proc.reqcount, 1)
88 def _ReadTimeoutResponse(self, req):
89 self.assertEqual(req.host, "192.0.2.13")
90 self.assertEqual(req.port, 19176)
91 self.assertEqual(req.path, "/version")
92 self.assertEqual(req.read_timeout, 12356)
94 req.resp_status_code = http.HTTP_OK
95 req.resp_body = serializer.DumpJson((True, -1))
97 def testReadTimeout(self):
98 resolver = rpc._StaticResolver(["192.0.2.13"])
99 http_proc = _FakeRequestProcessor(self._ReadTimeoutResponse)
100 proc = rpc._RpcProcessor(resolver, 19176)
101 result = proc(["node31856"], "version", None, _req_process_fn=http_proc,
103 self.assertEqual(result.keys(), ["node31856"])
104 lhresp = result["node31856"]
105 self.assertFalse(lhresp.offline)
106 self.assertEqual(lhresp.node, "node31856")
107 self.assertFalse(lhresp.fail_msg)
108 self.assertEqual(lhresp.payload, -1)
109 self.assertEqual(lhresp.call, "version")
110 lhresp.Raise("should not raise")
111 self.assertEqual(http_proc.reqcount, 1)
113 def testOfflineNode(self):
114 resolver = rpc._StaticResolver([rpc._OFFLINE])
115 http_proc = _FakeRequestProcessor(NotImplemented)
116 proc = rpc._RpcProcessor(resolver, 30668)
117 result = proc(["n17296"], "version", None, _req_process_fn=http_proc,
119 self.assertEqual(result.keys(), ["n17296"])
120 lhresp = result["n17296"]
121 self.assertTrue(lhresp.offline)
122 self.assertEqual(lhresp.node, "n17296")
123 self.assertTrue(lhresp.fail_msg)
124 self.assertFalse(lhresp.payload)
125 self.assertEqual(lhresp.call, "version")
128 self.assertRaises(errors.OpExecError, lhresp.Raise, "should raise")
131 self.assertRaises(errors.OpExecError, lhresp.Raise, None)
133 self.assertEqual(http_proc.reqcount, 0)
135 def _GetMultiVersionResponse(self, req):
136 self.assert_(req.host.startswith("node"))
137 self.assertEqual(req.port, 23245)
138 self.assertEqual(req.path, "/version")
140 req.resp_status_code = http.HTTP_OK
141 req.resp_body = serializer.DumpJson((True, 987))
143 def testMultiVersionSuccess(self):
144 nodes = ["node%s" % i for i in range(50)]
145 resolver = rpc._StaticResolver(nodes)
146 http_proc = _FakeRequestProcessor(self._GetMultiVersionResponse)
147 proc = rpc._RpcProcessor(resolver, 23245)
148 result = proc(nodes, "version", None, _req_process_fn=http_proc,
150 self.assertEqual(sorted(result.keys()), sorted(nodes))
153 lhresp = result[name]
154 self.assertFalse(lhresp.offline)
155 self.assertEqual(lhresp.node, name)
156 self.assertFalse(lhresp.fail_msg)
157 self.assertEqual(lhresp.payload, 987)
158 self.assertEqual(lhresp.call, "version")
159 lhresp.Raise("should not raise")
161 self.assertEqual(http_proc.reqcount, len(nodes))
163 def _GetVersionResponseFail(self, errinfo, req):
164 self.assertEqual(req.path, "/version")
166 req.resp_status_code = http.HTTP_OK
167 req.resp_body = serializer.DumpJson((False, errinfo))
169 def testVersionFailure(self):
170 resolver = rpc._StaticResolver(["aef9ur4i.example.com"])
171 proc = rpc._RpcProcessor(resolver, 5903)
172 for errinfo in [None, "Unknown error"]:
174 _FakeRequestProcessor(compat.partial(self._GetVersionResponseFail,
176 result = proc(["aef9ur4i.example.com"], "version", None,
177 _req_process_fn=http_proc, read_timeout=60)
178 self.assertEqual(result.keys(), ["aef9ur4i.example.com"])
179 lhresp = result["aef9ur4i.example.com"]
180 self.assertFalse(lhresp.offline)
181 self.assertEqual(lhresp.node, "aef9ur4i.example.com")
182 self.assert_(lhresp.fail_msg)
183 self.assertFalse(lhresp.payload)
184 self.assertEqual(lhresp.call, "version")
185 self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
186 self.assertEqual(http_proc.reqcount, 1)
188 def _GetHttpErrorResponse(self, httperrnodes, failnodes, req):
189 self.assertEqual(req.path, "/vg_list")
190 self.assertEqual(req.port, 15165)
192 if req.host in httperrnodes:
194 req.error = "Node set up for HTTP errors"
196 elif req.host in failnodes:
198 req.resp_status_code = 404
199 req.resp_body = serializer.DumpJson({
201 "message": "Method not found",
202 "explain": "Explanation goes here",
206 req.resp_status_code = http.HTTP_OK
207 req.resp_body = serializer.DumpJson((True, hash(req.host)))
209 def testHttpError(self):
210 nodes = ["uaf6pbbv%s" % i for i in range(50)]
211 resolver = rpc._StaticResolver(nodes)
213 httperrnodes = set(nodes[1::7])
214 self.assertEqual(len(httperrnodes), 7)
216 failnodes = set(nodes[2::3]) - httperrnodes
217 self.assertEqual(len(failnodes), 14)
219 self.assertEqual(len(set(nodes) - failnodes - httperrnodes), 29)
221 proc = rpc._RpcProcessor(resolver, 15165)
223 _FakeRequestProcessor(compat.partial(self._GetHttpErrorResponse,
224 httperrnodes, failnodes))
225 result = proc(nodes, "vg_list", None, _req_process_fn=http_proc,
226 read_timeout=rpc._TMO_URGENT)
227 self.assertEqual(sorted(result.keys()), sorted(nodes))
230 lhresp = result[name]
231 self.assertFalse(lhresp.offline)
232 self.assertEqual(lhresp.node, name)
233 self.assertEqual(lhresp.call, "vg_list")
235 if name in httperrnodes:
236 self.assert_(lhresp.fail_msg)
237 self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
238 elif name in failnodes:
239 self.assert_(lhresp.fail_msg)
240 self.assertRaises(errors.OpPrereqError, lhresp.Raise, "failed",
241 prereq=True, ecode=errors.ECODE_INVAL)
243 self.assertFalse(lhresp.fail_msg)
244 self.assertEqual(lhresp.payload, hash(name))
245 lhresp.Raise("should not raise")
247 self.assertEqual(http_proc.reqcount, len(nodes))
249 def _GetInvalidResponseA(self, req):
250 self.assertEqual(req.path, "/version")
252 req.resp_status_code = http.HTTP_OK
253 req.resp_body = serializer.DumpJson(("This", "is", "an", "invalid",
254 "response", "!", 1, 2, 3))
256 def _GetInvalidResponseB(self, req):
257 self.assertEqual(req.path, "/version")
259 req.resp_status_code = http.HTTP_OK
260 req.resp_body = serializer.DumpJson("invalid response")
262 def testInvalidResponse(self):
263 resolver = rpc._StaticResolver(["oqo7lanhly.example.com"])
264 proc = rpc._RpcProcessor(resolver, 19978)
266 for fn in [self._GetInvalidResponseA, self._GetInvalidResponseB]:
267 http_proc = _FakeRequestProcessor(fn)
268 result = proc(["oqo7lanhly.example.com"], "version", None,
269 _req_process_fn=http_proc, read_timeout=60)
270 self.assertEqual(result.keys(), ["oqo7lanhly.example.com"])
271 lhresp = result["oqo7lanhly.example.com"]
272 self.assertFalse(lhresp.offline)
273 self.assertEqual(lhresp.node, "oqo7lanhly.example.com")
274 self.assert_(lhresp.fail_msg)
275 self.assertFalse(lhresp.payload)
276 self.assertEqual(lhresp.call, "version")
277 self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
278 self.assertEqual(http_proc.reqcount, 1)
280 def _GetBodyTestResponse(self, test_data, req):
281 self.assertEqual(req.host, "192.0.2.84")
282 self.assertEqual(req.port, 18700)
283 self.assertEqual(req.path, "/upload_file")
284 self.assertEqual(serializer.LoadJson(req.post_data), test_data)
286 req.resp_status_code = http.HTTP_OK
287 req.resp_body = serializer.DumpJson((True, None))
289 def testResponseBody(self):
294 resolver = rpc._StaticResolver(["192.0.2.84"])
295 http_proc = _FakeRequestProcessor(compat.partial(self._GetBodyTestResponse,
297 proc = rpc._RpcProcessor(resolver, 18700)
298 body = serializer.DumpJson(test_data)
299 result = proc(["node19759"], "upload_file", body, _req_process_fn=http_proc,
301 self.assertEqual(result.keys(), ["node19759"])
302 lhresp = result["node19759"]
303 self.assertFalse(lhresp.offline)
304 self.assertEqual(lhresp.node, "node19759")
305 self.assertFalse(lhresp.fail_msg)
306 self.assertEqual(lhresp.payload, None)
307 self.assertEqual(lhresp.call, "upload_file")
308 lhresp.Raise("should not raise")
309 self.assertEqual(http_proc.reqcount, 1)
312 class TestSsconfResolver(unittest.TestCase):
313 def testSsconfLookup(self):
314 addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
315 node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
316 node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
317 ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
318 result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=NotImplemented)
319 self.assertEqual(result, zip(node_list, addr_list))
321 def testNsLookup(self):
322 addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
323 node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
324 ssc = GetFakeSimpleStoreClass(lambda _: [])
325 node_addr_map = dict(zip(node_list, addr_list))
326 nslookup_fn = lambda name, family=None: node_addr_map.get(name)
327 result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=nslookup_fn)
328 self.assertEqual(result, zip(node_list, addr_list))
330 def testBothLookups(self):
331 addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
332 node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
333 n = len(addr_list) / 2
334 node_addr_list = [" ".join(t) for t in zip(node_list[n:], addr_list[n:])]
335 ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
336 node_addr_map = dict(zip(node_list[:n], addr_list[:n]))
337 nslookup_fn = lambda name, family=None: node_addr_map.get(name)
338 result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=nslookup_fn)
339 self.assertEqual(result, zip(node_list, addr_list))
341 def testAddressLookupIPv6(self):
342 addr_list = ["2001:db8::%d" % n for n in range(0, 255, 11)]
343 node_list = ["node%d.example.com" % n for n in range(0, 255, 11)]
344 node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
345 ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
346 result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=NotImplemented)
347 self.assertEqual(result, zip(node_list, addr_list))
350 class TestStaticResolver(unittest.TestCase):
352 addresses = ["192.0.2.%d" % n for n in range(0, 123, 7)]
353 nodes = ["node%s.example.com" % n for n in range(0, 123, 7)]
354 res = rpc._StaticResolver(addresses)
355 self.assertEqual(res(nodes), zip(nodes, addresses))
357 def testWrongLength(self):
358 res = rpc._StaticResolver([])
359 self.assertRaises(AssertionError, res, ["abc"])
362 class TestNodeConfigResolver(unittest.TestCase):
364 def _GetSingleOnlineNode(name):
365 assert name == "node90.example.com"
366 return objects.Node(name=name, offline=False, primary_ip="192.0.2.90")
369 def _GetSingleOfflineNode(name):
370 assert name == "node100.example.com"
371 return objects.Node(name=name, offline=True, primary_ip="192.0.2.100")
373 def testSingleOnline(self):
374 self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOnlineNode,
376 ["node90.example.com"]),
377 [("node90.example.com", "192.0.2.90")])
379 def testSingleOffline(self):
380 self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOfflineNode,
382 ["node100.example.com"]),
383 [("node100.example.com", rpc._OFFLINE)])
385 def testUnknownSingleNode(self):
386 self.assertEqual(rpc._NodeConfigResolver(lambda _: None, NotImplemented,
387 ["node110.example.com"]),
388 [("node110.example.com", "node110.example.com")])
390 def testMultiEmpty(self):
391 self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
396 def testMultiSomeOffline(self):
397 nodes = dict(("node%s.example.com" % i,
398 objects.Node(name="node%s.example.com" % i,
399 offline=((i % 3) == 0),
400 primary_ip="192.0.2.%s" % i))
401 for i in range(1, 255))
404 self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
409 # Offline, online and unknown hosts
410 self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
412 ["node3.example.com",
413 "node92.example.com",
414 "node54.example.com",
415 "unknown.example.com",]), [
416 ("node3.example.com", rpc._OFFLINE),
417 ("node92.example.com", "192.0.2.92"),
418 ("node54.example.com", rpc._OFFLINE),
419 ("unknown.example.com", "unknown.example.com"),
423 if __name__ == "__main__":
424 testutils.GanetiTestProgram()