rpc._NodeConfigResolver: Support resolving offline nodes
[ganeti-local] / test / ganeti.rpc_unittest.py
1 #!/usr/bin/python
2 #
3
4 # Copyright (C) 2010, 2011 Google Inc.
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 # 02110-1301, USA.
20
21
22 """Script for testing ganeti.rpc"""
23
24 import os
25 import sys
26 import unittest
27
28 from ganeti import constants
29 from ganeti import compat
30 from ganeti import rpc
31 from ganeti import rpc_defs
32 from ganeti import http
33 from ganeti import errors
34 from ganeti import serializer
35 from ganeti import objects
36
37 import testutils
38
39
40 class _FakeRequestProcessor:
41   def __init__(self, response_fn):
42     self._response_fn = response_fn
43     self.reqcount = 0
44
45   def __call__(self, reqs, lock_monitor_cb=None):
46     assert lock_monitor_cb is None or callable(lock_monitor_cb)
47     for req in reqs:
48       self.reqcount += 1
49       self._response_fn(req)
50
51
52 def GetFakeSimpleStoreClass(fn):
53   class FakeSimpleStore:
54     GetNodePrimaryIPList = fn
55     GetPrimaryIPFamily = lambda _: None
56
57   return FakeSimpleStore
58
59
60 class TestRpcProcessor(unittest.TestCase):
61   def _FakeAddressLookup(self, map):
62     return lambda node_list: [map.get(node) for node in node_list]
63
64   def _GetVersionResponse(self, req):
65     self.assertEqual(req.host, "127.0.0.1")
66     self.assertEqual(req.port, 24094)
67     self.assertEqual(req.path, "/version")
68     self.assertEqual(req.read_timeout, rpc._TMO_URGENT)
69     req.success = True
70     req.resp_status_code = http.HTTP_OK
71     req.resp_body = serializer.DumpJson((True, 123))
72
73   def testVersionSuccess(self):
74     resolver = rpc._StaticResolver(["127.0.0.1"])
75     http_proc = _FakeRequestProcessor(self._GetVersionResponse)
76     proc = rpc._RpcProcessor(resolver, 24094)
77     result = proc(["localhost"], "version", {"localhost": ""}, 60,
78                   NotImplemented, _req_process_fn=http_proc)
79     self.assertEqual(result.keys(), ["localhost"])
80     lhresp = result["localhost"]
81     self.assertFalse(lhresp.offline)
82     self.assertEqual(lhresp.node, "localhost")
83     self.assertFalse(lhresp.fail_msg)
84     self.assertEqual(lhresp.payload, 123)
85     self.assertEqual(lhresp.call, "version")
86     lhresp.Raise("should not raise")
87     self.assertEqual(http_proc.reqcount, 1)
88
89   def _ReadTimeoutResponse(self, req):
90     self.assertEqual(req.host, "192.0.2.13")
91     self.assertEqual(req.port, 19176)
92     self.assertEqual(req.path, "/version")
93     self.assertEqual(req.read_timeout, 12356)
94     req.success = True
95     req.resp_status_code = http.HTTP_OK
96     req.resp_body = serializer.DumpJson((True, -1))
97
98   def testReadTimeout(self):
99     resolver = rpc._StaticResolver(["192.0.2.13"])
100     http_proc = _FakeRequestProcessor(self._ReadTimeoutResponse)
101     proc = rpc._RpcProcessor(resolver, 19176)
102     host = "node31856"
103     body = {host: ""}
104     result = proc([host], "version", body, 12356, NotImplemented,
105                   _req_process_fn=http_proc)
106     self.assertEqual(result.keys(), [host])
107     lhresp = result[host]
108     self.assertFalse(lhresp.offline)
109     self.assertEqual(lhresp.node, host)
110     self.assertFalse(lhresp.fail_msg)
111     self.assertEqual(lhresp.payload, -1)
112     self.assertEqual(lhresp.call, "version")
113     lhresp.Raise("should not raise")
114     self.assertEqual(http_proc.reqcount, 1)
115
116   def testOfflineNode(self):
117     resolver = rpc._StaticResolver([rpc._OFFLINE])
118     http_proc = _FakeRequestProcessor(NotImplemented)
119     proc = rpc._RpcProcessor(resolver, 30668)
120     host = "n17296"
121     body = {host: ""}
122     result = proc([host], "version", body, 60, NotImplemented,
123                   _req_process_fn=http_proc)
124     self.assertEqual(result.keys(), [host])
125     lhresp = result[host]
126     self.assertTrue(lhresp.offline)
127     self.assertEqual(lhresp.node, host)
128     self.assertTrue(lhresp.fail_msg)
129     self.assertFalse(lhresp.payload)
130     self.assertEqual(lhresp.call, "version")
131
132     # With a message
133     self.assertRaises(errors.OpExecError, lhresp.Raise, "should raise")
134
135     # No message
136     self.assertRaises(errors.OpExecError, lhresp.Raise, None)
137
138     self.assertEqual(http_proc.reqcount, 0)
139
140   def _GetMultiVersionResponse(self, req):
141     self.assert_(req.host.startswith("node"))
142     self.assertEqual(req.port, 23245)
143     self.assertEqual(req.path, "/version")
144     req.success = True
145     req.resp_status_code = http.HTTP_OK
146     req.resp_body = serializer.DumpJson((True, 987))
147
148   def testMultiVersionSuccess(self):
149     nodes = ["node%s" % i for i in range(50)]
150     body = dict((n, "") for n in nodes)
151     resolver = rpc._StaticResolver(nodes)
152     http_proc = _FakeRequestProcessor(self._GetMultiVersionResponse)
153     proc = rpc._RpcProcessor(resolver, 23245)
154     result = proc(nodes, "version", body, 60, NotImplemented,
155                   _req_process_fn=http_proc)
156     self.assertEqual(sorted(result.keys()), sorted(nodes))
157
158     for name in nodes:
159       lhresp = result[name]
160       self.assertFalse(lhresp.offline)
161       self.assertEqual(lhresp.node, name)
162       self.assertFalse(lhresp.fail_msg)
163       self.assertEqual(lhresp.payload, 987)
164       self.assertEqual(lhresp.call, "version")
165       lhresp.Raise("should not raise")
166
167     self.assertEqual(http_proc.reqcount, len(nodes))
168
169   def _GetVersionResponseFail(self, errinfo, req):
170     self.assertEqual(req.path, "/version")
171     req.success = True
172     req.resp_status_code = http.HTTP_OK
173     req.resp_body = serializer.DumpJson((False, errinfo))
174
175   def testVersionFailure(self):
176     resolver = rpc._StaticResolver(["aef9ur4i.example.com"])
177     proc = rpc._RpcProcessor(resolver, 5903)
178     for errinfo in [None, "Unknown error"]:
179       http_proc = \
180         _FakeRequestProcessor(compat.partial(self._GetVersionResponseFail,
181                                              errinfo))
182       host = "aef9ur4i.example.com"
183       body = {host: ""}
184       result = proc(body.keys(), "version", body, 60, NotImplemented,
185                     _req_process_fn=http_proc)
186       self.assertEqual(result.keys(), [host])
187       lhresp = result[host]
188       self.assertFalse(lhresp.offline)
189       self.assertEqual(lhresp.node, host)
190       self.assert_(lhresp.fail_msg)
191       self.assertFalse(lhresp.payload)
192       self.assertEqual(lhresp.call, "version")
193       self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
194       self.assertEqual(http_proc.reqcount, 1)
195
196   def _GetHttpErrorResponse(self, httperrnodes, failnodes, req):
197     self.assertEqual(req.path, "/vg_list")
198     self.assertEqual(req.port, 15165)
199
200     if req.host in httperrnodes:
201       req.success = False
202       req.error = "Node set up for HTTP errors"
203
204     elif req.host in failnodes:
205       req.success = True
206       req.resp_status_code = 404
207       req.resp_body = serializer.DumpJson({
208         "code": 404,
209         "message": "Method not found",
210         "explain": "Explanation goes here",
211         })
212     else:
213       req.success = True
214       req.resp_status_code = http.HTTP_OK
215       req.resp_body = serializer.DumpJson((True, hash(req.host)))
216
217   def testHttpError(self):
218     nodes = ["uaf6pbbv%s" % i for i in range(50)]
219     body = dict((n, "") for n in nodes)
220     resolver = rpc._StaticResolver(nodes)
221
222     httperrnodes = set(nodes[1::7])
223     self.assertEqual(len(httperrnodes), 7)
224
225     failnodes = set(nodes[2::3]) - httperrnodes
226     self.assertEqual(len(failnodes), 14)
227
228     self.assertEqual(len(set(nodes) - failnodes - httperrnodes), 29)
229
230     proc = rpc._RpcProcessor(resolver, 15165)
231     http_proc = \
232       _FakeRequestProcessor(compat.partial(self._GetHttpErrorResponse,
233                                            httperrnodes, failnodes))
234     result = proc(nodes, "vg_list", body, rpc._TMO_URGENT, NotImplemented,
235                   _req_process_fn=http_proc)
236     self.assertEqual(sorted(result.keys()), sorted(nodes))
237
238     for name in nodes:
239       lhresp = result[name]
240       self.assertFalse(lhresp.offline)
241       self.assertEqual(lhresp.node, name)
242       self.assertEqual(lhresp.call, "vg_list")
243
244       if name in httperrnodes:
245         self.assert_(lhresp.fail_msg)
246         self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
247       elif name in failnodes:
248         self.assert_(lhresp.fail_msg)
249         self.assertRaises(errors.OpPrereqError, lhresp.Raise, "failed",
250                           prereq=True, ecode=errors.ECODE_INVAL)
251       else:
252         self.assertFalse(lhresp.fail_msg)
253         self.assertEqual(lhresp.payload, hash(name))
254         lhresp.Raise("should not raise")
255
256     self.assertEqual(http_proc.reqcount, len(nodes))
257
258   def _GetInvalidResponseA(self, req):
259     self.assertEqual(req.path, "/version")
260     req.success = True
261     req.resp_status_code = http.HTTP_OK
262     req.resp_body = serializer.DumpJson(("This", "is", "an", "invalid",
263                                          "response", "!", 1, 2, 3))
264
265   def _GetInvalidResponseB(self, req):
266     self.assertEqual(req.path, "/version")
267     req.success = True
268     req.resp_status_code = http.HTTP_OK
269     req.resp_body = serializer.DumpJson("invalid response")
270
271   def testInvalidResponse(self):
272     resolver = rpc._StaticResolver(["oqo7lanhly.example.com"])
273     proc = rpc._RpcProcessor(resolver, 19978)
274
275     for fn in [self._GetInvalidResponseA, self._GetInvalidResponseB]:
276       http_proc = _FakeRequestProcessor(fn)
277       host = "oqo7lanhly.example.com"
278       body = {host: ""}
279       result = proc([host], "version", body, 60, NotImplemented,
280                     _req_process_fn=http_proc)
281       self.assertEqual(result.keys(), [host])
282       lhresp = result[host]
283       self.assertFalse(lhresp.offline)
284       self.assertEqual(lhresp.node, host)
285       self.assert_(lhresp.fail_msg)
286       self.assertFalse(lhresp.payload)
287       self.assertEqual(lhresp.call, "version")
288       self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
289       self.assertEqual(http_proc.reqcount, 1)
290
291   def _GetBodyTestResponse(self, test_data, req):
292     self.assertEqual(req.host, "192.0.2.84")
293     self.assertEqual(req.port, 18700)
294     self.assertEqual(req.path, "/upload_file")
295     self.assertEqual(serializer.LoadJson(req.post_data), test_data)
296     req.success = True
297     req.resp_status_code = http.HTTP_OK
298     req.resp_body = serializer.DumpJson((True, None))
299
300   def testResponseBody(self):
301     test_data = {
302       "Hello": "World",
303       "xyz": range(10),
304       }
305     resolver = rpc._StaticResolver(["192.0.2.84"])
306     http_proc = _FakeRequestProcessor(compat.partial(self._GetBodyTestResponse,
307                                                      test_data))
308     proc = rpc._RpcProcessor(resolver, 18700)
309     host = "node19759"
310     body = {host: serializer.DumpJson(test_data)}
311     result = proc([host], "upload_file", body, 30, NotImplemented,
312                   _req_process_fn=http_proc)
313     self.assertEqual(result.keys(), [host])
314     lhresp = result[host]
315     self.assertFalse(lhresp.offline)
316     self.assertEqual(lhresp.node, host)
317     self.assertFalse(lhresp.fail_msg)
318     self.assertEqual(lhresp.payload, None)
319     self.assertEqual(lhresp.call, "upload_file")
320     lhresp.Raise("should not raise")
321     self.assertEqual(http_proc.reqcount, 1)
322
323
324 class TestSsconfResolver(unittest.TestCase):
325   def testSsconfLookup(self):
326     addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
327     node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
328     node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
329     ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
330     result = rpc._SsconfResolver(node_list, NotImplemented,
331                                  ssc=ssc, nslookup_fn=NotImplemented)
332     self.assertEqual(result, zip(node_list, addr_list))
333
334   def testNsLookup(self):
335     addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
336     node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
337     ssc = GetFakeSimpleStoreClass(lambda _: [])
338     node_addr_map = dict(zip(node_list, addr_list))
339     nslookup_fn = lambda name, family=None: node_addr_map.get(name)
340     result = rpc._SsconfResolver(node_list, NotImplemented,
341                                  ssc=ssc, nslookup_fn=nslookup_fn)
342     self.assertEqual(result, zip(node_list, addr_list))
343
344   def testBothLookups(self):
345     addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
346     node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
347     n = len(addr_list) / 2
348     node_addr_list = [" ".join(t) for t in zip(node_list[n:], addr_list[n:])]
349     ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
350     node_addr_map = dict(zip(node_list[:n], addr_list[:n]))
351     nslookup_fn = lambda name, family=None: node_addr_map.get(name)
352     result = rpc._SsconfResolver(node_list, NotImplemented,
353                                  ssc=ssc, nslookup_fn=nslookup_fn)
354     self.assertEqual(result, zip(node_list, addr_list))
355
356   def testAddressLookupIPv6(self):
357     addr_list = ["2001:db8::%d" % n for n in range(0, 255, 11)]
358     node_list = ["node%d.example.com" % n for n in range(0, 255, 11)]
359     node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
360     ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
361     result = rpc._SsconfResolver(node_list, NotImplemented,
362                                  ssc=ssc, nslookup_fn=NotImplemented)
363     self.assertEqual(result, zip(node_list, addr_list))
364
365
366 class TestStaticResolver(unittest.TestCase):
367   def test(self):
368     addresses = ["192.0.2.%d" % n for n in range(0, 123, 7)]
369     nodes = ["node%s.example.com" % n for n in range(0, 123, 7)]
370     res = rpc._StaticResolver(addresses)
371     self.assertEqual(res(nodes, NotImplemented), zip(nodes, addresses))
372
373   def testWrongLength(self):
374     res = rpc._StaticResolver([])
375     self.assertRaises(AssertionError, res, ["abc"], NotImplemented)
376
377
378 class TestNodeConfigResolver(unittest.TestCase):
379   @staticmethod
380   def _GetSingleOnlineNode(name):
381     assert name == "node90.example.com"
382     return objects.Node(name=name, offline=False, primary_ip="192.0.2.90")
383
384   @staticmethod
385   def _GetSingleOfflineNode(name):
386     assert name == "node100.example.com"
387     return objects.Node(name=name, offline=True, primary_ip="192.0.2.100")
388
389   def testSingleOnline(self):
390     self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOnlineNode,
391                                              NotImplemented,
392                                              ["node90.example.com"], None),
393                      [("node90.example.com", "192.0.2.90")])
394
395   def testSingleOffline(self):
396     self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOfflineNode,
397                                              NotImplemented,
398                                              ["node100.example.com"], None),
399                      [("node100.example.com", rpc._OFFLINE)])
400
401   def testSingleOfflineWithAcceptOffline(self):
402     fn = self._GetSingleOfflineNode
403     assert fn("node100.example.com").offline
404     self.assertEqual(rpc._NodeConfigResolver(fn, NotImplemented,
405                                              ["node100.example.com"],
406                                              rpc_defs.ACCEPT_OFFLINE_NODE),
407                      [("node100.example.com", "192.0.2.100")])
408     for i in [False, True, "", "Hello", 0, 1]:
409       self.assertRaises(AssertionError, rpc._NodeConfigResolver,
410                         fn, NotImplemented, ["node100.example.com"], i)
411
412   def testUnknownSingleNode(self):
413     self.assertEqual(rpc._NodeConfigResolver(lambda _: None, NotImplemented,
414                                              ["node110.example.com"], None),
415                      [("node110.example.com", "node110.example.com")])
416
417   def testMultiEmpty(self):
418     self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
419                                              lambda: {},
420                                              [], None),
421                      [])
422
423   def testMultiSomeOffline(self):
424     nodes = dict(("node%s.example.com" % i,
425                   objects.Node(name="node%s.example.com" % i,
426                                offline=((i % 3) == 0),
427                                primary_ip="192.0.2.%s" % i))
428                   for i in range(1, 255))
429
430     # Resolve no names
431     self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
432                                              lambda: nodes,
433                                              [], None),
434                      [])
435
436     # Offline, online and unknown hosts
437     self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
438                                              lambda: nodes,
439                                              ["node3.example.com",
440                                               "node92.example.com",
441                                               "node54.example.com",
442                                               "unknown.example.com",],
443                                              None), [
444       ("node3.example.com", rpc._OFFLINE),
445       ("node92.example.com", "192.0.2.92"),
446       ("node54.example.com", rpc._OFFLINE),
447       ("unknown.example.com", "unknown.example.com"),
448       ])
449
450
451 if __name__ == "__main__":
452   testutils.GanetiTestProgram()