4 # Copyright (C) 2010, 2011 Google Inc.
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 # General Public License for more details.
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
22 """Script for testing ganeti.rpc"""
28 from ganeti import constants
29 from ganeti import compat
30 from ganeti import rpc
31 from ganeti import http
32 from ganeti import errors
33 from ganeti import serializer
34 from ganeti import objects
39 class _FakeRequestProcessor:
40 def __init__(self, response_fn):
41 self._response_fn = response_fn
44 def __call__(self, reqs, lock_monitor_cb=None):
45 assert lock_monitor_cb is None or callable(lock_monitor_cb)
48 self._response_fn(req)
51 def GetFakeSimpleStoreClass(fn):
52 class FakeSimpleStore:
53 GetNodePrimaryIPList = fn
54 GetPrimaryIPFamily = lambda _: None
56 return FakeSimpleStore
59 class TestRpcProcessor(unittest.TestCase):
60 def _FakeAddressLookup(self, map):
61 return lambda node_list: [map.get(node) for node in node_list]
63 def _GetVersionResponse(self, req):
64 self.assertEqual(req.host, "127.0.0.1")
65 self.assertEqual(req.port, 24094)
66 self.assertEqual(req.path, "/version")
67 self.assertEqual(req.read_timeout, rpc._TMO_URGENT)
69 req.resp_status_code = http.HTTP_OK
70 req.resp_body = serializer.DumpJson((True, 123))
72 def testVersionSuccess(self):
73 resolver = rpc._StaticResolver(["127.0.0.1"])
74 http_proc = _FakeRequestProcessor(self._GetVersionResponse)
75 proc = rpc._RpcProcessor(resolver, 24094)
76 result = proc(["localhost"], "version", {"localhost": ""},
77 _req_process_fn=http_proc, read_timeout=60)
78 self.assertEqual(result.keys(), ["localhost"])
79 lhresp = result["localhost"]
80 self.assertFalse(lhresp.offline)
81 self.assertEqual(lhresp.node, "localhost")
82 self.assertFalse(lhresp.fail_msg)
83 self.assertEqual(lhresp.payload, 123)
84 self.assertEqual(lhresp.call, "version")
85 lhresp.Raise("should not raise")
86 self.assertEqual(http_proc.reqcount, 1)
88 def _ReadTimeoutResponse(self, req):
89 self.assertEqual(req.host, "192.0.2.13")
90 self.assertEqual(req.port, 19176)
91 self.assertEqual(req.path, "/version")
92 self.assertEqual(req.read_timeout, 12356)
94 req.resp_status_code = http.HTTP_OK
95 req.resp_body = serializer.DumpJson((True, -1))
97 def testReadTimeout(self):
98 resolver = rpc._StaticResolver(["192.0.2.13"])
99 http_proc = _FakeRequestProcessor(self._ReadTimeoutResponse)
100 proc = rpc._RpcProcessor(resolver, 19176)
103 result = proc([host], "version", body, _req_process_fn=http_proc,
105 self.assertEqual(result.keys(), [host])
106 lhresp = result[host]
107 self.assertFalse(lhresp.offline)
108 self.assertEqual(lhresp.node, host)
109 self.assertFalse(lhresp.fail_msg)
110 self.assertEqual(lhresp.payload, -1)
111 self.assertEqual(lhresp.call, "version")
112 lhresp.Raise("should not raise")
113 self.assertEqual(http_proc.reqcount, 1)
115 def testOfflineNode(self):
116 resolver = rpc._StaticResolver([rpc._OFFLINE])
117 http_proc = _FakeRequestProcessor(NotImplemented)
118 proc = rpc._RpcProcessor(resolver, 30668)
121 result = proc([host], "version", body, _req_process_fn=http_proc,
123 self.assertEqual(result.keys(), [host])
124 lhresp = result[host]
125 self.assertTrue(lhresp.offline)
126 self.assertEqual(lhresp.node, host)
127 self.assertTrue(lhresp.fail_msg)
128 self.assertFalse(lhresp.payload)
129 self.assertEqual(lhresp.call, "version")
132 self.assertRaises(errors.OpExecError, lhresp.Raise, "should raise")
135 self.assertRaises(errors.OpExecError, lhresp.Raise, None)
137 self.assertEqual(http_proc.reqcount, 0)
139 def _GetMultiVersionResponse(self, req):
140 self.assert_(req.host.startswith("node"))
141 self.assertEqual(req.port, 23245)
142 self.assertEqual(req.path, "/version")
144 req.resp_status_code = http.HTTP_OK
145 req.resp_body = serializer.DumpJson((True, 987))
147 def testMultiVersionSuccess(self):
148 nodes = ["node%s" % i for i in range(50)]
149 body = dict((n, "") for n in nodes)
150 resolver = rpc._StaticResolver(nodes)
151 http_proc = _FakeRequestProcessor(self._GetMultiVersionResponse)
152 proc = rpc._RpcProcessor(resolver, 23245)
153 result = proc(nodes, "version", body, _req_process_fn=http_proc,
155 self.assertEqual(sorted(result.keys()), sorted(nodes))
158 lhresp = result[name]
159 self.assertFalse(lhresp.offline)
160 self.assertEqual(lhresp.node, name)
161 self.assertFalse(lhresp.fail_msg)
162 self.assertEqual(lhresp.payload, 987)
163 self.assertEqual(lhresp.call, "version")
164 lhresp.Raise("should not raise")
166 self.assertEqual(http_proc.reqcount, len(nodes))
168 def _GetVersionResponseFail(self, errinfo, req):
169 self.assertEqual(req.path, "/version")
171 req.resp_status_code = http.HTTP_OK
172 req.resp_body = serializer.DumpJson((False, errinfo))
174 def testVersionFailure(self):
175 resolver = rpc._StaticResolver(["aef9ur4i.example.com"])
176 proc = rpc._RpcProcessor(resolver, 5903)
177 for errinfo in [None, "Unknown error"]:
179 _FakeRequestProcessor(compat.partial(self._GetVersionResponseFail,
181 host = "aef9ur4i.example.com"
183 result = proc(body.keys(), "version", body,
184 _req_process_fn=http_proc, read_timeout=60)
185 self.assertEqual(result.keys(), [host])
186 lhresp = result[host]
187 self.assertFalse(lhresp.offline)
188 self.assertEqual(lhresp.node, host)
189 self.assert_(lhresp.fail_msg)
190 self.assertFalse(lhresp.payload)
191 self.assertEqual(lhresp.call, "version")
192 self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
193 self.assertEqual(http_proc.reqcount, 1)
195 def _GetHttpErrorResponse(self, httperrnodes, failnodes, req):
196 self.assertEqual(req.path, "/vg_list")
197 self.assertEqual(req.port, 15165)
199 if req.host in httperrnodes:
201 req.error = "Node set up for HTTP errors"
203 elif req.host in failnodes:
205 req.resp_status_code = 404
206 req.resp_body = serializer.DumpJson({
208 "message": "Method not found",
209 "explain": "Explanation goes here",
213 req.resp_status_code = http.HTTP_OK
214 req.resp_body = serializer.DumpJson((True, hash(req.host)))
216 def testHttpError(self):
217 nodes = ["uaf6pbbv%s" % i for i in range(50)]
218 body = dict((n, "") for n in nodes)
219 resolver = rpc._StaticResolver(nodes)
221 httperrnodes = set(nodes[1::7])
222 self.assertEqual(len(httperrnodes), 7)
224 failnodes = set(nodes[2::3]) - httperrnodes
225 self.assertEqual(len(failnodes), 14)
227 self.assertEqual(len(set(nodes) - failnodes - httperrnodes), 29)
229 proc = rpc._RpcProcessor(resolver, 15165)
231 _FakeRequestProcessor(compat.partial(self._GetHttpErrorResponse,
232 httperrnodes, failnodes))
233 result = proc(nodes, "vg_list", body, _req_process_fn=http_proc,
234 read_timeout=rpc._TMO_URGENT)
235 self.assertEqual(sorted(result.keys()), sorted(nodes))
238 lhresp = result[name]
239 self.assertFalse(lhresp.offline)
240 self.assertEqual(lhresp.node, name)
241 self.assertEqual(lhresp.call, "vg_list")
243 if name in httperrnodes:
244 self.assert_(lhresp.fail_msg)
245 self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
246 elif name in failnodes:
247 self.assert_(lhresp.fail_msg)
248 self.assertRaises(errors.OpPrereqError, lhresp.Raise, "failed",
249 prereq=True, ecode=errors.ECODE_INVAL)
251 self.assertFalse(lhresp.fail_msg)
252 self.assertEqual(lhresp.payload, hash(name))
253 lhresp.Raise("should not raise")
255 self.assertEqual(http_proc.reqcount, len(nodes))
257 def _GetInvalidResponseA(self, req):
258 self.assertEqual(req.path, "/version")
260 req.resp_status_code = http.HTTP_OK
261 req.resp_body = serializer.DumpJson(("This", "is", "an", "invalid",
262 "response", "!", 1, 2, 3))
264 def _GetInvalidResponseB(self, req):
265 self.assertEqual(req.path, "/version")
267 req.resp_status_code = http.HTTP_OK
268 req.resp_body = serializer.DumpJson("invalid response")
270 def testInvalidResponse(self):
271 resolver = rpc._StaticResolver(["oqo7lanhly.example.com"])
272 proc = rpc._RpcProcessor(resolver, 19978)
274 for fn in [self._GetInvalidResponseA, self._GetInvalidResponseB]:
275 http_proc = _FakeRequestProcessor(fn)
276 host = "oqo7lanhly.example.com"
278 result = proc([host], "version", body,
279 _req_process_fn=http_proc, read_timeout=60)
280 self.assertEqual(result.keys(), [host])
281 lhresp = result[host]
282 self.assertFalse(lhresp.offline)
283 self.assertEqual(lhresp.node, host)
284 self.assert_(lhresp.fail_msg)
285 self.assertFalse(lhresp.payload)
286 self.assertEqual(lhresp.call, "version")
287 self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
288 self.assertEqual(http_proc.reqcount, 1)
290 def _GetBodyTestResponse(self, test_data, req):
291 self.assertEqual(req.host, "192.0.2.84")
292 self.assertEqual(req.port, 18700)
293 self.assertEqual(req.path, "/upload_file")
294 self.assertEqual(serializer.LoadJson(req.post_data), test_data)
296 req.resp_status_code = http.HTTP_OK
297 req.resp_body = serializer.DumpJson((True, None))
299 def testResponseBody(self):
304 resolver = rpc._StaticResolver(["192.0.2.84"])
305 http_proc = _FakeRequestProcessor(compat.partial(self._GetBodyTestResponse,
307 proc = rpc._RpcProcessor(resolver, 18700)
309 body = {host: serializer.DumpJson(test_data)}
310 result = proc([host], "upload_file", body, _req_process_fn=http_proc,
312 self.assertEqual(result.keys(), [host])
313 lhresp = result[host]
314 self.assertFalse(lhresp.offline)
315 self.assertEqual(lhresp.node, host)
316 self.assertFalse(lhresp.fail_msg)
317 self.assertEqual(lhresp.payload, None)
318 self.assertEqual(lhresp.call, "upload_file")
319 lhresp.Raise("should not raise")
320 self.assertEqual(http_proc.reqcount, 1)
323 class TestSsconfResolver(unittest.TestCase):
324 def testSsconfLookup(self):
325 addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
326 node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
327 node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
328 ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
329 result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=NotImplemented)
330 self.assertEqual(result, zip(node_list, addr_list))
332 def testNsLookup(self):
333 addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
334 node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
335 ssc = GetFakeSimpleStoreClass(lambda _: [])
336 node_addr_map = dict(zip(node_list, addr_list))
337 nslookup_fn = lambda name, family=None: node_addr_map.get(name)
338 result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=nslookup_fn)
339 self.assertEqual(result, zip(node_list, addr_list))
341 def testBothLookups(self):
342 addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
343 node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
344 n = len(addr_list) / 2
345 node_addr_list = [" ".join(t) for t in zip(node_list[n:], addr_list[n:])]
346 ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
347 node_addr_map = dict(zip(node_list[:n], addr_list[:n]))
348 nslookup_fn = lambda name, family=None: node_addr_map.get(name)
349 result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=nslookup_fn)
350 self.assertEqual(result, zip(node_list, addr_list))
352 def testAddressLookupIPv6(self):
353 addr_list = ["2001:db8::%d" % n for n in range(0, 255, 11)]
354 node_list = ["node%d.example.com" % n for n in range(0, 255, 11)]
355 node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
356 ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
357 result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=NotImplemented)
358 self.assertEqual(result, zip(node_list, addr_list))
361 class TestStaticResolver(unittest.TestCase):
363 addresses = ["192.0.2.%d" % n for n in range(0, 123, 7)]
364 nodes = ["node%s.example.com" % n for n in range(0, 123, 7)]
365 res = rpc._StaticResolver(addresses)
366 self.assertEqual(res(nodes), zip(nodes, addresses))
368 def testWrongLength(self):
369 res = rpc._StaticResolver([])
370 self.assertRaises(AssertionError, res, ["abc"])
373 class TestNodeConfigResolver(unittest.TestCase):
375 def _GetSingleOnlineNode(name):
376 assert name == "node90.example.com"
377 return objects.Node(name=name, offline=False, primary_ip="192.0.2.90")
380 def _GetSingleOfflineNode(name):
381 assert name == "node100.example.com"
382 return objects.Node(name=name, offline=True, primary_ip="192.0.2.100")
384 def testSingleOnline(self):
385 self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOnlineNode,
387 ["node90.example.com"]),
388 [("node90.example.com", "192.0.2.90")])
390 def testSingleOffline(self):
391 self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOfflineNode,
393 ["node100.example.com"]),
394 [("node100.example.com", rpc._OFFLINE)])
396 def testUnknownSingleNode(self):
397 self.assertEqual(rpc._NodeConfigResolver(lambda _: None, NotImplemented,
398 ["node110.example.com"]),
399 [("node110.example.com", "node110.example.com")])
401 def testMultiEmpty(self):
402 self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
407 def testMultiSomeOffline(self):
408 nodes = dict(("node%s.example.com" % i,
409 objects.Node(name="node%s.example.com" % i,
410 offline=((i % 3) == 0),
411 primary_ip="192.0.2.%s" % i))
412 for i in range(1, 255))
415 self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
420 # Offline, online and unknown hosts
421 self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
423 ["node3.example.com",
424 "node92.example.com",
425 "node54.example.com",
426 "unknown.example.com",]), [
427 ("node3.example.com", rpc._OFFLINE),
428 ("node92.example.com", "192.0.2.92"),
429 ("node54.example.com", rpc._OFFLINE),
430 ("unknown.example.com", "unknown.example.com"),
434 if __name__ == "__main__":
435 testutils.GanetiTestProgram()