rpc: Overhaul client structure
[ganeti-local] / test / ganeti.rpc_unittest.py
1 #!/usr/bin/python
2 #
3
4 # Copyright (C) 2010 Google Inc.
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 # 02110-1301, USA.
20
21
22 """Script for testing ganeti.rpc"""
23
24 import os
25 import sys
26 import unittest
27
28 from ganeti import constants
29 from ganeti import compat
30 from ganeti import rpc
31 from ganeti import http
32 from ganeti import errors
33 from ganeti import serializer
34 from ganeti import objects
35
36 import testutils
37
38
39 class TestTimeouts(unittest.TestCase):
40   def test(self):
41     names = [name[len("call_"):] for name in dir(rpc.RpcRunner)
42              if name.startswith("call_")]
43     self.assertEqual(len(names), len(rpc._TIMEOUTS))
44     self.assertFalse([name for name in names
45                       if not (rpc._TIMEOUTS[name] is None or
46                               rpc._TIMEOUTS[name] > 0)])
47
48
49 class FakeHttpPool:
50   def __init__(self, response_fn):
51     self._response_fn = response_fn
52     self.reqcount = 0
53
54   def ProcessRequests(self, reqs):
55     for req in reqs:
56       self.reqcount += 1
57       self._response_fn(req)
58
59
60 def GetFakeSimpleStoreClass(fn):
61   class FakeSimpleStore:
62     GetNodePrimaryIPList = fn
63     GetPrimaryIPFamily = lambda _: None
64
65   return FakeSimpleStore
66
67
68 class TestRpcProcessor(unittest.TestCase):
69   def _FakeAddressLookup(self, map):
70     return lambda node_list: [map.get(node) for node in node_list]
71
72   def _GetVersionResponse(self, req):
73     self.assertEqual(req.host, "127.0.0.1")
74     self.assertEqual(req.port, 24094)
75     self.assertEqual(req.path, "/version")
76     self.assertEqual(req.read_timeout, rpc._TMO_URGENT)
77     req.success = True
78     req.resp_status_code = http.HTTP_OK
79     req.resp_body = serializer.DumpJson((True, 123))
80
81   def testVersionSuccess(self):
82     resolver = rpc._StaticResolver(["127.0.0.1"])
83     pool = FakeHttpPool(self._GetVersionResponse)
84     proc = rpc._RpcProcessor(resolver, 24094)
85     result = proc(["localhost"], "version", None, http_pool=pool)
86     self.assertEqual(result.keys(), ["localhost"])
87     lhresp = result["localhost"]
88     self.assertFalse(lhresp.offline)
89     self.assertEqual(lhresp.node, "localhost")
90     self.assertFalse(lhresp.fail_msg)
91     self.assertEqual(lhresp.payload, 123)
92     self.assertEqual(lhresp.call, "version")
93     lhresp.Raise("should not raise")
94     self.assertEqual(pool.reqcount, 1)
95
96   def _ReadTimeoutResponse(self, req):
97     self.assertEqual(req.host, "192.0.2.13")
98     self.assertEqual(req.port, 19176)
99     self.assertEqual(req.path, "/version")
100     self.assertEqual(req.read_timeout, 12356)
101     req.success = True
102     req.resp_status_code = http.HTTP_OK
103     req.resp_body = serializer.DumpJson((True, -1))
104
105   def testReadTimeout(self):
106     resolver = rpc._StaticResolver(["192.0.2.13"])
107     pool = FakeHttpPool(self._ReadTimeoutResponse)
108     proc = rpc._RpcProcessor(resolver, 19176)
109     result = proc(["node31856"], "version", None, http_pool=pool,
110                   read_timeout=12356)
111     self.assertEqual(result.keys(), ["node31856"])
112     lhresp = result["node31856"]
113     self.assertFalse(lhresp.offline)
114     self.assertEqual(lhresp.node, "node31856")
115     self.assertFalse(lhresp.fail_msg)
116     self.assertEqual(lhresp.payload, -1)
117     self.assertEqual(lhresp.call, "version")
118     lhresp.Raise("should not raise")
119     self.assertEqual(pool.reqcount, 1)
120
121   def testOfflineNode(self):
122     resolver = rpc._StaticResolver([rpc._OFFLINE])
123     pool = FakeHttpPool(NotImplemented)
124     proc = rpc._RpcProcessor(resolver, 30668)
125     result = proc(["n17296"], "version", None, http_pool=pool)
126     self.assertEqual(result.keys(), ["n17296"])
127     lhresp = result["n17296"]
128     self.assertTrue(lhresp.offline)
129     self.assertEqual(lhresp.node, "n17296")
130     self.assertTrue(lhresp.fail_msg)
131     self.assertFalse(lhresp.payload)
132     self.assertEqual(lhresp.call, "version")
133
134     # With a message
135     self.assertRaises(errors.OpExecError, lhresp.Raise, "should raise")
136
137     # No message
138     self.assertRaises(errors.OpExecError, lhresp.Raise, None)
139
140     self.assertEqual(pool.reqcount, 0)
141
142   def _GetMultiVersionResponse(self, req):
143     self.assert_(req.host.startswith("node"))
144     self.assertEqual(req.port, 23245)
145     self.assertEqual(req.path, "/version")
146     req.success = True
147     req.resp_status_code = http.HTTP_OK
148     req.resp_body = serializer.DumpJson((True, 987))
149
150   def testMultiVersionSuccess(self):
151     nodes = ["node%s" % i for i in range(50)]
152     resolver = rpc._StaticResolver(nodes)
153     pool = FakeHttpPool(self._GetMultiVersionResponse)
154     proc = rpc._RpcProcessor(resolver, 23245)
155     result = proc(nodes, "version", None, http_pool=pool)
156     self.assertEqual(sorted(result.keys()), sorted(nodes))
157
158     for name in nodes:
159       lhresp = result[name]
160       self.assertFalse(lhresp.offline)
161       self.assertEqual(lhresp.node, name)
162       self.assertFalse(lhresp.fail_msg)
163       self.assertEqual(lhresp.payload, 987)
164       self.assertEqual(lhresp.call, "version")
165       lhresp.Raise("should not raise")
166
167     self.assertEqual(pool.reqcount, len(nodes))
168
169   def _GetVersionResponseFail(self, errinfo, req):
170     self.assertEqual(req.path, "/version")
171     req.success = True
172     req.resp_status_code = http.HTTP_OK
173     req.resp_body = serializer.DumpJson((False, errinfo))
174
175   def testVersionFailure(self):
176     resolver = rpc._StaticResolver(["aef9ur4i.example.com"])
177     proc = rpc._RpcProcessor(resolver, 5903)
178     for errinfo in [None, "Unknown error"]:
179       pool = FakeHttpPool(compat.partial(self._GetVersionResponseFail, errinfo))
180       result = proc(["aef9ur4i.example.com"], "version", None, http_pool=pool)
181       self.assertEqual(result.keys(), ["aef9ur4i.example.com"])
182       lhresp = result["aef9ur4i.example.com"]
183       self.assertFalse(lhresp.offline)
184       self.assertEqual(lhresp.node, "aef9ur4i.example.com")
185       self.assert_(lhresp.fail_msg)
186       self.assertFalse(lhresp.payload)
187       self.assertEqual(lhresp.call, "version")
188       self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
189       self.assertEqual(pool.reqcount, 1)
190
191   def _GetHttpErrorResponse(self, httperrnodes, failnodes, req):
192     self.assertEqual(req.path, "/vg_list")
193     self.assertEqual(req.port, 15165)
194
195     if req.host in httperrnodes:
196       req.success = False
197       req.error = "Node set up for HTTP errors"
198
199     elif req.host in failnodes:
200       req.success = True
201       req.resp_status_code = 404
202       req.resp_body = serializer.DumpJson({
203         "code": 404,
204         "message": "Method not found",
205         "explain": "Explanation goes here",
206         })
207     else:
208       req.success = True
209       req.resp_status_code = http.HTTP_OK
210       req.resp_body = serializer.DumpJson((True, hash(req.host)))
211
212   def testHttpError(self):
213     nodes = ["uaf6pbbv%s" % i for i in range(50)]
214     resolver = rpc._StaticResolver(nodes)
215
216     httperrnodes = set(nodes[1::7])
217     self.assertEqual(len(httperrnodes), 7)
218
219     failnodes = set(nodes[2::3]) - httperrnodes
220     self.assertEqual(len(failnodes), 14)
221
222     self.assertEqual(len(set(nodes) - failnodes - httperrnodes), 29)
223
224     proc = rpc._RpcProcessor(resolver, 15165)
225     pool = FakeHttpPool(compat.partial(self._GetHttpErrorResponse,
226                                        httperrnodes, failnodes))
227     result = proc(nodes, "vg_list", None, http_pool=pool)
228     self.assertEqual(sorted(result.keys()), sorted(nodes))
229
230     for name in nodes:
231       lhresp = result[name]
232       self.assertFalse(lhresp.offline)
233       self.assertEqual(lhresp.node, name)
234       self.assertEqual(lhresp.call, "vg_list")
235
236       if name in httperrnodes:
237         self.assert_(lhresp.fail_msg)
238         self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
239       elif name in failnodes:
240         self.assert_(lhresp.fail_msg)
241         self.assertRaises(errors.OpPrereqError, lhresp.Raise, "failed",
242                           prereq=True, ecode=errors.ECODE_INVAL)
243       else:
244         self.assertFalse(lhresp.fail_msg)
245         self.assertEqual(lhresp.payload, hash(name))
246         lhresp.Raise("should not raise")
247
248     self.assertEqual(pool.reqcount, len(nodes))
249
250   def _GetInvalidResponseA(self, req):
251     self.assertEqual(req.path, "/version")
252     req.success = True
253     req.resp_status_code = http.HTTP_OK
254     req.resp_body = serializer.DumpJson(("This", "is", "an", "invalid",
255                                          "response", "!", 1, 2, 3))
256
257   def _GetInvalidResponseB(self, req):
258     self.assertEqual(req.path, "/version")
259     req.success = True
260     req.resp_status_code = http.HTTP_OK
261     req.resp_body = serializer.DumpJson("invalid response")
262
263   def testInvalidResponse(self):
264     resolver = rpc._StaticResolver(["oqo7lanhly.example.com"])
265     proc = rpc._RpcProcessor(resolver, 19978)
266
267     for fn in [self._GetInvalidResponseA, self._GetInvalidResponseB]:
268       pool = FakeHttpPool(fn)
269       result = proc(["oqo7lanhly.example.com"], "version", None, http_pool=pool)
270       self.assertEqual(result.keys(), ["oqo7lanhly.example.com"])
271       lhresp = result["oqo7lanhly.example.com"]
272       self.assertFalse(lhresp.offline)
273       self.assertEqual(lhresp.node, "oqo7lanhly.example.com")
274       self.assert_(lhresp.fail_msg)
275       self.assertFalse(lhresp.payload)
276       self.assertEqual(lhresp.call, "version")
277       self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
278       self.assertEqual(pool.reqcount, 1)
279
280   def _GetBodyTestResponse(self, test_data, req):
281     self.assertEqual(req.host, "192.0.2.84")
282     self.assertEqual(req.port, 18700)
283     self.assertEqual(req.path, "/upload_file")
284     self.assertEqual(serializer.LoadJson(req.post_data), test_data)
285     req.success = True
286     req.resp_status_code = http.HTTP_OK
287     req.resp_body = serializer.DumpJson((True, None))
288
289   def testResponseBody(self):
290     test_data = {
291       "Hello": "World",
292       "xyz": range(10),
293       }
294     resolver = rpc._StaticResolver(["192.0.2.84"])
295     pool = FakeHttpPool(compat.partial(self._GetBodyTestResponse, test_data))
296     proc = rpc._RpcProcessor(resolver, 18700)
297     body = serializer.DumpJson(test_data)
298     result = proc(["node19759"], "upload_file", body, http_pool=pool)
299     self.assertEqual(result.keys(), ["node19759"])
300     lhresp = result["node19759"]
301     self.assertFalse(lhresp.offline)
302     self.assertEqual(lhresp.node, "node19759")
303     self.assertFalse(lhresp.fail_msg)
304     self.assertEqual(lhresp.payload, None)
305     self.assertEqual(lhresp.call, "upload_file")
306     lhresp.Raise("should not raise")
307     self.assertEqual(pool.reqcount, 1)
308
309
310 class TestSsconfResolver(unittest.TestCase):
311   def testSsconfLookup(self):
312     addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
313     node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
314     node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
315     ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
316     result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=NotImplemented)
317     self.assertEqual(result, zip(node_list, addr_list))
318
319   def testNsLookup(self):
320     addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
321     node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
322     ssc = GetFakeSimpleStoreClass(lambda _: [])
323     node_addr_map = dict(zip(node_list, addr_list))
324     nslookup_fn = lambda name, family=None: node_addr_map.get(name)
325     result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=nslookup_fn)
326     self.assertEqual(result, zip(node_list, addr_list))
327
328   def testBothLookups(self):
329     addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
330     node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
331     n = len(addr_list) / 2
332     node_addr_list = [" ".join(t) for t in zip(node_list[n:], addr_list[n:])]
333     ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
334     node_addr_map = dict(zip(node_list[:n], addr_list[:n]))
335     nslookup_fn = lambda name, family=None: node_addr_map.get(name)
336     result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=nslookup_fn)
337     self.assertEqual(result, zip(node_list, addr_list))
338
339   def testAddressLookupIPv6(self):
340     addr_list = ["2001:db8::%d" % n for n in range(0, 255, 11)]
341     node_list = ["node%d.example.com" % n for n in range(0, 255, 11)]
342     node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
343     ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
344     result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=NotImplemented)
345     self.assertEqual(result, zip(node_list, addr_list))
346
347
348 class TestStaticResolver(unittest.TestCase):
349   def test(self):
350     addresses = ["192.0.2.%d" % n for n in range(0, 123, 7)]
351     nodes = ["node%s.example.com" % n for n in range(0, 123, 7)]
352     res = rpc._StaticResolver(addresses)
353     self.assertEqual(res(nodes), zip(nodes, addresses))
354
355   def testWrongLength(self):
356     res = rpc._StaticResolver([])
357     self.assertRaises(AssertionError, res, ["abc"])
358
359
360 class TestNodeConfigResolver(unittest.TestCase):
361   @staticmethod
362   def _GetSingleOnlineNode(name):
363     assert name == "node90.example.com"
364     return objects.Node(name=name, offline=False, primary_ip="192.0.2.90")
365
366   @staticmethod
367   def _GetSingleOfflineNode(name):
368     assert name == "node100.example.com"
369     return objects.Node(name=name, offline=True, primary_ip="192.0.2.100")
370
371   def testSingleOnline(self):
372     self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOnlineNode,
373                                              NotImplemented,
374                                              ["node90.example.com"]),
375                      [("node90.example.com", "192.0.2.90")])
376
377   def testSingleOffline(self):
378     self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOfflineNode,
379                                              NotImplemented,
380                                              ["node100.example.com"]),
381                      [("node100.example.com", rpc._OFFLINE)])
382
383   def testUnknownSingleNode(self):
384     self.assertEqual(rpc._NodeConfigResolver(lambda _: None, NotImplemented,
385                                              ["node110.example.com"]),
386                      [("node110.example.com", "node110.example.com")])
387
388   def testMultiEmpty(self):
389     self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
390                                              lambda: {},
391                                              []),
392                      [])
393
394   def testMultiSomeOffline(self):
395     nodes = dict(("node%s.example.com" % i,
396                   objects.Node(name="node%s.example.com" % i,
397                                offline=((i % 3) == 0),
398                                primary_ip="192.0.2.%s" % i))
399                   for i in range(1, 255))
400
401     # Resolve no names
402     self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
403                                              lambda: nodes,
404                                              []),
405                      [])
406
407     # Offline, online and unknown hosts
408     self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
409                                              lambda: nodes,
410                                              ["node3.example.com",
411                                               "node92.example.com",
412                                               "node54.example.com",
413                                               "unknown.example.com",]), [
414       ("node3.example.com", rpc._OFFLINE),
415       ("node92.example.com", "192.0.2.92"),
416       ("node54.example.com", rpc._OFFLINE),
417       ("unknown.example.com", "unknown.example.com"),
418       ])
419
420
421 if __name__ == "__main__":
422   testutils.GanetiTestProgram()