ace49af8993700a262a49f00557aaf09ce68ce26
[ganeti-local] / test / py / ganeti.rpc_unittest.py
1 #!/usr/bin/python
2 #
3
4 # Copyright (C) 2010, 2011, 2012 Google Inc.
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 # 02110-1301, USA.
20
21
22 """Script for testing ganeti.rpc"""
23
24 import os
25 import sys
26 import unittest
27 import random
28 import tempfile
29
30 from ganeti import constants
31 from ganeti import compat
32 from ganeti import rpc
33 from ganeti import rpc_defs
34 from ganeti import http
35 from ganeti import errors
36 from ganeti import serializer
37 from ganeti import objects
38 from ganeti import backend
39
40 import testutils
41 import mocks
42
43
44 class _FakeRequestProcessor:
45   def __init__(self, response_fn):
46     self._response_fn = response_fn
47     self.reqcount = 0
48
49   def __call__(self, reqs, lock_monitor_cb=None):
50     assert lock_monitor_cb is None or callable(lock_monitor_cb)
51     for req in reqs:
52       self.reqcount += 1
53       self._response_fn(req)
54
55
56 def GetFakeSimpleStoreClass(fn):
57   class FakeSimpleStore:
58     GetNodePrimaryIPList = fn
59     GetPrimaryIPFamily = lambda _: None
60
61   return FakeSimpleStore
62
63
64 def _RaiseNotImplemented():
65   """Simple wrapper to raise NotImplementedError.
66
67   """
68   raise NotImplementedError
69
70
71 class TestRpcProcessor(unittest.TestCase):
72   def _FakeAddressLookup(self, map):
73     return lambda node_list: [map.get(node) for node in node_list]
74
75   def _GetVersionResponse(self, req):
76     self.assertEqual(req.host, "127.0.0.1")
77     self.assertEqual(req.port, 24094)
78     self.assertEqual(req.path, "/version")
79     self.assertEqual(req.read_timeout, constants.RPC_TMO_URGENT)
80     req.success = True
81     req.resp_status_code = http.HTTP_OK
82     req.resp_body = serializer.DumpJson((True, 123))
83
84   def testVersionSuccess(self):
85     resolver = rpc._StaticResolver(["127.0.0.1"])
86     http_proc = _FakeRequestProcessor(self._GetVersionResponse)
87     proc = rpc._RpcProcessor(resolver, 24094)
88     result = proc(["localhost"], "version", {"localhost": ""}, 60,
89                   NotImplemented, _req_process_fn=http_proc)
90     self.assertEqual(result.keys(), ["localhost"])
91     lhresp = result["localhost"]
92     self.assertFalse(lhresp.offline)
93     self.assertEqual(lhresp.node, "localhost")
94     self.assertFalse(lhresp.fail_msg)
95     self.assertEqual(lhresp.payload, 123)
96     self.assertEqual(lhresp.call, "version")
97     lhresp.Raise("should not raise")
98     self.assertEqual(http_proc.reqcount, 1)
99
100   def _ReadTimeoutResponse(self, req):
101     self.assertEqual(req.host, "192.0.2.13")
102     self.assertEqual(req.port, 19176)
103     self.assertEqual(req.path, "/version")
104     self.assertEqual(req.read_timeout, 12356)
105     req.success = True
106     req.resp_status_code = http.HTTP_OK
107     req.resp_body = serializer.DumpJson((True, -1))
108
109   def testReadTimeout(self):
110     resolver = rpc._StaticResolver(["192.0.2.13"])
111     http_proc = _FakeRequestProcessor(self._ReadTimeoutResponse)
112     proc = rpc._RpcProcessor(resolver, 19176)
113     host = "node31856"
114     body = {host: ""}
115     result = proc([host], "version", body, 12356, NotImplemented,
116                   _req_process_fn=http_proc)
117     self.assertEqual(result.keys(), [host])
118     lhresp = result[host]
119     self.assertFalse(lhresp.offline)
120     self.assertEqual(lhresp.node, host)
121     self.assertFalse(lhresp.fail_msg)
122     self.assertEqual(lhresp.payload, -1)
123     self.assertEqual(lhresp.call, "version")
124     lhresp.Raise("should not raise")
125     self.assertEqual(http_proc.reqcount, 1)
126
127   def testOfflineNode(self):
128     resolver = rpc._StaticResolver([rpc._OFFLINE])
129     http_proc = _FakeRequestProcessor(NotImplemented)
130     proc = rpc._RpcProcessor(resolver, 30668)
131     host = "n17296"
132     body = {host: ""}
133     result = proc([host], "version", body, 60, NotImplemented,
134                   _req_process_fn=http_proc)
135     self.assertEqual(result.keys(), [host])
136     lhresp = result[host]
137     self.assertTrue(lhresp.offline)
138     self.assertEqual(lhresp.node, host)
139     self.assertTrue(lhresp.fail_msg)
140     self.assertFalse(lhresp.payload)
141     self.assertEqual(lhresp.call, "version")
142
143     # With a message
144     self.assertRaises(errors.OpExecError, lhresp.Raise, "should raise")
145
146     # No message
147     self.assertRaises(errors.OpExecError, lhresp.Raise, None)
148
149     self.assertEqual(http_proc.reqcount, 0)
150
151   def _GetMultiVersionResponse(self, req):
152     self.assert_(req.host.startswith("node"))
153     self.assertEqual(req.port, 23245)
154     self.assertEqual(req.path, "/version")
155     req.success = True
156     req.resp_status_code = http.HTTP_OK
157     req.resp_body = serializer.DumpJson((True, 987))
158
159   def testMultiVersionSuccess(self):
160     nodes = ["node%s" % i for i in range(50)]
161     body = dict((n, "") for n in nodes)
162     resolver = rpc._StaticResolver(nodes)
163     http_proc = _FakeRequestProcessor(self._GetMultiVersionResponse)
164     proc = rpc._RpcProcessor(resolver, 23245)
165     result = proc(nodes, "version", body, 60, NotImplemented,
166                   _req_process_fn=http_proc)
167     self.assertEqual(sorted(result.keys()), sorted(nodes))
168
169     for name in nodes:
170       lhresp = result[name]
171       self.assertFalse(lhresp.offline)
172       self.assertEqual(lhresp.node, name)
173       self.assertFalse(lhresp.fail_msg)
174       self.assertEqual(lhresp.payload, 987)
175       self.assertEqual(lhresp.call, "version")
176       lhresp.Raise("should not raise")
177
178     self.assertEqual(http_proc.reqcount, len(nodes))
179
180   def _GetVersionResponseFail(self, errinfo, req):
181     self.assertEqual(req.path, "/version")
182     req.success = True
183     req.resp_status_code = http.HTTP_OK
184     req.resp_body = serializer.DumpJson((False, errinfo))
185
186   def testVersionFailure(self):
187     resolver = rpc._StaticResolver(["aef9ur4i.example.com"])
188     proc = rpc._RpcProcessor(resolver, 5903)
189     for errinfo in [None, "Unknown error"]:
190       http_proc = \
191         _FakeRequestProcessor(compat.partial(self._GetVersionResponseFail,
192                                              errinfo))
193       host = "aef9ur4i.example.com"
194       body = {host: ""}
195       result = proc(body.keys(), "version", body, 60, NotImplemented,
196                     _req_process_fn=http_proc)
197       self.assertEqual(result.keys(), [host])
198       lhresp = result[host]
199       self.assertFalse(lhresp.offline)
200       self.assertEqual(lhresp.node, host)
201       self.assert_(lhresp.fail_msg)
202       self.assertFalse(lhresp.payload)
203       self.assertEqual(lhresp.call, "version")
204       self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
205       self.assertEqual(http_proc.reqcount, 1)
206
207   def _GetHttpErrorResponse(self, httperrnodes, failnodes, req):
208     self.assertEqual(req.path, "/vg_list")
209     self.assertEqual(req.port, 15165)
210
211     if req.host in httperrnodes:
212       req.success = False
213       req.error = "Node set up for HTTP errors"
214
215     elif req.host in failnodes:
216       req.success = True
217       req.resp_status_code = 404
218       req.resp_body = serializer.DumpJson({
219         "code": 404,
220         "message": "Method not found",
221         "explain": "Explanation goes here",
222         })
223     else:
224       req.success = True
225       req.resp_status_code = http.HTTP_OK
226       req.resp_body = serializer.DumpJson((True, hash(req.host)))
227
228   def testHttpError(self):
229     nodes = ["uaf6pbbv%s" % i for i in range(50)]
230     body = dict((n, "") for n in nodes)
231     resolver = rpc._StaticResolver(nodes)
232
233     httperrnodes = set(nodes[1::7])
234     self.assertEqual(len(httperrnodes), 7)
235
236     failnodes = set(nodes[2::3]) - httperrnodes
237     self.assertEqual(len(failnodes), 14)
238
239     self.assertEqual(len(set(nodes) - failnodes - httperrnodes), 29)
240
241     proc = rpc._RpcProcessor(resolver, 15165)
242     http_proc = \
243       _FakeRequestProcessor(compat.partial(self._GetHttpErrorResponse,
244                                            httperrnodes, failnodes))
245     result = proc(nodes, "vg_list", body,
246                   constants.RPC_TMO_URGENT, NotImplemented,
247                   _req_process_fn=http_proc)
248     self.assertEqual(sorted(result.keys()), sorted(nodes))
249
250     for name in nodes:
251       lhresp = result[name]
252       self.assertFalse(lhresp.offline)
253       self.assertEqual(lhresp.node, name)
254       self.assertEqual(lhresp.call, "vg_list")
255
256       if name in httperrnodes:
257         self.assert_(lhresp.fail_msg)
258         self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
259       elif name in failnodes:
260         self.assert_(lhresp.fail_msg)
261         self.assertRaises(errors.OpPrereqError, lhresp.Raise, "failed",
262                           prereq=True, ecode=errors.ECODE_INVAL)
263       else:
264         self.assertFalse(lhresp.fail_msg)
265         self.assertEqual(lhresp.payload, hash(name))
266         lhresp.Raise("should not raise")
267
268     self.assertEqual(http_proc.reqcount, len(nodes))
269
270   def _GetInvalidResponseA(self, req):
271     self.assertEqual(req.path, "/version")
272     req.success = True
273     req.resp_status_code = http.HTTP_OK
274     req.resp_body = serializer.DumpJson(("This", "is", "an", "invalid",
275                                          "response", "!", 1, 2, 3))
276
277   def _GetInvalidResponseB(self, req):
278     self.assertEqual(req.path, "/version")
279     req.success = True
280     req.resp_status_code = http.HTTP_OK
281     req.resp_body = serializer.DumpJson("invalid response")
282
283   def testInvalidResponse(self):
284     resolver = rpc._StaticResolver(["oqo7lanhly.example.com"])
285     proc = rpc._RpcProcessor(resolver, 19978)
286
287     for fn in [self._GetInvalidResponseA, self._GetInvalidResponseB]:
288       http_proc = _FakeRequestProcessor(fn)
289       host = "oqo7lanhly.example.com"
290       body = {host: ""}
291       result = proc([host], "version", body, 60, NotImplemented,
292                     _req_process_fn=http_proc)
293       self.assertEqual(result.keys(), [host])
294       lhresp = result[host]
295       self.assertFalse(lhresp.offline)
296       self.assertEqual(lhresp.node, host)
297       self.assert_(lhresp.fail_msg)
298       self.assertFalse(lhresp.payload)
299       self.assertEqual(lhresp.call, "version")
300       self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
301       self.assertEqual(http_proc.reqcount, 1)
302
303   def _GetBodyTestResponse(self, test_data, req):
304     self.assertEqual(req.host, "192.0.2.84")
305     self.assertEqual(req.port, 18700)
306     self.assertEqual(req.path, "/upload_file")
307     self.assertEqual(serializer.LoadJson(req.post_data), test_data)
308     req.success = True
309     req.resp_status_code = http.HTTP_OK
310     req.resp_body = serializer.DumpJson((True, None))
311
312   def testResponseBody(self):
313     test_data = {
314       "Hello": "World",
315       "xyz": range(10),
316       }
317     resolver = rpc._StaticResolver(["192.0.2.84"])
318     http_proc = _FakeRequestProcessor(compat.partial(self._GetBodyTestResponse,
319                                                      test_data))
320     proc = rpc._RpcProcessor(resolver, 18700)
321     host = "node19759"
322     body = {host: serializer.DumpJson(test_data)}
323     result = proc([host], "upload_file", body, 30, NotImplemented,
324                   _req_process_fn=http_proc)
325     self.assertEqual(result.keys(), [host])
326     lhresp = result[host]
327     self.assertFalse(lhresp.offline)
328     self.assertEqual(lhresp.node, host)
329     self.assertFalse(lhresp.fail_msg)
330     self.assertEqual(lhresp.payload, None)
331     self.assertEqual(lhresp.call, "upload_file")
332     lhresp.Raise("should not raise")
333     self.assertEqual(http_proc.reqcount, 1)
334
335
336 class TestSsconfResolver(unittest.TestCase):
337   def testSsconfLookup(self):
338     addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
339     node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
340     node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
341     ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
342     result = rpc._SsconfResolver(True, node_list, NotImplemented,
343                                  ssc=ssc, nslookup_fn=NotImplemented)
344     self.assertEqual(result, zip(node_list, addr_list))
345
346   def testNsLookup(self):
347     addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
348     node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
349     ssc = GetFakeSimpleStoreClass(lambda _: [])
350     node_addr_map = dict(zip(node_list, addr_list))
351     nslookup_fn = lambda name, family=None: node_addr_map.get(name)
352     result = rpc._SsconfResolver(True, node_list, NotImplemented,
353                                  ssc=ssc, nslookup_fn=nslookup_fn)
354     self.assertEqual(result, zip(node_list, addr_list))
355
356   def testDisabledSsconfIp(self):
357     addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
358     node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
359     ssc = GetFakeSimpleStoreClass(_RaiseNotImplemented)
360     node_addr_map = dict(zip(node_list, addr_list))
361     nslookup_fn = lambda name, family=None: node_addr_map.get(name)
362     result = rpc._SsconfResolver(False, node_list, NotImplemented,
363                                  ssc=ssc, nslookup_fn=nslookup_fn)
364     self.assertEqual(result, zip(node_list, addr_list))
365
366   def testBothLookups(self):
367     addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
368     node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
369     n = len(addr_list) / 2
370     node_addr_list = [" ".join(t) for t in zip(node_list[n:], addr_list[n:])]
371     ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
372     node_addr_map = dict(zip(node_list[:n], addr_list[:n]))
373     nslookup_fn = lambda name, family=None: node_addr_map.get(name)
374     result = rpc._SsconfResolver(True, node_list, NotImplemented,
375                                  ssc=ssc, nslookup_fn=nslookup_fn)
376     self.assertEqual(result, zip(node_list, addr_list))
377
378   def testAddressLookupIPv6(self):
379     addr_list = ["2001:db8::%d" % n for n in range(0, 255, 11)]
380     node_list = ["node%d.example.com" % n for n in range(0, 255, 11)]
381     node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
382     ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
383     result = rpc._SsconfResolver(True, node_list, NotImplemented,
384                                  ssc=ssc, nslookup_fn=NotImplemented)
385     self.assertEqual(result, zip(node_list, addr_list))
386
387
388 class TestStaticResolver(unittest.TestCase):
389   def test(self):
390     addresses = ["192.0.2.%d" % n for n in range(0, 123, 7)]
391     nodes = ["node%s.example.com" % n for n in range(0, 123, 7)]
392     res = rpc._StaticResolver(addresses)
393     self.assertEqual(res(nodes, NotImplemented), zip(nodes, addresses))
394
395   def testWrongLength(self):
396     res = rpc._StaticResolver([])
397     self.assertRaises(AssertionError, res, ["abc"], NotImplemented)
398
399
400 class TestNodeConfigResolver(unittest.TestCase):
401   @staticmethod
402   def _GetSingleOnlineNode(name):
403     assert name == "node90.example.com"
404     return objects.Node(name=name, offline=False, primary_ip="192.0.2.90")
405
406   @staticmethod
407   def _GetSingleOfflineNode(name):
408     assert name == "node100.example.com"
409     return objects.Node(name=name, offline=True, primary_ip="192.0.2.100")
410
411   def testSingleOnline(self):
412     self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOnlineNode,
413                                              NotImplemented,
414                                              ["node90.example.com"], None),
415                      [("node90.example.com", "192.0.2.90")])
416
417   def testSingleOffline(self):
418     self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOfflineNode,
419                                              NotImplemented,
420                                              ["node100.example.com"], None),
421                      [("node100.example.com", rpc._OFFLINE)])
422
423   def testSingleOfflineWithAcceptOffline(self):
424     fn = self._GetSingleOfflineNode
425     assert fn("node100.example.com").offline
426     self.assertEqual(rpc._NodeConfigResolver(fn, NotImplemented,
427                                              ["node100.example.com"],
428                                              rpc_defs.ACCEPT_OFFLINE_NODE),
429                      [("node100.example.com", "192.0.2.100")])
430     for i in [False, True, "", "Hello", 0, 1]:
431       self.assertRaises(AssertionError, rpc._NodeConfigResolver,
432                         fn, NotImplemented, ["node100.example.com"], i)
433
434   def testUnknownSingleNode(self):
435     self.assertEqual(rpc._NodeConfigResolver(lambda _: None, NotImplemented,
436                                              ["node110.example.com"], None),
437                      [("node110.example.com", "node110.example.com")])
438
439   def testMultiEmpty(self):
440     self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
441                                              lambda: {},
442                                              [], None),
443                      [])
444
445   def testMultiSomeOffline(self):
446     nodes = dict(("node%s.example.com" % i,
447                   objects.Node(name="node%s.example.com" % i,
448                                offline=((i % 3) == 0),
449                                primary_ip="192.0.2.%s" % i))
450                   for i in range(1, 255))
451
452     # Resolve no names
453     self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
454                                              lambda: nodes,
455                                              [], None),
456                      [])
457
458     # Offline, online and unknown hosts
459     self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
460                                              lambda: nodes,
461                                              ["node3.example.com",
462                                               "node92.example.com",
463                                               "node54.example.com",
464                                               "unknown.example.com",],
465                                              None), [
466       ("node3.example.com", rpc._OFFLINE),
467       ("node92.example.com", "192.0.2.92"),
468       ("node54.example.com", rpc._OFFLINE),
469       ("unknown.example.com", "unknown.example.com"),
470       ])
471
472
473 class TestCompress(unittest.TestCase):
474   def test(self):
475     for data in ["", "Hello", "Hello World!\nnew\nlines"]:
476       self.assertEqual(rpc._Compress(data),
477                        (constants.RPC_ENCODING_NONE, data))
478
479     for data in [512 * " ", 5242 * "Hello World!\n"]:
480       compressed = rpc._Compress(data)
481       self.assertEqual(len(compressed), 2)
482       self.assertEqual(backend._Decompress(compressed), data)
483
484   def testDecompression(self):
485     self.assertRaises(AssertionError, backend._Decompress, "")
486     self.assertRaises(AssertionError, backend._Decompress, [""])
487     self.assertRaises(AssertionError, backend._Decompress,
488                       ("unknown compression", "data"))
489     self.assertRaises(Exception, backend._Decompress,
490                       (constants.RPC_ENCODING_ZLIB_BASE64, "invalid zlib data"))
491
492
493 class TestRpcClientBase(unittest.TestCase):
494   def testNoHosts(self):
495     cdef = ("test_call", NotImplemented, None, constants.RPC_TMO_SLOW, [],
496             None, None, NotImplemented)
497     http_proc = _FakeRequestProcessor(NotImplemented)
498     client = rpc._RpcClientBase(rpc._StaticResolver([]), NotImplemented,
499                                 _req_process_fn=http_proc)
500     self.assertEqual(client._Call(cdef, [], []), {})
501
502     # Test wrong number of arguments
503     self.assertRaises(errors.ProgrammerError, client._Call,
504                       cdef, [], [0, 1, 2])
505
506   def testTimeout(self):
507     def _CalcTimeout((arg1, arg2)):
508       return arg1 + arg2
509
510     def _VerifyRequest(exp_timeout, req):
511       self.assertEqual(req.read_timeout, exp_timeout)
512
513       req.success = True
514       req.resp_status_code = http.HTTP_OK
515       req.resp_body = serializer.DumpJson((True, hex(req.read_timeout)))
516
517     resolver = rpc._StaticResolver([
518       "192.0.2.1",
519       "192.0.2.2",
520       ])
521
522     nodes = [
523       "node1.example.com",
524       "node2.example.com",
525       ]
526
527     tests = [(100, None, 100), (30, None, 30)]
528     tests.extend((_CalcTimeout, i, i + 300)
529                  for i in [0, 5, 16485, 30516])
530
531     for timeout, arg1, exp_timeout in tests:
532       cdef = ("test_call", NotImplemented, None, timeout, [
533         ("arg1", None, NotImplemented),
534         ("arg2", None, NotImplemented),
535         ], None, None, NotImplemented)
536
537       http_proc = _FakeRequestProcessor(compat.partial(_VerifyRequest,
538                                                        exp_timeout))
539       client = rpc._RpcClientBase(resolver, NotImplemented,
540                                   _req_process_fn=http_proc)
541       result = client._Call(cdef, nodes, [arg1, 300])
542       self.assertEqual(len(result), len(nodes))
543       self.assertTrue(compat.all(not res.fail_msg and
544                                  res.payload == hex(exp_timeout)
545                                  for res in result.values()))
546
547   def testArgumentEncoder(self):
548     (AT1, AT2) = range(1, 3)
549
550     resolver = rpc._StaticResolver([
551       "192.0.2.5",
552       "192.0.2.6",
553       ])
554
555     nodes = [
556       "node5.example.com",
557       "node6.example.com",
558       ]
559
560     encoders = {
561       AT1: hex,
562       AT2: hash,
563       }
564
565     cdef = ("test_call", NotImplemented, None, constants.RPC_TMO_NORMAL, [
566       ("arg0", None, NotImplemented),
567       ("arg1", AT1, NotImplemented),
568       ("arg1", AT2, NotImplemented),
569       ], None, None, NotImplemented)
570
571     def _VerifyRequest(req):
572       req.success = True
573       req.resp_status_code = http.HTTP_OK
574       req.resp_body = serializer.DumpJson((True, req.post_data))
575
576     http_proc = _FakeRequestProcessor(_VerifyRequest)
577
578     for num in [0, 3796, 9032119]:
579       client = rpc._RpcClientBase(resolver, encoders.get,
580                                   _req_process_fn=http_proc)
581       result = client._Call(cdef, nodes, ["foo", num, "Hello%s" % num])
582       self.assertEqual(len(result), len(nodes))
583       for res in result.values():
584         self.assertFalse(res.fail_msg)
585         self.assertEqual(serializer.LoadJson(res.payload),
586                          ["foo", hex(num), hash("Hello%s" % num)])
587
588   def testPostProc(self):
589     def _VerifyRequest(nums, req):
590       req.success = True
591       req.resp_status_code = http.HTTP_OK
592       req.resp_body = serializer.DumpJson((True, nums))
593
594     resolver = rpc._StaticResolver([
595       "192.0.2.90",
596       "192.0.2.95",
597       ])
598
599     nodes = [
600       "node90.example.com",
601       "node95.example.com",
602       ]
603
604     def _PostProc(res):
605       self.assertFalse(res.fail_msg)
606       res.payload = sum(res.payload)
607       return res
608
609     cdef = ("test_call", NotImplemented, None, constants.RPC_TMO_NORMAL, [],
610             None, _PostProc, NotImplemented)
611
612     # Seeded random generator
613     rnd = random.Random(20299)
614
615     for i in [0, 4, 74, 1391]:
616       nums = [rnd.randint(0, 1000) for _ in range(i)]
617       http_proc = _FakeRequestProcessor(compat.partial(_VerifyRequest, nums))
618       client = rpc._RpcClientBase(resolver, NotImplemented,
619                                   _req_process_fn=http_proc)
620       result = client._Call(cdef, nodes, [])
621       self.assertEqual(len(result), len(nodes))
622       for res in result.values():
623         self.assertFalse(res.fail_msg)
624         self.assertEqual(res.payload, sum(nums))
625
626   def testPreProc(self):
627     def _VerifyRequest(req):
628       req.success = True
629       req.resp_status_code = http.HTTP_OK
630       req.resp_body = serializer.DumpJson((True, req.post_data))
631
632     resolver = rpc._StaticResolver([
633       "192.0.2.30",
634       "192.0.2.35",
635       ])
636
637     nodes = [
638       "node30.example.com",
639       "node35.example.com",
640       ]
641
642     def _PreProc(node, data):
643       self.assertEqual(len(data), 1)
644       return data[0] + node
645
646     cdef = ("test_call", NotImplemented, None, constants.RPC_TMO_NORMAL, [
647       ("arg0", None, NotImplemented),
648       ], _PreProc, None, NotImplemented)
649
650     http_proc = _FakeRequestProcessor(_VerifyRequest)
651     client = rpc._RpcClientBase(resolver, NotImplemented,
652                                 _req_process_fn=http_proc)
653
654     for prefix in ["foo", "bar", "baz"]:
655       result = client._Call(cdef, nodes, [prefix])
656       self.assertEqual(len(result), len(nodes))
657       for (idx, (node, res)) in enumerate(result.items()):
658         self.assertFalse(res.fail_msg)
659         self.assertEqual(serializer.LoadJson(res.payload), prefix + node)
660
661   def testResolverOptions(self):
662     def _VerifyRequest(req):
663       req.success = True
664       req.resp_status_code = http.HTTP_OK
665       req.resp_body = serializer.DumpJson((True, req.post_data))
666
667     nodes = [
668       "node30.example.com",
669       "node35.example.com",
670       ]
671
672     def _Resolver(expected, hosts, options):
673       self.assertEqual(hosts, nodes)
674       self.assertEqual(options, expected)
675       return zip(hosts, nodes)
676
677     def _DynamicResolverOptions((arg0, )):
678       return sum(arg0)
679
680     tests = [
681       (None, None, None),
682       (rpc_defs.ACCEPT_OFFLINE_NODE, None, rpc_defs.ACCEPT_OFFLINE_NODE),
683       (False, None, False),
684       (True, None, True),
685       (0, None, 0),
686       (_DynamicResolverOptions, [1, 2, 3], 6),
687       (_DynamicResolverOptions, range(4, 19), 165),
688       ]
689
690     for (resolver_opts, arg0, expected) in tests:
691       cdef = ("test_call", NotImplemented, resolver_opts,
692               constants.RPC_TMO_NORMAL, [
693         ("arg0", None, NotImplemented),
694         ], None, None, NotImplemented)
695
696       http_proc = _FakeRequestProcessor(_VerifyRequest)
697
698       client = rpc._RpcClientBase(compat.partial(_Resolver, expected),
699                                   NotImplemented, _req_process_fn=http_proc)
700       result = client._Call(cdef, nodes, [arg0])
701       self.assertEqual(len(result), len(nodes))
702       for (idx, (node, res)) in enumerate(result.items()):
703         self.assertFalse(res.fail_msg)
704
705
706 class _FakeConfigForRpcRunner:
707   GetAllNodesInfo = NotImplemented
708
709   def __init__(self, cluster=NotImplemented):
710     self._cluster = cluster
711
712   def GetNodeInfo(self, name):
713     return objects.Node(name=name)
714
715   def GetClusterInfo(self):
716     return self._cluster
717
718   def GetInstanceDiskParams(self, _):
719     return constants.DISK_DT_DEFAULTS
720
721
722 class TestRpcRunner(unittest.TestCase):
723   def testUploadFile(self):
724     data = 1779 * "Hello World\n"
725
726     tmpfile = tempfile.NamedTemporaryFile()
727     tmpfile.write(data)
728     tmpfile.flush()
729     st = os.stat(tmpfile.name)
730
731     def _VerifyRequest(req):
732       (uldata, ) = serializer.LoadJson(req.post_data)
733       self.assertEqual(len(uldata), 7)
734       self.assertEqual(uldata[0], tmpfile.name)
735       self.assertEqual(list(uldata[1]), list(rpc._Compress(data)))
736       self.assertEqual(uldata[2], st.st_mode)
737       self.assertEqual(uldata[3], "user%s" % os.getuid())
738       self.assertEqual(uldata[4], "group%s" % os.getgid())
739       self.assertTrue(uldata[5] is not None)
740       self.assertEqual(uldata[6], st.st_mtime)
741
742       req.success = True
743       req.resp_status_code = http.HTTP_OK
744       req.resp_body = serializer.DumpJson((True, None))
745
746     http_proc = _FakeRequestProcessor(_VerifyRequest)
747
748     std_runner = rpc.RpcRunner(_FakeConfigForRpcRunner(), None,
749                                _req_process_fn=http_proc,
750                                _getents=mocks.FakeGetentResolver)
751
752     cfg_runner = rpc.ConfigRunner(None, ["192.0.2.13"],
753                                   _req_process_fn=http_proc,
754                                   _getents=mocks.FakeGetentResolver)
755
756     nodes = [
757       "node1.example.com",
758       ]
759
760     for runner in [std_runner, cfg_runner]:
761       result = runner.call_upload_file(nodes, tmpfile.name)
762       self.assertEqual(len(result), len(nodes))
763       for (idx, (node, res)) in enumerate(result.items()):
764         self.assertFalse(res.fail_msg)
765
766   def testEncodeInstance(self):
767     cluster = objects.Cluster(hvparams={
768       constants.HT_KVM: {
769         constants.HV_BLOCKDEV_PREFIX: "foo",
770         },
771       },
772       beparams={
773         constants.PP_DEFAULT: {
774           constants.BE_MAXMEM: 8192,
775           },
776         },
777       os_hvp={},
778       osparams={
779         "linux": {
780           "role": "unknown",
781           },
782         })
783     cluster.UpgradeConfig()
784
785     inst = objects.Instance(name="inst1.example.com",
786       hypervisor=constants.HT_FAKE,
787       os="linux",
788       hvparams={
789         constants.HT_KVM: {
790           constants.HV_BLOCKDEV_PREFIX: "bar",
791           constants.HV_ROOT_PATH: "/tmp",
792           },
793         },
794       beparams={
795         constants.BE_MINMEM: 128,
796         constants.BE_MAXMEM: 256,
797         },
798       nics=[
799         objects.NIC(nicparams={
800           constants.NIC_MODE: "mymode",
801           }),
802         ],
803       disk_template=constants.DT_PLAIN,
804       disks=[
805         objects.Disk(dev_type=constants.LD_LV, size=4096,
806                      logical_id=("vg", "disk6120")),
807         objects.Disk(dev_type=constants.LD_LV, size=1024,
808                      logical_id=("vg", "disk8508")),
809         ])
810     inst.UpgradeConfig()
811
812     cfg = _FakeConfigForRpcRunner(cluster=cluster)
813     runner = rpc.RpcRunner(cfg, None,
814                            _req_process_fn=NotImplemented,
815                            _getents=mocks.FakeGetentResolver)
816
817     def _CheckBasics(result):
818       self.assertEqual(result["name"], "inst1.example.com")
819       self.assertEqual(result["os"], "linux")
820       self.assertEqual(result["beparams"][constants.BE_MINMEM], 128)
821       self.assertEqual(len(result["hvparams"]), 1)
822       self.assertEqual(len(result["nics"]), 1)
823       self.assertEqual(result["nics"][0]["nicparams"][constants.NIC_MODE],
824                        "mymode")
825
826     # Generic object serialization
827     result = runner._encoder((rpc_defs.ED_OBJECT_DICT, inst))
828     _CheckBasics(result)
829
830     result = runner._encoder((rpc_defs.ED_OBJECT_DICT_LIST, 5 * [inst]))
831     map(_CheckBasics, result)
832
833     # Just an instance
834     result = runner._encoder((rpc_defs.ED_INST_DICT, inst))
835     _CheckBasics(result)
836     self.assertEqual(result["beparams"][constants.BE_MAXMEM], 256)
837     self.assertEqual(result["hvparams"][constants.HT_KVM], {
838       constants.HV_BLOCKDEV_PREFIX: "bar",
839       constants.HV_ROOT_PATH: "/tmp",
840       })
841     self.assertEqual(result["osparams"], {
842       "role": "unknown",
843       })
844
845     # Instance with OS parameters
846     result = runner._encoder((rpc_defs.ED_INST_DICT_OSP_DP, (inst, {
847       "role": "webserver",
848       "other": "field",
849       })))
850     _CheckBasics(result)
851     self.assertEqual(result["beparams"][constants.BE_MAXMEM], 256)
852     self.assertEqual(result["hvparams"][constants.HT_KVM], {
853       constants.HV_BLOCKDEV_PREFIX: "bar",
854       constants.HV_ROOT_PATH: "/tmp",
855       })
856     self.assertEqual(result["osparams"], {
857       "role": "webserver",
858       "other": "field",
859       })
860
861     # Instance with hypervisor and backend parameters
862     result = runner._encoder((rpc_defs.ED_INST_DICT_HVP_BEP_DP, (inst, {
863       constants.HT_KVM: {
864         constants.HV_BOOT_ORDER: "xyz",
865         },
866       }, {
867       constants.BE_VCPUS: 100,
868       constants.BE_MAXMEM: 4096,
869       })))
870     _CheckBasics(result)
871     self.assertEqual(result["beparams"][constants.BE_MAXMEM], 4096)
872     self.assertEqual(result["beparams"][constants.BE_VCPUS], 100)
873     self.assertEqual(result["hvparams"][constants.HT_KVM], {
874       constants.HV_BOOT_ORDER: "xyz",
875       })
876     self.assertEqual(result["disks"], [{
877       "dev_type": constants.LD_LV,
878       "size": 4096,
879       "logical_id": ("vg", "disk6120"),
880       "params": constants.DISK_DT_DEFAULTS[inst.disk_template],
881       }, {
882       "dev_type": constants.LD_LV,
883       "size": 1024,
884       "logical_id": ("vg", "disk8508"),
885       "params": constants.DISK_DT_DEFAULTS[inst.disk_template],
886       }])
887
888     self.assertTrue(compat.all(disk.params == {} for disk in inst.disks),
889                     msg="Configuration objects were modified")
890
891
892 if __name__ == "__main__":
893   testutils.GanetiTestProgram()