Merge branch 'stable-2.9' into stable-2.10
[ganeti-local] / test / py / ganeti.rpc_unittest.py
1 #!/usr/bin/python
2 #
3
4 # Copyright (C) 2010, 2011, 2012, 2013 Google Inc.
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 # 02110-1301, USA.
20
21
22 """Script for testing ganeti.rpc"""
23
24 import os
25 import sys
26 import unittest
27 import random
28 import tempfile
29
30 from ganeti import constants
31 from ganeti import compat
32 from ganeti import rpc
33 from ganeti import rpc_defs
34 from ganeti import http
35 from ganeti import errors
36 from ganeti import serializer
37 from ganeti import objects
38 from ganeti import backend
39
40 import testutils
41 import mocks
42
43
44 class _FakeRequestProcessor:
45   def __init__(self, response_fn):
46     self._response_fn = response_fn
47     self.reqcount = 0
48
49   def __call__(self, reqs, lock_monitor_cb=None):
50     assert lock_monitor_cb is None or callable(lock_monitor_cb)
51     for req in reqs:
52       self.reqcount += 1
53       self._response_fn(req)
54
55
56 def GetFakeSimpleStoreClass(fn):
57   class FakeSimpleStore:
58     GetNodePrimaryIPList = fn
59     GetPrimaryIPFamily = lambda _: None
60
61   return FakeSimpleStore
62
63
64 def _RaiseNotImplemented():
65   """Simple wrapper to raise NotImplementedError.
66
67   """
68   raise NotImplementedError
69
70
71 class TestRpcProcessor(unittest.TestCase):
72   def _FakeAddressLookup(self, map):
73     return lambda node_list: [map.get(node) for node in node_list]
74
75   def _GetVersionResponse(self, req):
76     self.assertEqual(req.host, "127.0.0.1")
77     self.assertEqual(req.port, 24094)
78     self.assertEqual(req.path, "/version")
79     self.assertEqual(req.read_timeout, constants.RPC_TMO_URGENT)
80     req.success = True
81     req.resp_status_code = http.HTTP_OK
82     req.resp_body = serializer.DumpJson((True, 123))
83
84   def testVersionSuccess(self):
85     resolver = rpc._StaticResolver(["127.0.0.1"])
86     http_proc = _FakeRequestProcessor(self._GetVersionResponse)
87     proc = rpc._RpcProcessor(resolver, 24094)
88     result = proc(["localhost"], "version", {"localhost": ""}, 60,
89                   NotImplemented, _req_process_fn=http_proc)
90     self.assertEqual(result.keys(), ["localhost"])
91     lhresp = result["localhost"]
92     self.assertFalse(lhresp.offline)
93     self.assertEqual(lhresp.node, "localhost")
94     self.assertFalse(lhresp.fail_msg)
95     self.assertEqual(lhresp.payload, 123)
96     self.assertEqual(lhresp.call, "version")
97     lhresp.Raise("should not raise")
98     self.assertEqual(http_proc.reqcount, 1)
99
100   def _ReadTimeoutResponse(self, req):
101     self.assertEqual(req.host, "192.0.2.13")
102     self.assertEqual(req.port, 19176)
103     self.assertEqual(req.path, "/version")
104     self.assertEqual(req.read_timeout, 12356)
105     req.success = True
106     req.resp_status_code = http.HTTP_OK
107     req.resp_body = serializer.DumpJson((True, -1))
108
109   def testReadTimeout(self):
110     resolver = rpc._StaticResolver(["192.0.2.13"])
111     http_proc = _FakeRequestProcessor(self._ReadTimeoutResponse)
112     proc = rpc._RpcProcessor(resolver, 19176)
113     host = "node31856"
114     body = {host: ""}
115     result = proc([host], "version", body, 12356, NotImplemented,
116                   _req_process_fn=http_proc)
117     self.assertEqual(result.keys(), [host])
118     lhresp = result[host]
119     self.assertFalse(lhresp.offline)
120     self.assertEqual(lhresp.node, host)
121     self.assertFalse(lhresp.fail_msg)
122     self.assertEqual(lhresp.payload, -1)
123     self.assertEqual(lhresp.call, "version")
124     lhresp.Raise("should not raise")
125     self.assertEqual(http_proc.reqcount, 1)
126
127   def testOfflineNode(self):
128     resolver = rpc._StaticResolver([rpc._OFFLINE])
129     http_proc = _FakeRequestProcessor(NotImplemented)
130     proc = rpc._RpcProcessor(resolver, 30668)
131     host = "n17296"
132     body = {host: ""}
133     result = proc([host], "version", body, 60, NotImplemented,
134                   _req_process_fn=http_proc)
135     self.assertEqual(result.keys(), [host])
136     lhresp = result[host]
137     self.assertTrue(lhresp.offline)
138     self.assertEqual(lhresp.node, host)
139     self.assertTrue(lhresp.fail_msg)
140     self.assertFalse(lhresp.payload)
141     self.assertEqual(lhresp.call, "version")
142
143     # With a message
144     self.assertRaises(errors.OpExecError, lhresp.Raise, "should raise")
145
146     # No message
147     self.assertRaises(errors.OpExecError, lhresp.Raise, None)
148
149     self.assertEqual(http_proc.reqcount, 0)
150
151   def _GetMultiVersionResponse(self, req):
152     self.assert_(req.host.startswith("node"))
153     self.assertEqual(req.port, 23245)
154     self.assertEqual(req.path, "/version")
155     req.success = True
156     req.resp_status_code = http.HTTP_OK
157     req.resp_body = serializer.DumpJson((True, 987))
158
159   def testMultiVersionSuccess(self):
160     nodes = ["node%s" % i for i in range(50)]
161     body = dict((n, "") for n in nodes)
162     resolver = rpc._StaticResolver(nodes)
163     http_proc = _FakeRequestProcessor(self._GetMultiVersionResponse)
164     proc = rpc._RpcProcessor(resolver, 23245)
165     result = proc(nodes, "version", body, 60, NotImplemented,
166                   _req_process_fn=http_proc)
167     self.assertEqual(sorted(result.keys()), sorted(nodes))
168
169     for name in nodes:
170       lhresp = result[name]
171       self.assertFalse(lhresp.offline)
172       self.assertEqual(lhresp.node, name)
173       self.assertFalse(lhresp.fail_msg)
174       self.assertEqual(lhresp.payload, 987)
175       self.assertEqual(lhresp.call, "version")
176       lhresp.Raise("should not raise")
177
178     self.assertEqual(http_proc.reqcount, len(nodes))
179
180   def _GetVersionResponseFail(self, errinfo, req):
181     self.assertEqual(req.path, "/version")
182     req.success = True
183     req.resp_status_code = http.HTTP_OK
184     req.resp_body = serializer.DumpJson((False, errinfo))
185
186   def testVersionFailure(self):
187     resolver = rpc._StaticResolver(["aef9ur4i.example.com"])
188     proc = rpc._RpcProcessor(resolver, 5903)
189     for errinfo in [None, "Unknown error"]:
190       http_proc = \
191         _FakeRequestProcessor(compat.partial(self._GetVersionResponseFail,
192                                              errinfo))
193       host = "aef9ur4i.example.com"
194       body = {host: ""}
195       result = proc(body.keys(), "version", body, 60, NotImplemented,
196                     _req_process_fn=http_proc)
197       self.assertEqual(result.keys(), [host])
198       lhresp = result[host]
199       self.assertFalse(lhresp.offline)
200       self.assertEqual(lhresp.node, host)
201       self.assert_(lhresp.fail_msg)
202       self.assertFalse(lhresp.payload)
203       self.assertEqual(lhresp.call, "version")
204       self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
205       self.assertEqual(http_proc.reqcount, 1)
206
207   def _GetHttpErrorResponse(self, httperrnodes, failnodes, req):
208     self.assertEqual(req.path, "/vg_list")
209     self.assertEqual(req.port, 15165)
210
211     if req.host in httperrnodes:
212       req.success = False
213       req.error = "Node set up for HTTP errors"
214
215     elif req.host in failnodes:
216       req.success = True
217       req.resp_status_code = 404
218       req.resp_body = serializer.DumpJson({
219         "code": 404,
220         "message": "Method not found",
221         "explain": "Explanation goes here",
222         })
223     else:
224       req.success = True
225       req.resp_status_code = http.HTTP_OK
226       req.resp_body = serializer.DumpJson((True, hash(req.host)))
227
228   def testHttpError(self):
229     nodes = ["uaf6pbbv%s" % i for i in range(50)]
230     body = dict((n, "") for n in nodes)
231     resolver = rpc._StaticResolver(nodes)
232
233     httperrnodes = set(nodes[1::7])
234     self.assertEqual(len(httperrnodes), 7)
235
236     failnodes = set(nodes[2::3]) - httperrnodes
237     self.assertEqual(len(failnodes), 14)
238
239     self.assertEqual(len(set(nodes) - failnodes - httperrnodes), 29)
240
241     proc = rpc._RpcProcessor(resolver, 15165)
242     http_proc = \
243       _FakeRequestProcessor(compat.partial(self._GetHttpErrorResponse,
244                                            httperrnodes, failnodes))
245     result = proc(nodes, "vg_list", body,
246                   constants.RPC_TMO_URGENT, NotImplemented,
247                   _req_process_fn=http_proc)
248     self.assertEqual(sorted(result.keys()), sorted(nodes))
249
250     for name in nodes:
251       lhresp = result[name]
252       self.assertFalse(lhresp.offline)
253       self.assertEqual(lhresp.node, name)
254       self.assertEqual(lhresp.call, "vg_list")
255
256       if name in httperrnodes:
257         self.assert_(lhresp.fail_msg)
258         self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
259       elif name in failnodes:
260         self.assert_(lhresp.fail_msg)
261         self.assertRaises(errors.OpPrereqError, lhresp.Raise, "failed",
262                           prereq=True, ecode=errors.ECODE_INVAL)
263       else:
264         self.assertFalse(lhresp.fail_msg)
265         self.assertEqual(lhresp.payload, hash(name))
266         lhresp.Raise("should not raise")
267
268     self.assertEqual(http_proc.reqcount, len(nodes))
269
270   def _GetInvalidResponseA(self, req):
271     self.assertEqual(req.path, "/version")
272     req.success = True
273     req.resp_status_code = http.HTTP_OK
274     req.resp_body = serializer.DumpJson(("This", "is", "an", "invalid",
275                                          "response", "!", 1, 2, 3))
276
277   def _GetInvalidResponseB(self, req):
278     self.assertEqual(req.path, "/version")
279     req.success = True
280     req.resp_status_code = http.HTTP_OK
281     req.resp_body = serializer.DumpJson("invalid response")
282
283   def testInvalidResponse(self):
284     resolver = rpc._StaticResolver(["oqo7lanhly.example.com"])
285     proc = rpc._RpcProcessor(resolver, 19978)
286
287     for fn in [self._GetInvalidResponseA, self._GetInvalidResponseB]:
288       http_proc = _FakeRequestProcessor(fn)
289       host = "oqo7lanhly.example.com"
290       body = {host: ""}
291       result = proc([host], "version", body, 60, NotImplemented,
292                     _req_process_fn=http_proc)
293       self.assertEqual(result.keys(), [host])
294       lhresp = result[host]
295       self.assertFalse(lhresp.offline)
296       self.assertEqual(lhresp.node, host)
297       self.assert_(lhresp.fail_msg)
298       self.assertFalse(lhresp.payload)
299       self.assertEqual(lhresp.call, "version")
300       self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
301       self.assertEqual(http_proc.reqcount, 1)
302
303   def _GetBodyTestResponse(self, test_data, req):
304     self.assertEqual(req.host, "192.0.2.84")
305     self.assertEqual(req.port, 18700)
306     self.assertEqual(req.path, "/upload_file")
307     self.assertEqual(serializer.LoadJson(req.post_data), test_data)
308     req.success = True
309     req.resp_status_code = http.HTTP_OK
310     req.resp_body = serializer.DumpJson((True, None))
311
312   def testResponseBody(self):
313     test_data = {
314       "Hello": "World",
315       "xyz": range(10),
316       }
317     resolver = rpc._StaticResolver(["192.0.2.84"])
318     http_proc = _FakeRequestProcessor(compat.partial(self._GetBodyTestResponse,
319                                                      test_data))
320     proc = rpc._RpcProcessor(resolver, 18700)
321     host = "node19759"
322     body = {host: serializer.DumpJson(test_data)}
323     result = proc([host], "upload_file", body, 30, NotImplemented,
324                   _req_process_fn=http_proc)
325     self.assertEqual(result.keys(), [host])
326     lhresp = result[host]
327     self.assertFalse(lhresp.offline)
328     self.assertEqual(lhresp.node, host)
329     self.assertFalse(lhresp.fail_msg)
330     self.assertEqual(lhresp.payload, None)
331     self.assertEqual(lhresp.call, "upload_file")
332     lhresp.Raise("should not raise")
333     self.assertEqual(http_proc.reqcount, 1)
334
335
336 class TestSsconfResolver(unittest.TestCase):
337   def testSsconfLookup(self):
338     addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
339     node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
340     node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
341     ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
342     result = rpc._SsconfResolver(True, node_list, NotImplemented,
343                                  ssc=ssc, nslookup_fn=NotImplemented)
344     self.assertEqual(result, zip(node_list, addr_list, node_list))
345
346   def testNsLookup(self):
347     addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
348     node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
349     ssc = GetFakeSimpleStoreClass(lambda _: [])
350     node_addr_map = dict(zip(node_list, addr_list))
351     nslookup_fn = lambda name, family=None: node_addr_map.get(name)
352     result = rpc._SsconfResolver(True, node_list, NotImplemented,
353                                  ssc=ssc, nslookup_fn=nslookup_fn)
354     self.assertEqual(result, zip(node_list, addr_list, node_list))
355
356   def testDisabledSsconfIp(self):
357     addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
358     node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
359     ssc = GetFakeSimpleStoreClass(_RaiseNotImplemented)
360     node_addr_map = dict(zip(node_list, addr_list))
361     nslookup_fn = lambda name, family=None: node_addr_map.get(name)
362     result = rpc._SsconfResolver(False, node_list, NotImplemented,
363                                  ssc=ssc, nslookup_fn=nslookup_fn)
364     self.assertEqual(result, zip(node_list, addr_list, node_list))
365
366   def testBothLookups(self):
367     addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
368     node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
369     n = len(addr_list) / 2
370     node_addr_list = [" ".join(t) for t in zip(node_list[n:], addr_list[n:])]
371     ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
372     node_addr_map = dict(zip(node_list[:n], addr_list[:n]))
373     nslookup_fn = lambda name, family=None: node_addr_map.get(name)
374     result = rpc._SsconfResolver(True, node_list, NotImplemented,
375                                  ssc=ssc, nslookup_fn=nslookup_fn)
376     self.assertEqual(result, zip(node_list, addr_list, node_list))
377
378   def testAddressLookupIPv6(self):
379     addr_list = ["2001:db8::%d" % n for n in range(0, 255, 11)]
380     node_list = ["node%d.example.com" % n for n in range(0, 255, 11)]
381     node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
382     ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
383     result = rpc._SsconfResolver(True, node_list, NotImplemented,
384                                  ssc=ssc, nslookup_fn=NotImplemented)
385     self.assertEqual(result, zip(node_list, addr_list, node_list))
386
387
388 class TestStaticResolver(unittest.TestCase):
389   def test(self):
390     addresses = ["192.0.2.%d" % n for n in range(0, 123, 7)]
391     nodes = ["node%s.example.com" % n for n in range(0, 123, 7)]
392     res = rpc._StaticResolver(addresses)
393     self.assertEqual(res(nodes, NotImplemented), zip(nodes, addresses, nodes))
394
395   def testWrongLength(self):
396     res = rpc._StaticResolver([])
397     self.assertRaises(AssertionError, res, ["abc"], NotImplemented)
398
399
400 class TestNodeConfigResolver(unittest.TestCase):
401   @staticmethod
402   def _GetSingleOnlineNode(uuid):
403     assert uuid == "node90-uuid"
404     return objects.Node(name="node90.example.com",
405                         uuid=uuid,
406                         offline=False,
407                         primary_ip="192.0.2.90")
408
409   @staticmethod
410   def _GetSingleOfflineNode(uuid):
411     assert uuid == "node100-uuid"
412     return objects.Node(name="node100.example.com",
413                         uuid=uuid,
414                         offline=True,
415                         primary_ip="192.0.2.100")
416
417   def testSingleOnline(self):
418     self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOnlineNode,
419                                              NotImplemented,
420                                              ["node90-uuid"], None),
421                      [("node90.example.com", "192.0.2.90", "node90-uuid")])
422
423   def testSingleOffline(self):
424     self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOfflineNode,
425                                              NotImplemented,
426                                              ["node100-uuid"], None),
427                      [("node100.example.com", rpc._OFFLINE, "node100-uuid")])
428
429   def testSingleOfflineWithAcceptOffline(self):
430     fn = self._GetSingleOfflineNode
431     assert fn("node100-uuid").offline
432     self.assertEqual(rpc._NodeConfigResolver(fn, NotImplemented,
433                                              ["node100-uuid"],
434                                              rpc_defs.ACCEPT_OFFLINE_NODE),
435                      [("node100.example.com", "192.0.2.100", "node100-uuid")])
436     for i in [False, True, "", "Hello", 0, 1]:
437       self.assertRaises(AssertionError, rpc._NodeConfigResolver,
438                         fn, NotImplemented, ["node100.example.com"], i)
439
440   def testUnknownSingleNode(self):
441     self.assertEqual(rpc._NodeConfigResolver(lambda _: None, NotImplemented,
442                                              ["node110.example.com"], None),
443                      [("node110.example.com", "node110.example.com",
444                        "node110.example.com")])
445
446   def testMultiEmpty(self):
447     self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
448                                              lambda: {},
449                                              [], None),
450                      [])
451
452   def testMultiSomeOffline(self):
453     nodes = dict(("node%s-uuid" % i,
454                   objects.Node(name="node%s.example.com" % i,
455                                offline=((i % 3) == 0),
456                                primary_ip="192.0.2.%s" % i,
457                                uuid="node%s-uuid" % i))
458                   for i in range(1, 255))
459
460     # Resolve no names
461     self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
462                                              lambda: nodes,
463                                              [], None),
464                      [])
465
466     # Offline, online and unknown hosts
467     self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
468                                              lambda: nodes,
469                                              ["node3-uuid",
470                                               "node92-uuid",
471                                               "node54-uuid",
472                                               "unknown.example.com",],
473                                              None), [
474       ("node3.example.com", rpc._OFFLINE, "node3-uuid"),
475       ("node92.example.com", "192.0.2.92", "node92-uuid"),
476       ("node54.example.com", rpc._OFFLINE, "node54-uuid"),
477       ("unknown.example.com", "unknown.example.com", "unknown.example.com"),
478       ])
479
480
481 class TestCompress(unittest.TestCase):
482   def test(self):
483     for data in ["", "Hello", "Hello World!\nnew\nlines"]:
484       self.assertEqual(rpc._Compress(NotImplemented, data),
485                        (constants.RPC_ENCODING_NONE, data))
486
487     for data in [512 * " ", 5242 * "Hello World!\n"]:
488       compressed = rpc._Compress(NotImplemented, data)
489       self.assertEqual(len(compressed), 2)
490       self.assertEqual(backend._Decompress(compressed), data)
491
492   def testDecompression(self):
493     self.assertRaises(AssertionError, backend._Decompress, "")
494     self.assertRaises(AssertionError, backend._Decompress, [""])
495     self.assertRaises(AssertionError, backend._Decompress,
496                       ("unknown compression", "data"))
497     self.assertRaises(Exception, backend._Decompress,
498                       (constants.RPC_ENCODING_ZLIB_BASE64, "invalid zlib data"))
499
500
501 class TestRpcClientBase(unittest.TestCase):
502   def testNoHosts(self):
503     cdef = ("test_call", NotImplemented, None, constants.RPC_TMO_SLOW, [],
504             None, None, NotImplemented)
505     http_proc = _FakeRequestProcessor(NotImplemented)
506     client = rpc._RpcClientBase(rpc._StaticResolver([]), NotImplemented,
507                                 _req_process_fn=http_proc)
508     self.assertEqual(client._Call(cdef, [], []), {})
509
510     # Test wrong number of arguments
511     self.assertRaises(errors.ProgrammerError, client._Call,
512                       cdef, [], [0, 1, 2])
513
514   def testTimeout(self):
515     def _CalcTimeout((arg1, arg2)):
516       return arg1 + arg2
517
518     def _VerifyRequest(exp_timeout, req):
519       self.assertEqual(req.read_timeout, exp_timeout)
520
521       req.success = True
522       req.resp_status_code = http.HTTP_OK
523       req.resp_body = serializer.DumpJson((True, hex(req.read_timeout)))
524
525     resolver = rpc._StaticResolver([
526       "192.0.2.1",
527       "192.0.2.2",
528       ])
529
530     nodes = [
531       "node1.example.com",
532       "node2.example.com",
533       ]
534
535     tests = [(100, None, 100), (30, None, 30)]
536     tests.extend((_CalcTimeout, i, i + 300)
537                  for i in [0, 5, 16485, 30516])
538
539     for timeout, arg1, exp_timeout in tests:
540       cdef = ("test_call", NotImplemented, None, timeout, [
541         ("arg1", None, NotImplemented),
542         ("arg2", None, NotImplemented),
543         ], None, None, NotImplemented)
544
545       http_proc = _FakeRequestProcessor(compat.partial(_VerifyRequest,
546                                                        exp_timeout))
547       client = rpc._RpcClientBase(resolver, NotImplemented,
548                                   _req_process_fn=http_proc)
549       result = client._Call(cdef, nodes, [arg1, 300])
550       self.assertEqual(len(result), len(nodes))
551       self.assertTrue(compat.all(not res.fail_msg and
552                                  res.payload == hex(exp_timeout)
553                                  for res in result.values()))
554
555   def testArgumentEncoder(self):
556     (AT1, AT2) = range(1, 3)
557
558     resolver = rpc._StaticResolver([
559       "192.0.2.5",
560       "192.0.2.6",
561       ])
562
563     nodes = [
564       "node5.example.com",
565       "node6.example.com",
566       ]
567
568     encoders = {
569       AT1: lambda _, value: hex(value),
570       AT2: lambda _, value: hash(value),
571       }
572
573     cdef = ("test_call", NotImplemented, None, constants.RPC_TMO_NORMAL, [
574       ("arg0", None, NotImplemented),
575       ("arg1", AT1, NotImplemented),
576       ("arg1", AT2, NotImplemented),
577       ], None, None, NotImplemented)
578
579     def _VerifyRequest(req):
580       req.success = True
581       req.resp_status_code = http.HTTP_OK
582       req.resp_body = serializer.DumpJson((True, req.post_data))
583
584     http_proc = _FakeRequestProcessor(_VerifyRequest)
585
586     for num in [0, 3796, 9032119]:
587       client = rpc._RpcClientBase(resolver, encoders.get,
588                                   _req_process_fn=http_proc)
589       result = client._Call(cdef, nodes, ["foo", num, "Hello%s" % num])
590       self.assertEqual(len(result), len(nodes))
591       for res in result.values():
592         self.assertFalse(res.fail_msg)
593         self.assertEqual(serializer.LoadJson(res.payload),
594                          ["foo", hex(num), hash("Hello%s" % num)])
595
596   def testPostProc(self):
597     def _VerifyRequest(nums, req):
598       req.success = True
599       req.resp_status_code = http.HTTP_OK
600       req.resp_body = serializer.DumpJson((True, nums))
601
602     resolver = rpc._StaticResolver([
603       "192.0.2.90",
604       "192.0.2.95",
605       ])
606
607     nodes = [
608       "node90.example.com",
609       "node95.example.com",
610       ]
611
612     def _PostProc(res):
613       self.assertFalse(res.fail_msg)
614       res.payload = sum(res.payload)
615       return res
616
617     cdef = ("test_call", NotImplemented, None, constants.RPC_TMO_NORMAL, [],
618             None, _PostProc, NotImplemented)
619
620     # Seeded random generator
621     rnd = random.Random(20299)
622
623     for i in [0, 4, 74, 1391]:
624       nums = [rnd.randint(0, 1000) for _ in range(i)]
625       http_proc = _FakeRequestProcessor(compat.partial(_VerifyRequest, nums))
626       client = rpc._RpcClientBase(resolver, NotImplemented,
627                                   _req_process_fn=http_proc)
628       result = client._Call(cdef, nodes, [])
629       self.assertEqual(len(result), len(nodes))
630       for res in result.values():
631         self.assertFalse(res.fail_msg)
632         self.assertEqual(res.payload, sum(nums))
633
634   def testPreProc(self):
635     def _VerifyRequest(req):
636       req.success = True
637       req.resp_status_code = http.HTTP_OK
638       req.resp_body = serializer.DumpJson((True, req.post_data))
639
640     resolver = rpc._StaticResolver([
641       "192.0.2.30",
642       "192.0.2.35",
643       ])
644
645     nodes = [
646       "node30.example.com",
647       "node35.example.com",
648       ]
649
650     def _PreProc(node, data):
651       self.assertEqual(len(data), 1)
652       return data[0] + node
653
654     cdef = ("test_call", NotImplemented, None, constants.RPC_TMO_NORMAL, [
655       ("arg0", None, NotImplemented),
656       ], _PreProc, None, NotImplemented)
657
658     http_proc = _FakeRequestProcessor(_VerifyRequest)
659     client = rpc._RpcClientBase(resolver, NotImplemented,
660                                 _req_process_fn=http_proc)
661
662     for prefix in ["foo", "bar", "baz"]:
663       result = client._Call(cdef, nodes, [prefix])
664       self.assertEqual(len(result), len(nodes))
665       for (idx, (node, res)) in enumerate(result.items()):
666         self.assertFalse(res.fail_msg)
667         self.assertEqual(serializer.LoadJson(res.payload), prefix + node)
668
669   def testResolverOptions(self):
670     def _VerifyRequest(req):
671       req.success = True
672       req.resp_status_code = http.HTTP_OK
673       req.resp_body = serializer.DumpJson((True, req.post_data))
674
675     nodes = [
676       "node30.example.com",
677       "node35.example.com",
678       ]
679
680     def _Resolver(expected, hosts, options):
681       self.assertEqual(hosts, nodes)
682       self.assertEqual(options, expected)
683       return zip(hosts, nodes, hosts)
684
685     def _DynamicResolverOptions((arg0, )):
686       return sum(arg0)
687
688     tests = [
689       (None, None, None),
690       (rpc_defs.ACCEPT_OFFLINE_NODE, None, rpc_defs.ACCEPT_OFFLINE_NODE),
691       (False, None, False),
692       (True, None, True),
693       (0, None, 0),
694       (_DynamicResolverOptions, [1, 2, 3], 6),
695       (_DynamicResolverOptions, range(4, 19), 165),
696       ]
697
698     for (resolver_opts, arg0, expected) in tests:
699       cdef = ("test_call", NotImplemented, resolver_opts,
700               constants.RPC_TMO_NORMAL, [
701         ("arg0", None, NotImplemented),
702         ], None, None, NotImplemented)
703
704       http_proc = _FakeRequestProcessor(_VerifyRequest)
705
706       client = rpc._RpcClientBase(compat.partial(_Resolver, expected),
707                                   NotImplemented, _req_process_fn=http_proc)
708       result = client._Call(cdef, nodes, [arg0])
709       self.assertEqual(len(result), len(nodes))
710       for (idx, (node, res)) in enumerate(result.items()):
711         self.assertFalse(res.fail_msg)
712
713
714 class _FakeConfigForRpcRunner:
715   GetAllNodesInfo = NotImplemented
716
717   def __init__(self, cluster=NotImplemented):
718     self._cluster = cluster
719
720   def GetNodeInfo(self, name):
721     return objects.Node(name=name)
722
723   def GetMultiNodeInfo(self, names):
724     return [(name, self.GetNodeInfo(name)) for name in names]
725
726   def GetClusterInfo(self):
727     return self._cluster
728
729   def GetInstanceDiskParams(self, _):
730     return constants.DISK_DT_DEFAULTS
731
732
733 class TestRpcRunner(unittest.TestCase):
734   def testUploadFile(self):
735     data = 1779 * "Hello World\n"
736
737     tmpfile = tempfile.NamedTemporaryFile()
738     tmpfile.write(data)
739     tmpfile.flush()
740     st = os.stat(tmpfile.name)
741
742     nodes = [
743       "node1.example.com",
744       ]
745
746     def _VerifyRequest(req):
747       (uldata, ) = serializer.LoadJson(req.post_data)
748       self.assertEqual(len(uldata), 7)
749       self.assertEqual(uldata[0], tmpfile.name)
750       self.assertEqual(list(uldata[1]), list(rpc._Compress(nodes[0], data)))
751       self.assertEqual(uldata[2], st.st_mode)
752       self.assertEqual(uldata[3], "user%s" % os.getuid())
753       self.assertEqual(uldata[4], "group%s" % os.getgid())
754       self.assertTrue(uldata[5] is not None)
755       self.assertEqual(uldata[6], st.st_mtime)
756
757       req.success = True
758       req.resp_status_code = http.HTTP_OK
759       req.resp_body = serializer.DumpJson((True, None))
760
761     http_proc = _FakeRequestProcessor(_VerifyRequest)
762
763     std_runner = rpc.RpcRunner(_FakeConfigForRpcRunner(), None,
764                                _req_process_fn=http_proc,
765                                _getents=mocks.FakeGetentResolver)
766
767     cfg_runner = rpc.ConfigRunner(None, ["192.0.2.13"],
768                                   _req_process_fn=http_proc,
769                                   _getents=mocks.FakeGetentResolver)
770
771     for runner in [std_runner, cfg_runner]:
772       result = runner.call_upload_file(nodes, tmpfile.name)
773       self.assertEqual(len(result), len(nodes))
774       for (idx, (node, res)) in enumerate(result.items()):
775         self.assertFalse(res.fail_msg)
776
777   def testEncodeInstance(self):
778     cluster = objects.Cluster(hvparams={
779       constants.HT_KVM: {
780         constants.HV_CDROM_IMAGE_PATH: "foo",
781         },
782       },
783       beparams={
784         constants.PP_DEFAULT: {
785           constants.BE_MAXMEM: 8192,
786           },
787         },
788       os_hvp={},
789       osparams={
790         "linux": {
791           "role": "unknown",
792           },
793         })
794     cluster.UpgradeConfig()
795
796     inst = objects.Instance(name="inst1.example.com",
797       hypervisor=constants.HT_KVM,
798       os="linux",
799       hvparams={
800         constants.HV_CDROM_IMAGE_PATH: "bar",
801         constants.HV_ROOT_PATH: "/tmp",
802         },
803       beparams={
804         constants.BE_MINMEM: 128,
805         constants.BE_MAXMEM: 256,
806         },
807       nics=[
808         objects.NIC(nicparams={
809           constants.NIC_MODE: "mymode",
810           }),
811         ],
812       disk_template=constants.DT_PLAIN,
813       disks=[
814         objects.Disk(dev_type=constants.DT_PLAIN, size=4096,
815                      logical_id=("vg", "disk6120")),
816         objects.Disk(dev_type=constants.DT_PLAIN, size=1024,
817                      logical_id=("vg", "disk8508")),
818         ])
819     inst.UpgradeConfig()
820
821     cfg = _FakeConfigForRpcRunner(cluster=cluster)
822     runner = rpc.RpcRunner(cfg, None,
823                            _req_process_fn=NotImplemented,
824                            _getents=mocks.FakeGetentResolver)
825
826     def _CheckBasics(result):
827       self.assertEqual(result["name"], "inst1.example.com")
828       self.assertEqual(result["os"], "linux")
829       self.assertEqual(result["beparams"][constants.BE_MINMEM], 128)
830       self.assertEqual(len(result["nics"]), 1)
831       self.assertEqual(result["nics"][0]["nicparams"][constants.NIC_MODE],
832                        "mymode")
833
834     # Generic object serialization
835     result = runner._encoder(NotImplemented, (rpc_defs.ED_OBJECT_DICT, inst))
836     _CheckBasics(result)
837     self.assertEqual(len(result["hvparams"]), 2)
838
839     result = runner._encoder(NotImplemented,
840                              (rpc_defs.ED_OBJECT_DICT_LIST, 5 * [inst]))
841     map(_CheckBasics, result)
842     map(lambda r: self.assertEqual(len(r["hvparams"]), 2), result)
843
844     # Just an instance
845     result = runner._encoder(NotImplemented, (rpc_defs.ED_INST_DICT, inst))
846     _CheckBasics(result)
847     self.assertEqual(result["beparams"][constants.BE_MAXMEM], 256)
848     self.assertEqual(result["hvparams"][constants.HV_CDROM_IMAGE_PATH], "bar")
849     self.assertEqual(result["hvparams"][constants.HV_ROOT_PATH], "/tmp")
850     self.assertEqual(result["osparams"], {
851       "role": "unknown",
852       })
853     self.assertEqual(len(result["hvparams"]),
854                      len(constants.HVC_DEFAULTS[constants.HT_KVM]))
855
856     # Instance with OS parameters
857     result = runner._encoder(NotImplemented,
858                              (rpc_defs.ED_INST_DICT_OSP_DP, (inst, {
859                                "role": "webserver",
860                                "other": "field",
861                              })))
862     _CheckBasics(result)
863     self.assertEqual(result["beparams"][constants.BE_MAXMEM], 256)
864     self.assertEqual(result["hvparams"][constants.HV_CDROM_IMAGE_PATH], "bar")
865     self.assertEqual(result["hvparams"][constants.HV_ROOT_PATH], "/tmp")
866     self.assertEqual(result["osparams"], {
867       "role": "webserver",
868       "other": "field",
869       })
870
871     # Instance with hypervisor and backend parameters
872     result = runner._encoder(NotImplemented,
873                              (rpc_defs.ED_INST_DICT_HVP_BEP_DP, (inst, {
874       constants.HT_KVM: {
875         constants.HV_BOOT_ORDER: "xyz",
876         },
877       }, {
878       constants.BE_VCPUS: 100,
879       constants.BE_MAXMEM: 4096,
880       })))
881     _CheckBasics(result)
882     self.assertEqual(result["beparams"][constants.BE_MAXMEM], 4096)
883     self.assertEqual(result["beparams"][constants.BE_VCPUS], 100)
884     self.assertEqual(result["hvparams"][constants.HT_KVM], {
885       constants.HV_BOOT_ORDER: "xyz",
886       })
887     self.assertEqual(result["disks"], [{
888       "dev_type": constants.DT_PLAIN,
889       "dynamic_params": {},
890       "size": 4096,
891       "logical_id": ("vg", "disk6120"),
892       "params": constants.DISK_DT_DEFAULTS[inst.disk_template],
893       }, {
894       "dev_type": constants.DT_PLAIN,
895       "dynamic_params": {},
896       "size": 1024,
897       "logical_id": ("vg", "disk8508"),
898       "params": constants.DISK_DT_DEFAULTS[inst.disk_template],
899       }])
900
901     self.assertTrue(compat.all(disk.params == {} for disk in inst.disks),
902                     msg="Configuration objects were modified")
903
904
905 class TestLegacyNodeInfo(unittest.TestCase):
906   KEY_BOOT = "bootid"
907   KEY_NAME = "name"
908   KEY_STORAGE_FREE = "storage_free"
909   KEY_STORAGE_TOTAL = "storage_size"
910   KEY_CPU_COUNT = "cpu_count"
911   KEY_SPINDLES_FREE = "spindles_free"
912   KEY_SPINDLES_TOTAL = "spindles_total"
913   KEY_STORAGE_TYPE = "type" # key for storage type
914   VAL_BOOT = 0
915   VAL_VG_NAME = "xy"
916   VAL_VG_FREE = 11
917   VAL_VG_TOTAL = 12
918   VAL_VG_TYPE = "lvm-vg"
919   VAL_CPU_COUNT = 2
920   VAL_PV_NAME = "ab"
921   VAL_PV_FREE = 31
922   VAL_PV_TOTAL = 32
923   VAL_PV_TYPE = "lvm-pv"
924   DICT_VG = {
925     KEY_NAME: VAL_VG_NAME,
926     KEY_STORAGE_FREE: VAL_VG_FREE,
927     KEY_STORAGE_TOTAL: VAL_VG_TOTAL,
928     KEY_STORAGE_TYPE: VAL_VG_TYPE,
929     }
930   DICT_HV = {KEY_CPU_COUNT: VAL_CPU_COUNT}
931   DICT_SP = {
932     KEY_STORAGE_TYPE: VAL_PV_TYPE,
933     KEY_NAME: VAL_PV_NAME,
934     KEY_STORAGE_FREE: VAL_PV_FREE,
935     KEY_STORAGE_TOTAL: VAL_PV_TOTAL,
936     }
937   STD_LST = [VAL_BOOT, [DICT_VG, DICT_SP], [DICT_HV]]
938   STD_DICT = {
939     KEY_BOOT: VAL_BOOT,
940     KEY_NAME: VAL_VG_NAME,
941     KEY_STORAGE_FREE: VAL_VG_FREE,
942     KEY_STORAGE_TOTAL: VAL_VG_TOTAL,
943     KEY_SPINDLES_FREE: VAL_PV_FREE,
944     KEY_SPINDLES_TOTAL: VAL_PV_TOTAL,
945     KEY_CPU_COUNT: VAL_CPU_COUNT,
946     }
947
948   def testWithSpindles(self):
949     result = rpc.MakeLegacyNodeInfo(self.STD_LST, constants.DT_PLAIN)
950     self.assertEqual(result, self.STD_DICT)
951
952   def testNoSpindles(self):
953     my_lst = [self.VAL_BOOT, [self.DICT_VG], [self.DICT_HV]]
954     result = rpc.MakeLegacyNodeInfo(my_lst, constants.DT_PLAIN)
955     expected_dict = dict((k,v) for k, v in self.STD_DICT.iteritems())
956     expected_dict[self.KEY_SPINDLES_FREE] = 0
957     expected_dict[self.KEY_SPINDLES_TOTAL] = 0
958     self.assertEqual(result, expected_dict)
959
960
961 if __name__ == "__main__":
962   testutils.GanetiTestProgram()