rpc: Disable timeout check
[ganeti-local] / test / ganeti.rpc_unittest.py
1 #!/usr/bin/python
2 #
3
4 # Copyright (C) 2010 Google Inc.
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 # 02110-1301, USA.
20
21
22 """Script for testing ganeti.rpc"""
23
24 import os
25 import sys
26 import unittest
27
28 from ganeti import constants
29 from ganeti import compat
30 from ganeti import rpc
31 from ganeti import http
32 from ganeti import errors
33 from ganeti import serializer
34 from ganeti import objects
35
36 import testutils
37
38
39 class _FakeRequestProcessor:
40   def __init__(self, response_fn):
41     self._response_fn = response_fn
42     self.reqcount = 0
43
44   def __call__(self, reqs, lock_monitor_cb=None):
45     assert lock_monitor_cb is None or callable(lock_monitor_cb)
46     for req in reqs:
47       self.reqcount += 1
48       self._response_fn(req)
49
50
51 def GetFakeSimpleStoreClass(fn):
52   class FakeSimpleStore:
53     GetNodePrimaryIPList = fn
54     GetPrimaryIPFamily = lambda _: None
55
56   return FakeSimpleStore
57
58
59 class TestRpcProcessor(unittest.TestCase):
60   def _FakeAddressLookup(self, map):
61     return lambda node_list: [map.get(node) for node in node_list]
62
63   def _GetVersionResponse(self, req):
64     self.assertEqual(req.host, "127.0.0.1")
65     self.assertEqual(req.port, 24094)
66     self.assertEqual(req.path, "/version")
67     self.assertEqual(req.read_timeout, rpc._TMO_URGENT)
68     req.success = True
69     req.resp_status_code = http.HTTP_OK
70     req.resp_body = serializer.DumpJson((True, 123))
71
72   def testVersionSuccess(self):
73     resolver = rpc._StaticResolver(["127.0.0.1"])
74     http_proc = _FakeRequestProcessor(self._GetVersionResponse)
75     proc = rpc._RpcProcessor(resolver, 24094)
76     result = proc(["localhost"], "version", None, _req_process_fn=http_proc)
77     self.assertEqual(result.keys(), ["localhost"])
78     lhresp = result["localhost"]
79     self.assertFalse(lhresp.offline)
80     self.assertEqual(lhresp.node, "localhost")
81     self.assertFalse(lhresp.fail_msg)
82     self.assertEqual(lhresp.payload, 123)
83     self.assertEqual(lhresp.call, "version")
84     lhresp.Raise("should not raise")
85     self.assertEqual(http_proc.reqcount, 1)
86
87   def _ReadTimeoutResponse(self, req):
88     self.assertEqual(req.host, "192.0.2.13")
89     self.assertEqual(req.port, 19176)
90     self.assertEqual(req.path, "/version")
91     self.assertEqual(req.read_timeout, 12356)
92     req.success = True
93     req.resp_status_code = http.HTTP_OK
94     req.resp_body = serializer.DumpJson((True, -1))
95
96   def testReadTimeout(self):
97     resolver = rpc._StaticResolver(["192.0.2.13"])
98     http_proc = _FakeRequestProcessor(self._ReadTimeoutResponse)
99     proc = rpc._RpcProcessor(resolver, 19176)
100     result = proc(["node31856"], "version", None, _req_process_fn=http_proc,
101                   read_timeout=12356)
102     self.assertEqual(result.keys(), ["node31856"])
103     lhresp = result["node31856"]
104     self.assertFalse(lhresp.offline)
105     self.assertEqual(lhresp.node, "node31856")
106     self.assertFalse(lhresp.fail_msg)
107     self.assertEqual(lhresp.payload, -1)
108     self.assertEqual(lhresp.call, "version")
109     lhresp.Raise("should not raise")
110     self.assertEqual(http_proc.reqcount, 1)
111
112   def testOfflineNode(self):
113     resolver = rpc._StaticResolver([rpc._OFFLINE])
114     http_proc = _FakeRequestProcessor(NotImplemented)
115     proc = rpc._RpcProcessor(resolver, 30668)
116     result = proc(["n17296"], "version", None, _req_process_fn=http_proc)
117     self.assertEqual(result.keys(), ["n17296"])
118     lhresp = result["n17296"]
119     self.assertTrue(lhresp.offline)
120     self.assertEqual(lhresp.node, "n17296")
121     self.assertTrue(lhresp.fail_msg)
122     self.assertFalse(lhresp.payload)
123     self.assertEqual(lhresp.call, "version")
124
125     # With a message
126     self.assertRaises(errors.OpExecError, lhresp.Raise, "should raise")
127
128     # No message
129     self.assertRaises(errors.OpExecError, lhresp.Raise, None)
130
131     self.assertEqual(http_proc.reqcount, 0)
132
133   def _GetMultiVersionResponse(self, req):
134     self.assert_(req.host.startswith("node"))
135     self.assertEqual(req.port, 23245)
136     self.assertEqual(req.path, "/version")
137     req.success = True
138     req.resp_status_code = http.HTTP_OK
139     req.resp_body = serializer.DumpJson((True, 987))
140
141   def testMultiVersionSuccess(self):
142     nodes = ["node%s" % i for i in range(50)]
143     resolver = rpc._StaticResolver(nodes)
144     http_proc = _FakeRequestProcessor(self._GetMultiVersionResponse)
145     proc = rpc._RpcProcessor(resolver, 23245)
146     result = proc(nodes, "version", None, _req_process_fn=http_proc)
147     self.assertEqual(sorted(result.keys()), sorted(nodes))
148
149     for name in nodes:
150       lhresp = result[name]
151       self.assertFalse(lhresp.offline)
152       self.assertEqual(lhresp.node, name)
153       self.assertFalse(lhresp.fail_msg)
154       self.assertEqual(lhresp.payload, 987)
155       self.assertEqual(lhresp.call, "version")
156       lhresp.Raise("should not raise")
157
158     self.assertEqual(http_proc.reqcount, len(nodes))
159
160   def _GetVersionResponseFail(self, errinfo, req):
161     self.assertEqual(req.path, "/version")
162     req.success = True
163     req.resp_status_code = http.HTTP_OK
164     req.resp_body = serializer.DumpJson((False, errinfo))
165
166   def testVersionFailure(self):
167     resolver = rpc._StaticResolver(["aef9ur4i.example.com"])
168     proc = rpc._RpcProcessor(resolver, 5903)
169     for errinfo in [None, "Unknown error"]:
170       http_proc = \
171         _FakeRequestProcessor(compat.partial(self._GetVersionResponseFail,
172                                              errinfo))
173       result = proc(["aef9ur4i.example.com"], "version", None,
174                     _req_process_fn=http_proc)
175       self.assertEqual(result.keys(), ["aef9ur4i.example.com"])
176       lhresp = result["aef9ur4i.example.com"]
177       self.assertFalse(lhresp.offline)
178       self.assertEqual(lhresp.node, "aef9ur4i.example.com")
179       self.assert_(lhresp.fail_msg)
180       self.assertFalse(lhresp.payload)
181       self.assertEqual(lhresp.call, "version")
182       self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
183       self.assertEqual(http_proc.reqcount, 1)
184
185   def _GetHttpErrorResponse(self, httperrnodes, failnodes, req):
186     self.assertEqual(req.path, "/vg_list")
187     self.assertEqual(req.port, 15165)
188
189     if req.host in httperrnodes:
190       req.success = False
191       req.error = "Node set up for HTTP errors"
192
193     elif req.host in failnodes:
194       req.success = True
195       req.resp_status_code = 404
196       req.resp_body = serializer.DumpJson({
197         "code": 404,
198         "message": "Method not found",
199         "explain": "Explanation goes here",
200         })
201     else:
202       req.success = True
203       req.resp_status_code = http.HTTP_OK
204       req.resp_body = serializer.DumpJson((True, hash(req.host)))
205
206   def testHttpError(self):
207     nodes = ["uaf6pbbv%s" % i for i in range(50)]
208     resolver = rpc._StaticResolver(nodes)
209
210     httperrnodes = set(nodes[1::7])
211     self.assertEqual(len(httperrnodes), 7)
212
213     failnodes = set(nodes[2::3]) - httperrnodes
214     self.assertEqual(len(failnodes), 14)
215
216     self.assertEqual(len(set(nodes) - failnodes - httperrnodes), 29)
217
218     proc = rpc._RpcProcessor(resolver, 15165)
219     http_proc = \
220       _FakeRequestProcessor(compat.partial(self._GetHttpErrorResponse,
221                                            httperrnodes, failnodes))
222     result = proc(nodes, "vg_list", None, _req_process_fn=http_proc,
223                   read_timeout=rpc._TMO_URGENT)
224     self.assertEqual(sorted(result.keys()), sorted(nodes))
225
226     for name in nodes:
227       lhresp = result[name]
228       self.assertFalse(lhresp.offline)
229       self.assertEqual(lhresp.node, name)
230       self.assertEqual(lhresp.call, "vg_list")
231
232       if name in httperrnodes:
233         self.assert_(lhresp.fail_msg)
234         self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
235       elif name in failnodes:
236         self.assert_(lhresp.fail_msg)
237         self.assertRaises(errors.OpPrereqError, lhresp.Raise, "failed",
238                           prereq=True, ecode=errors.ECODE_INVAL)
239       else:
240         self.assertFalse(lhresp.fail_msg)
241         self.assertEqual(lhresp.payload, hash(name))
242         lhresp.Raise("should not raise")
243
244     self.assertEqual(http_proc.reqcount, len(nodes))
245
246   def _GetInvalidResponseA(self, req):
247     self.assertEqual(req.path, "/version")
248     req.success = True
249     req.resp_status_code = http.HTTP_OK
250     req.resp_body = serializer.DumpJson(("This", "is", "an", "invalid",
251                                          "response", "!", 1, 2, 3))
252
253   def _GetInvalidResponseB(self, req):
254     self.assertEqual(req.path, "/version")
255     req.success = True
256     req.resp_status_code = http.HTTP_OK
257     req.resp_body = serializer.DumpJson("invalid response")
258
259   def testInvalidResponse(self):
260     resolver = rpc._StaticResolver(["oqo7lanhly.example.com"])
261     proc = rpc._RpcProcessor(resolver, 19978)
262
263     for fn in [self._GetInvalidResponseA, self._GetInvalidResponseB]:
264       http_proc = _FakeRequestProcessor(fn)
265       result = proc(["oqo7lanhly.example.com"], "version", None,
266                     _req_process_fn=http_proc)
267       self.assertEqual(result.keys(), ["oqo7lanhly.example.com"])
268       lhresp = result["oqo7lanhly.example.com"]
269       self.assertFalse(lhresp.offline)
270       self.assertEqual(lhresp.node, "oqo7lanhly.example.com")
271       self.assert_(lhresp.fail_msg)
272       self.assertFalse(lhresp.payload)
273       self.assertEqual(lhresp.call, "version")
274       self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
275       self.assertEqual(http_proc.reqcount, 1)
276
277   def _GetBodyTestResponse(self, test_data, req):
278     self.assertEqual(req.host, "192.0.2.84")
279     self.assertEqual(req.port, 18700)
280     self.assertEqual(req.path, "/upload_file")
281     self.assertEqual(serializer.LoadJson(req.post_data), test_data)
282     req.success = True
283     req.resp_status_code = http.HTTP_OK
284     req.resp_body = serializer.DumpJson((True, None))
285
286   def testResponseBody(self):
287     test_data = {
288       "Hello": "World",
289       "xyz": range(10),
290       }
291     resolver = rpc._StaticResolver(["192.0.2.84"])
292     http_proc = _FakeRequestProcessor(compat.partial(self._GetBodyTestResponse,
293                                                      test_data))
294     proc = rpc._RpcProcessor(resolver, 18700)
295     body = serializer.DumpJson(test_data)
296     result = proc(["node19759"], "upload_file", body, _req_process_fn=http_proc)
297     self.assertEqual(result.keys(), ["node19759"])
298     lhresp = result["node19759"]
299     self.assertFalse(lhresp.offline)
300     self.assertEqual(lhresp.node, "node19759")
301     self.assertFalse(lhresp.fail_msg)
302     self.assertEqual(lhresp.payload, None)
303     self.assertEqual(lhresp.call, "upload_file")
304     lhresp.Raise("should not raise")
305     self.assertEqual(http_proc.reqcount, 1)
306
307
308 class TestSsconfResolver(unittest.TestCase):
309   def testSsconfLookup(self):
310     addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
311     node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
312     node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
313     ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
314     result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=NotImplemented)
315     self.assertEqual(result, zip(node_list, addr_list))
316
317   def testNsLookup(self):
318     addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
319     node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
320     ssc = GetFakeSimpleStoreClass(lambda _: [])
321     node_addr_map = dict(zip(node_list, addr_list))
322     nslookup_fn = lambda name, family=None: node_addr_map.get(name)
323     result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=nslookup_fn)
324     self.assertEqual(result, zip(node_list, addr_list))
325
326   def testBothLookups(self):
327     addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
328     node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
329     n = len(addr_list) / 2
330     node_addr_list = [" ".join(t) for t in zip(node_list[n:], addr_list[n:])]
331     ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
332     node_addr_map = dict(zip(node_list[:n], addr_list[:n]))
333     nslookup_fn = lambda name, family=None: node_addr_map.get(name)
334     result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=nslookup_fn)
335     self.assertEqual(result, zip(node_list, addr_list))
336
337   def testAddressLookupIPv6(self):
338     addr_list = ["2001:db8::%d" % n for n in range(0, 255, 11)]
339     node_list = ["node%d.example.com" % n for n in range(0, 255, 11)]
340     node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
341     ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
342     result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=NotImplemented)
343     self.assertEqual(result, zip(node_list, addr_list))
344
345
346 class TestStaticResolver(unittest.TestCase):
347   def test(self):
348     addresses = ["192.0.2.%d" % n for n in range(0, 123, 7)]
349     nodes = ["node%s.example.com" % n for n in range(0, 123, 7)]
350     res = rpc._StaticResolver(addresses)
351     self.assertEqual(res(nodes), zip(nodes, addresses))
352
353   def testWrongLength(self):
354     res = rpc._StaticResolver([])
355     self.assertRaises(AssertionError, res, ["abc"])
356
357
358 class TestNodeConfigResolver(unittest.TestCase):
359   @staticmethod
360   def _GetSingleOnlineNode(name):
361     assert name == "node90.example.com"
362     return objects.Node(name=name, offline=False, primary_ip="192.0.2.90")
363
364   @staticmethod
365   def _GetSingleOfflineNode(name):
366     assert name == "node100.example.com"
367     return objects.Node(name=name, offline=True, primary_ip="192.0.2.100")
368
369   def testSingleOnline(self):
370     self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOnlineNode,
371                                              NotImplemented,
372                                              ["node90.example.com"]),
373                      [("node90.example.com", "192.0.2.90")])
374
375   def testSingleOffline(self):
376     self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOfflineNode,
377                                              NotImplemented,
378                                              ["node100.example.com"]),
379                      [("node100.example.com", rpc._OFFLINE)])
380
381   def testUnknownSingleNode(self):
382     self.assertEqual(rpc._NodeConfigResolver(lambda _: None, NotImplemented,
383                                              ["node110.example.com"]),
384                      [("node110.example.com", "node110.example.com")])
385
386   def testMultiEmpty(self):
387     self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
388                                              lambda: {},
389                                              []),
390                      [])
391
392   def testMultiSomeOffline(self):
393     nodes = dict(("node%s.example.com" % i,
394                   objects.Node(name="node%s.example.com" % i,
395                                offline=((i % 3) == 0),
396                                primary_ip="192.0.2.%s" % i))
397                   for i in range(1, 255))
398
399     # Resolve no names
400     self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
401                                              lambda: nodes,
402                                              []),
403                      [])
404
405     # Offline, online and unknown hosts
406     self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
407                                              lambda: nodes,
408                                              ["node3.example.com",
409                                               "node92.example.com",
410                                               "node54.example.com",
411                                               "unknown.example.com",]), [
412       ("node3.example.com", rpc._OFFLINE),
413       ("node92.example.com", "192.0.2.92"),
414       ("node54.example.com", rpc._OFFLINE),
415       ("unknown.example.com", "unknown.example.com"),
416       ])
417
418
419 if __name__ == "__main__":
420   testutils.GanetiTestProgram()