Fix RPC unittest
[ganeti-local] / test / ganeti.rpc_unittest.py
1 #!/usr/bin/python
2 #
3
4 # Copyright (C) 2010 Google Inc.
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 # 02110-1301, USA.
20
21
22 """Script for testing ganeti.rpc"""
23
24 import os
25 import sys
26 import unittest
27
28 from ganeti import constants
29 from ganeti import compat
30 from ganeti import rpc
31 from ganeti import http
32 from ganeti import errors
33 from ganeti import serializer
34 from ganeti import objects
35
36 import testutils
37
38
39 class _FakeRequestProcessor:
40   def __init__(self, response_fn):
41     self._response_fn = response_fn
42     self.reqcount = 0
43
44   def __call__(self, reqs, lock_monitor_cb=None):
45     assert lock_monitor_cb is None or callable(lock_monitor_cb)
46     for req in reqs:
47       self.reqcount += 1
48       self._response_fn(req)
49
50
51 def GetFakeSimpleStoreClass(fn):
52   class FakeSimpleStore:
53     GetNodePrimaryIPList = fn
54     GetPrimaryIPFamily = lambda _: None
55
56   return FakeSimpleStore
57
58
59 class TestRpcProcessor(unittest.TestCase):
60   def _FakeAddressLookup(self, map):
61     return lambda node_list: [map.get(node) for node in node_list]
62
63   def _GetVersionResponse(self, req):
64     self.assertEqual(req.host, "127.0.0.1")
65     self.assertEqual(req.port, 24094)
66     self.assertEqual(req.path, "/version")
67     self.assertEqual(req.read_timeout, rpc._TMO_URGENT)
68     req.success = True
69     req.resp_status_code = http.HTTP_OK
70     req.resp_body = serializer.DumpJson((True, 123))
71
72   def testVersionSuccess(self):
73     resolver = rpc._StaticResolver(["127.0.0.1"])
74     http_proc = _FakeRequestProcessor(self._GetVersionResponse)
75     proc = rpc._RpcProcessor(resolver, 24094)
76     result = proc(["localhost"], "version", None, _req_process_fn=http_proc,
77                   read_timeout=60)
78     self.assertEqual(result.keys(), ["localhost"])
79     lhresp = result["localhost"]
80     self.assertFalse(lhresp.offline)
81     self.assertEqual(lhresp.node, "localhost")
82     self.assertFalse(lhresp.fail_msg)
83     self.assertEqual(lhresp.payload, 123)
84     self.assertEqual(lhresp.call, "version")
85     lhresp.Raise("should not raise")
86     self.assertEqual(http_proc.reqcount, 1)
87
88   def _ReadTimeoutResponse(self, req):
89     self.assertEqual(req.host, "192.0.2.13")
90     self.assertEqual(req.port, 19176)
91     self.assertEqual(req.path, "/version")
92     self.assertEqual(req.read_timeout, 12356)
93     req.success = True
94     req.resp_status_code = http.HTTP_OK
95     req.resp_body = serializer.DumpJson((True, -1))
96
97   def testReadTimeout(self):
98     resolver = rpc._StaticResolver(["192.0.2.13"])
99     http_proc = _FakeRequestProcessor(self._ReadTimeoutResponse)
100     proc = rpc._RpcProcessor(resolver, 19176)
101     result = proc(["node31856"], "version", None, _req_process_fn=http_proc,
102                   read_timeout=12356)
103     self.assertEqual(result.keys(), ["node31856"])
104     lhresp = result["node31856"]
105     self.assertFalse(lhresp.offline)
106     self.assertEqual(lhresp.node, "node31856")
107     self.assertFalse(lhresp.fail_msg)
108     self.assertEqual(lhresp.payload, -1)
109     self.assertEqual(lhresp.call, "version")
110     lhresp.Raise("should not raise")
111     self.assertEqual(http_proc.reqcount, 1)
112
113   def testOfflineNode(self):
114     resolver = rpc._StaticResolver([rpc._OFFLINE])
115     http_proc = _FakeRequestProcessor(NotImplemented)
116     proc = rpc._RpcProcessor(resolver, 30668)
117     result = proc(["n17296"], "version", None, _req_process_fn=http_proc,
118                   read_timeout=60)
119     self.assertEqual(result.keys(), ["n17296"])
120     lhresp = result["n17296"]
121     self.assertTrue(lhresp.offline)
122     self.assertEqual(lhresp.node, "n17296")
123     self.assertTrue(lhresp.fail_msg)
124     self.assertFalse(lhresp.payload)
125     self.assertEqual(lhresp.call, "version")
126
127     # With a message
128     self.assertRaises(errors.OpExecError, lhresp.Raise, "should raise")
129
130     # No message
131     self.assertRaises(errors.OpExecError, lhresp.Raise, None)
132
133     self.assertEqual(http_proc.reqcount, 0)
134
135   def _GetMultiVersionResponse(self, req):
136     self.assert_(req.host.startswith("node"))
137     self.assertEqual(req.port, 23245)
138     self.assertEqual(req.path, "/version")
139     req.success = True
140     req.resp_status_code = http.HTTP_OK
141     req.resp_body = serializer.DumpJson((True, 987))
142
143   def testMultiVersionSuccess(self):
144     nodes = ["node%s" % i for i in range(50)]
145     resolver = rpc._StaticResolver(nodes)
146     http_proc = _FakeRequestProcessor(self._GetMultiVersionResponse)
147     proc = rpc._RpcProcessor(resolver, 23245)
148     result = proc(nodes, "version", None, _req_process_fn=http_proc,
149                   read_timeout=60)
150     self.assertEqual(sorted(result.keys()), sorted(nodes))
151
152     for name in nodes:
153       lhresp = result[name]
154       self.assertFalse(lhresp.offline)
155       self.assertEqual(lhresp.node, name)
156       self.assertFalse(lhresp.fail_msg)
157       self.assertEqual(lhresp.payload, 987)
158       self.assertEqual(lhresp.call, "version")
159       lhresp.Raise("should not raise")
160
161     self.assertEqual(http_proc.reqcount, len(nodes))
162
163   def _GetVersionResponseFail(self, errinfo, req):
164     self.assertEqual(req.path, "/version")
165     req.success = True
166     req.resp_status_code = http.HTTP_OK
167     req.resp_body = serializer.DumpJson((False, errinfo))
168
169   def testVersionFailure(self):
170     resolver = rpc._StaticResolver(["aef9ur4i.example.com"])
171     proc = rpc._RpcProcessor(resolver, 5903)
172     for errinfo in [None, "Unknown error"]:
173       http_proc = \
174         _FakeRequestProcessor(compat.partial(self._GetVersionResponseFail,
175                                              errinfo))
176       result = proc(["aef9ur4i.example.com"], "version", None,
177                     _req_process_fn=http_proc, read_timeout=60)
178       self.assertEqual(result.keys(), ["aef9ur4i.example.com"])
179       lhresp = result["aef9ur4i.example.com"]
180       self.assertFalse(lhresp.offline)
181       self.assertEqual(lhresp.node, "aef9ur4i.example.com")
182       self.assert_(lhresp.fail_msg)
183       self.assertFalse(lhresp.payload)
184       self.assertEqual(lhresp.call, "version")
185       self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
186       self.assertEqual(http_proc.reqcount, 1)
187
188   def _GetHttpErrorResponse(self, httperrnodes, failnodes, req):
189     self.assertEqual(req.path, "/vg_list")
190     self.assertEqual(req.port, 15165)
191
192     if req.host in httperrnodes:
193       req.success = False
194       req.error = "Node set up for HTTP errors"
195
196     elif req.host in failnodes:
197       req.success = True
198       req.resp_status_code = 404
199       req.resp_body = serializer.DumpJson({
200         "code": 404,
201         "message": "Method not found",
202         "explain": "Explanation goes here",
203         })
204     else:
205       req.success = True
206       req.resp_status_code = http.HTTP_OK
207       req.resp_body = serializer.DumpJson((True, hash(req.host)))
208
209   def testHttpError(self):
210     nodes = ["uaf6pbbv%s" % i for i in range(50)]
211     resolver = rpc._StaticResolver(nodes)
212
213     httperrnodes = set(nodes[1::7])
214     self.assertEqual(len(httperrnodes), 7)
215
216     failnodes = set(nodes[2::3]) - httperrnodes
217     self.assertEqual(len(failnodes), 14)
218
219     self.assertEqual(len(set(nodes) - failnodes - httperrnodes), 29)
220
221     proc = rpc._RpcProcessor(resolver, 15165)
222     http_proc = \
223       _FakeRequestProcessor(compat.partial(self._GetHttpErrorResponse,
224                                            httperrnodes, failnodes))
225     result = proc(nodes, "vg_list", None, _req_process_fn=http_proc,
226                   read_timeout=rpc._TMO_URGENT)
227     self.assertEqual(sorted(result.keys()), sorted(nodes))
228
229     for name in nodes:
230       lhresp = result[name]
231       self.assertFalse(lhresp.offline)
232       self.assertEqual(lhresp.node, name)
233       self.assertEqual(lhresp.call, "vg_list")
234
235       if name in httperrnodes:
236         self.assert_(lhresp.fail_msg)
237         self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
238       elif name in failnodes:
239         self.assert_(lhresp.fail_msg)
240         self.assertRaises(errors.OpPrereqError, lhresp.Raise, "failed",
241                           prereq=True, ecode=errors.ECODE_INVAL)
242       else:
243         self.assertFalse(lhresp.fail_msg)
244         self.assertEqual(lhresp.payload, hash(name))
245         lhresp.Raise("should not raise")
246
247     self.assertEqual(http_proc.reqcount, len(nodes))
248
249   def _GetInvalidResponseA(self, req):
250     self.assertEqual(req.path, "/version")
251     req.success = True
252     req.resp_status_code = http.HTTP_OK
253     req.resp_body = serializer.DumpJson(("This", "is", "an", "invalid",
254                                          "response", "!", 1, 2, 3))
255
256   def _GetInvalidResponseB(self, req):
257     self.assertEqual(req.path, "/version")
258     req.success = True
259     req.resp_status_code = http.HTTP_OK
260     req.resp_body = serializer.DumpJson("invalid response")
261
262   def testInvalidResponse(self):
263     resolver = rpc._StaticResolver(["oqo7lanhly.example.com"])
264     proc = rpc._RpcProcessor(resolver, 19978)
265
266     for fn in [self._GetInvalidResponseA, self._GetInvalidResponseB]:
267       http_proc = _FakeRequestProcessor(fn)
268       result = proc(["oqo7lanhly.example.com"], "version", None,
269                     _req_process_fn=http_proc, read_timeout=60)
270       self.assertEqual(result.keys(), ["oqo7lanhly.example.com"])
271       lhresp = result["oqo7lanhly.example.com"]
272       self.assertFalse(lhresp.offline)
273       self.assertEqual(lhresp.node, "oqo7lanhly.example.com")
274       self.assert_(lhresp.fail_msg)
275       self.assertFalse(lhresp.payload)
276       self.assertEqual(lhresp.call, "version")
277       self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
278       self.assertEqual(http_proc.reqcount, 1)
279
280   def _GetBodyTestResponse(self, test_data, req):
281     self.assertEqual(req.host, "192.0.2.84")
282     self.assertEqual(req.port, 18700)
283     self.assertEqual(req.path, "/upload_file")
284     self.assertEqual(serializer.LoadJson(req.post_data), test_data)
285     req.success = True
286     req.resp_status_code = http.HTTP_OK
287     req.resp_body = serializer.DumpJson((True, None))
288
289   def testResponseBody(self):
290     test_data = {
291       "Hello": "World",
292       "xyz": range(10),
293       }
294     resolver = rpc._StaticResolver(["192.0.2.84"])
295     http_proc = _FakeRequestProcessor(compat.partial(self._GetBodyTestResponse,
296                                                      test_data))
297     proc = rpc._RpcProcessor(resolver, 18700)
298     body = serializer.DumpJson(test_data)
299     result = proc(["node19759"], "upload_file", body, _req_process_fn=http_proc,
300                   read_timeout=30)
301     self.assertEqual(result.keys(), ["node19759"])
302     lhresp = result["node19759"]
303     self.assertFalse(lhresp.offline)
304     self.assertEqual(lhresp.node, "node19759")
305     self.assertFalse(lhresp.fail_msg)
306     self.assertEqual(lhresp.payload, None)
307     self.assertEqual(lhresp.call, "upload_file")
308     lhresp.Raise("should not raise")
309     self.assertEqual(http_proc.reqcount, 1)
310
311
312 class TestSsconfResolver(unittest.TestCase):
313   def testSsconfLookup(self):
314     addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
315     node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
316     node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
317     ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
318     result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=NotImplemented)
319     self.assertEqual(result, zip(node_list, addr_list))
320
321   def testNsLookup(self):
322     addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
323     node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
324     ssc = GetFakeSimpleStoreClass(lambda _: [])
325     node_addr_map = dict(zip(node_list, addr_list))
326     nslookup_fn = lambda name, family=None: node_addr_map.get(name)
327     result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=nslookup_fn)
328     self.assertEqual(result, zip(node_list, addr_list))
329
330   def testBothLookups(self):
331     addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
332     node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
333     n = len(addr_list) / 2
334     node_addr_list = [" ".join(t) for t in zip(node_list[n:], addr_list[n:])]
335     ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
336     node_addr_map = dict(zip(node_list[:n], addr_list[:n]))
337     nslookup_fn = lambda name, family=None: node_addr_map.get(name)
338     result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=nslookup_fn)
339     self.assertEqual(result, zip(node_list, addr_list))
340
341   def testAddressLookupIPv6(self):
342     addr_list = ["2001:db8::%d" % n for n in range(0, 255, 11)]
343     node_list = ["node%d.example.com" % n for n in range(0, 255, 11)]
344     node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
345     ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
346     result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=NotImplemented)
347     self.assertEqual(result, zip(node_list, addr_list))
348
349
350 class TestStaticResolver(unittest.TestCase):
351   def test(self):
352     addresses = ["192.0.2.%d" % n for n in range(0, 123, 7)]
353     nodes = ["node%s.example.com" % n for n in range(0, 123, 7)]
354     res = rpc._StaticResolver(addresses)
355     self.assertEqual(res(nodes), zip(nodes, addresses))
356
357   def testWrongLength(self):
358     res = rpc._StaticResolver([])
359     self.assertRaises(AssertionError, res, ["abc"])
360
361
362 class TestNodeConfigResolver(unittest.TestCase):
363   @staticmethod
364   def _GetSingleOnlineNode(name):
365     assert name == "node90.example.com"
366     return objects.Node(name=name, offline=False, primary_ip="192.0.2.90")
367
368   @staticmethod
369   def _GetSingleOfflineNode(name):
370     assert name == "node100.example.com"
371     return objects.Node(name=name, offline=True, primary_ip="192.0.2.100")
372
373   def testSingleOnline(self):
374     self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOnlineNode,
375                                              NotImplemented,
376                                              ["node90.example.com"]),
377                      [("node90.example.com", "192.0.2.90")])
378
379   def testSingleOffline(self):
380     self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOfflineNode,
381                                              NotImplemented,
382                                              ["node100.example.com"]),
383                      [("node100.example.com", rpc._OFFLINE)])
384
385   def testUnknownSingleNode(self):
386     self.assertEqual(rpc._NodeConfigResolver(lambda _: None, NotImplemented,
387                                              ["node110.example.com"]),
388                      [("node110.example.com", "node110.example.com")])
389
390   def testMultiEmpty(self):
391     self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
392                                              lambda: {},
393                                              []),
394                      [])
395
396   def testMultiSomeOffline(self):
397     nodes = dict(("node%s.example.com" % i,
398                   objects.Node(name="node%s.example.com" % i,
399                                offline=((i % 3) == 0),
400                                primary_ip="192.0.2.%s" % i))
401                   for i in range(1, 255))
402
403     # Resolve no names
404     self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
405                                              lambda: nodes,
406                                              []),
407                      [])
408
409     # Offline, online and unknown hosts
410     self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
411                                              lambda: nodes,
412                                              ["node3.example.com",
413                                               "node92.example.com",
414                                               "node54.example.com",
415                                               "unknown.example.com",]), [
416       ("node3.example.com", rpc._OFFLINE),
417       ("node92.example.com", "192.0.2.92"),
418       ("node54.example.com", rpc._OFFLINE),
419       ("unknown.example.com", "unknown.example.com"),
420       ])
421
422
423 if __name__ == "__main__":
424   testutils.GanetiTestProgram()