Index nodes by their UUID
[ganeti-local] / test / py / ganeti.rpc_unittest.py
1 #!/usr/bin/python
2 #
3
4 # Copyright (C) 2010, 2011, 2012, 2013 Google Inc.
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 # 02110-1301, USA.
20
21
22 """Script for testing ganeti.rpc"""
23
24 import os
25 import sys
26 import unittest
27 import random
28 import tempfile
29
30 from ganeti import constants
31 from ganeti import compat
32 from ganeti import rpc
33 from ganeti import rpc_defs
34 from ganeti import http
35 from ganeti import errors
36 from ganeti import serializer
37 from ganeti import objects
38 from ganeti import backend
39
40 import testutils
41 import mocks
42
43
44 class _FakeRequestProcessor:
45   def __init__(self, response_fn):
46     self._response_fn = response_fn
47     self.reqcount = 0
48
49   def __call__(self, reqs, lock_monitor_cb=None):
50     assert lock_monitor_cb is None or callable(lock_monitor_cb)
51     for req in reqs:
52       self.reqcount += 1
53       self._response_fn(req)
54
55
56 def GetFakeSimpleStoreClass(fn):
57   class FakeSimpleStore:
58     GetNodePrimaryIPList = fn
59     GetPrimaryIPFamily = lambda _: None
60
61   return FakeSimpleStore
62
63
64 def _RaiseNotImplemented():
65   """Simple wrapper to raise NotImplementedError.
66
67   """
68   raise NotImplementedError
69
70
71 class TestRpcProcessor(unittest.TestCase):
72   def _FakeAddressLookup(self, map):
73     return lambda node_list: [map.get(node) for node in node_list]
74
75   def _GetVersionResponse(self, req):
76     self.assertEqual(req.host, "127.0.0.1")
77     self.assertEqual(req.port, 24094)
78     self.assertEqual(req.path, "/version")
79     self.assertEqual(req.read_timeout, constants.RPC_TMO_URGENT)
80     req.success = True
81     req.resp_status_code = http.HTTP_OK
82     req.resp_body = serializer.DumpJson((True, 123))
83
84   def testVersionSuccess(self):
85     resolver = rpc._StaticResolver(["127.0.0.1"])
86     http_proc = _FakeRequestProcessor(self._GetVersionResponse)
87     proc = rpc._RpcProcessor(resolver, 24094)
88     result = proc(["localhost"], "version", {"localhost": ""}, 60,
89                   NotImplemented, _req_process_fn=http_proc)
90     self.assertEqual(result.keys(), ["localhost"])
91     lhresp = result["localhost"]
92     self.assertFalse(lhresp.offline)
93     self.assertEqual(lhresp.node, "localhost")
94     self.assertFalse(lhresp.fail_msg)
95     self.assertEqual(lhresp.payload, 123)
96     self.assertEqual(lhresp.call, "version")
97     lhresp.Raise("should not raise")
98     self.assertEqual(http_proc.reqcount, 1)
99
100   def _ReadTimeoutResponse(self, req):
101     self.assertEqual(req.host, "192.0.2.13")
102     self.assertEqual(req.port, 19176)
103     self.assertEqual(req.path, "/version")
104     self.assertEqual(req.read_timeout, 12356)
105     req.success = True
106     req.resp_status_code = http.HTTP_OK
107     req.resp_body = serializer.DumpJson((True, -1))
108
109   def testReadTimeout(self):
110     resolver = rpc._StaticResolver(["192.0.2.13"])
111     http_proc = _FakeRequestProcessor(self._ReadTimeoutResponse)
112     proc = rpc._RpcProcessor(resolver, 19176)
113     host = "node31856"
114     body = {host: ""}
115     result = proc([host], "version", body, 12356, NotImplemented,
116                   _req_process_fn=http_proc)
117     self.assertEqual(result.keys(), [host])
118     lhresp = result[host]
119     self.assertFalse(lhresp.offline)
120     self.assertEqual(lhresp.node, host)
121     self.assertFalse(lhresp.fail_msg)
122     self.assertEqual(lhresp.payload, -1)
123     self.assertEqual(lhresp.call, "version")
124     lhresp.Raise("should not raise")
125     self.assertEqual(http_proc.reqcount, 1)
126
127   def testOfflineNode(self):
128     resolver = rpc._StaticResolver([rpc._OFFLINE])
129     http_proc = _FakeRequestProcessor(NotImplemented)
130     proc = rpc._RpcProcessor(resolver, 30668)
131     host = "n17296"
132     body = {host: ""}
133     result = proc([host], "version", body, 60, NotImplemented,
134                   _req_process_fn=http_proc)
135     self.assertEqual(result.keys(), [host])
136     lhresp = result[host]
137     self.assertTrue(lhresp.offline)
138     self.assertEqual(lhresp.node, host)
139     self.assertTrue(lhresp.fail_msg)
140     self.assertFalse(lhresp.payload)
141     self.assertEqual(lhresp.call, "version")
142
143     # With a message
144     self.assertRaises(errors.OpExecError, lhresp.Raise, "should raise")
145
146     # No message
147     self.assertRaises(errors.OpExecError, lhresp.Raise, None)
148
149     self.assertEqual(http_proc.reqcount, 0)
150
151   def _GetMultiVersionResponse(self, req):
152     self.assert_(req.host.startswith("node"))
153     self.assertEqual(req.port, 23245)
154     self.assertEqual(req.path, "/version")
155     req.success = True
156     req.resp_status_code = http.HTTP_OK
157     req.resp_body = serializer.DumpJson((True, 987))
158
159   def testMultiVersionSuccess(self):
160     nodes = ["node%s" % i for i in range(50)]
161     body = dict((n, "") for n in nodes)
162     resolver = rpc._StaticResolver(nodes)
163     http_proc = _FakeRequestProcessor(self._GetMultiVersionResponse)
164     proc = rpc._RpcProcessor(resolver, 23245)
165     result = proc(nodes, "version", body, 60, NotImplemented,
166                   _req_process_fn=http_proc)
167     self.assertEqual(sorted(result.keys()), sorted(nodes))
168
169     for name in nodes:
170       lhresp = result[name]
171       self.assertFalse(lhresp.offline)
172       self.assertEqual(lhresp.node, name)
173       self.assertFalse(lhresp.fail_msg)
174       self.assertEqual(lhresp.payload, 987)
175       self.assertEqual(lhresp.call, "version")
176       lhresp.Raise("should not raise")
177
178     self.assertEqual(http_proc.reqcount, len(nodes))
179
180   def _GetVersionResponseFail(self, errinfo, req):
181     self.assertEqual(req.path, "/version")
182     req.success = True
183     req.resp_status_code = http.HTTP_OK
184     req.resp_body = serializer.DumpJson((False, errinfo))
185
186   def testVersionFailure(self):
187     resolver = rpc._StaticResolver(["aef9ur4i.example.com"])
188     proc = rpc._RpcProcessor(resolver, 5903)
189     for errinfo in [None, "Unknown error"]:
190       http_proc = \
191         _FakeRequestProcessor(compat.partial(self._GetVersionResponseFail,
192                                              errinfo))
193       host = "aef9ur4i.example.com"
194       body = {host: ""}
195       result = proc(body.keys(), "version", body, 60, NotImplemented,
196                     _req_process_fn=http_proc)
197       self.assertEqual(result.keys(), [host])
198       lhresp = result[host]
199       self.assertFalse(lhresp.offline)
200       self.assertEqual(lhresp.node, host)
201       self.assert_(lhresp.fail_msg)
202       self.assertFalse(lhresp.payload)
203       self.assertEqual(lhresp.call, "version")
204       self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
205       self.assertEqual(http_proc.reqcount, 1)
206
207   def _GetHttpErrorResponse(self, httperrnodes, failnodes, req):
208     self.assertEqual(req.path, "/vg_list")
209     self.assertEqual(req.port, 15165)
210
211     if req.host in httperrnodes:
212       req.success = False
213       req.error = "Node set up for HTTP errors"
214
215     elif req.host in failnodes:
216       req.success = True
217       req.resp_status_code = 404
218       req.resp_body = serializer.DumpJson({
219         "code": 404,
220         "message": "Method not found",
221         "explain": "Explanation goes here",
222         })
223     else:
224       req.success = True
225       req.resp_status_code = http.HTTP_OK
226       req.resp_body = serializer.DumpJson((True, hash(req.host)))
227
228   def testHttpError(self):
229     nodes = ["uaf6pbbv%s" % i for i in range(50)]
230     body = dict((n, "") for n in nodes)
231     resolver = rpc._StaticResolver(nodes)
232
233     httperrnodes = set(nodes[1::7])
234     self.assertEqual(len(httperrnodes), 7)
235
236     failnodes = set(nodes[2::3]) - httperrnodes
237     self.assertEqual(len(failnodes), 14)
238
239     self.assertEqual(len(set(nodes) - failnodes - httperrnodes), 29)
240
241     proc = rpc._RpcProcessor(resolver, 15165)
242     http_proc = \
243       _FakeRequestProcessor(compat.partial(self._GetHttpErrorResponse,
244                                            httperrnodes, failnodes))
245     result = proc(nodes, "vg_list", body,
246                   constants.RPC_TMO_URGENT, NotImplemented,
247                   _req_process_fn=http_proc)
248     self.assertEqual(sorted(result.keys()), sorted(nodes))
249
250     for name in nodes:
251       lhresp = result[name]
252       self.assertFalse(lhresp.offline)
253       self.assertEqual(lhresp.node, name)
254       self.assertEqual(lhresp.call, "vg_list")
255
256       if name in httperrnodes:
257         self.assert_(lhresp.fail_msg)
258         self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
259       elif name in failnodes:
260         self.assert_(lhresp.fail_msg)
261         self.assertRaises(errors.OpPrereqError, lhresp.Raise, "failed",
262                           prereq=True, ecode=errors.ECODE_INVAL)
263       else:
264         self.assertFalse(lhresp.fail_msg)
265         self.assertEqual(lhresp.payload, hash(name))
266         lhresp.Raise("should not raise")
267
268     self.assertEqual(http_proc.reqcount, len(nodes))
269
270   def _GetInvalidResponseA(self, req):
271     self.assertEqual(req.path, "/version")
272     req.success = True
273     req.resp_status_code = http.HTTP_OK
274     req.resp_body = serializer.DumpJson(("This", "is", "an", "invalid",
275                                          "response", "!", 1, 2, 3))
276
277   def _GetInvalidResponseB(self, req):
278     self.assertEqual(req.path, "/version")
279     req.success = True
280     req.resp_status_code = http.HTTP_OK
281     req.resp_body = serializer.DumpJson("invalid response")
282
283   def testInvalidResponse(self):
284     resolver = rpc._StaticResolver(["oqo7lanhly.example.com"])
285     proc = rpc._RpcProcessor(resolver, 19978)
286
287     for fn in [self._GetInvalidResponseA, self._GetInvalidResponseB]:
288       http_proc = _FakeRequestProcessor(fn)
289       host = "oqo7lanhly.example.com"
290       body = {host: ""}
291       result = proc([host], "version", body, 60, NotImplemented,
292                     _req_process_fn=http_proc)
293       self.assertEqual(result.keys(), [host])
294       lhresp = result[host]
295       self.assertFalse(lhresp.offline)
296       self.assertEqual(lhresp.node, host)
297       self.assert_(lhresp.fail_msg)
298       self.assertFalse(lhresp.payload)
299       self.assertEqual(lhresp.call, "version")
300       self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
301       self.assertEqual(http_proc.reqcount, 1)
302
303   def _GetBodyTestResponse(self, test_data, req):
304     self.assertEqual(req.host, "192.0.2.84")
305     self.assertEqual(req.port, 18700)
306     self.assertEqual(req.path, "/upload_file")
307     self.assertEqual(serializer.LoadJson(req.post_data), test_data)
308     req.success = True
309     req.resp_status_code = http.HTTP_OK
310     req.resp_body = serializer.DumpJson((True, None))
311
312   def testResponseBody(self):
313     test_data = {
314       "Hello": "World",
315       "xyz": range(10),
316       }
317     resolver = rpc._StaticResolver(["192.0.2.84"])
318     http_proc = _FakeRequestProcessor(compat.partial(self._GetBodyTestResponse,
319                                                      test_data))
320     proc = rpc._RpcProcessor(resolver, 18700)
321     host = "node19759"
322     body = {host: serializer.DumpJson(test_data)}
323     result = proc([host], "upload_file", body, 30, NotImplemented,
324                   _req_process_fn=http_proc)
325     self.assertEqual(result.keys(), [host])
326     lhresp = result[host]
327     self.assertFalse(lhresp.offline)
328     self.assertEqual(lhresp.node, host)
329     self.assertFalse(lhresp.fail_msg)
330     self.assertEqual(lhresp.payload, None)
331     self.assertEqual(lhresp.call, "upload_file")
332     lhresp.Raise("should not raise")
333     self.assertEqual(http_proc.reqcount, 1)
334
335
336 class TestSsconfResolver(unittest.TestCase):
337   def testSsconfLookup(self):
338     addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
339     node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
340     node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
341     ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
342     result = rpc._SsconfResolver(True, node_list, NotImplemented,
343                                  ssc=ssc, nslookup_fn=NotImplemented)
344     self.assertEqual(result, zip(node_list, addr_list, node_list))
345
346   def testNsLookup(self):
347     addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
348     node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
349     ssc = GetFakeSimpleStoreClass(lambda _: [])
350     node_addr_map = dict(zip(node_list, addr_list))
351     nslookup_fn = lambda name, family=None: node_addr_map.get(name)
352     result = rpc._SsconfResolver(True, node_list, NotImplemented,
353                                  ssc=ssc, nslookup_fn=nslookup_fn)
354     self.assertEqual(result, zip(node_list, addr_list, node_list))
355
356   def testDisabledSsconfIp(self):
357     addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
358     node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
359     ssc = GetFakeSimpleStoreClass(_RaiseNotImplemented)
360     node_addr_map = dict(zip(node_list, addr_list))
361     nslookup_fn = lambda name, family=None: node_addr_map.get(name)
362     result = rpc._SsconfResolver(False, node_list, NotImplemented,
363                                  ssc=ssc, nslookup_fn=nslookup_fn)
364     self.assertEqual(result, zip(node_list, addr_list, node_list))
365
366   def testBothLookups(self):
367     addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
368     node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
369     n = len(addr_list) / 2
370     node_addr_list = [" ".join(t) for t in zip(node_list[n:], addr_list[n:])]
371     ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
372     node_addr_map = dict(zip(node_list[:n], addr_list[:n]))
373     nslookup_fn = lambda name, family=None: node_addr_map.get(name)
374     result = rpc._SsconfResolver(True, node_list, NotImplemented,
375                                  ssc=ssc, nslookup_fn=nslookup_fn)
376     self.assertEqual(result, zip(node_list, addr_list, node_list))
377
378   def testAddressLookupIPv6(self):
379     addr_list = ["2001:db8::%d" % n for n in range(0, 255, 11)]
380     node_list = ["node%d.example.com" % n for n in range(0, 255, 11)]
381     node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
382     ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
383     result = rpc._SsconfResolver(True, node_list, NotImplemented,
384                                  ssc=ssc, nslookup_fn=NotImplemented)
385     self.assertEqual(result, zip(node_list, addr_list, node_list))
386
387
388 class TestStaticResolver(unittest.TestCase):
389   def test(self):
390     addresses = ["192.0.2.%d" % n for n in range(0, 123, 7)]
391     nodes = ["node%s.example.com" % n for n in range(0, 123, 7)]
392     res = rpc._StaticResolver(addresses)
393     self.assertEqual(res(nodes, NotImplemented), zip(nodes, addresses, nodes))
394
395   def testWrongLength(self):
396     res = rpc._StaticResolver([])
397     self.assertRaises(AssertionError, res, ["abc"], NotImplemented)
398
399
400 class TestNodeConfigResolver(unittest.TestCase):
401   @staticmethod
402   def _GetSingleOnlineNode(uuid):
403     assert uuid == "node90-uuid"
404     return objects.Node(name="node90.example.com",
405                         uuid=uuid,
406                         offline=False,
407                         primary_ip="192.0.2.90")
408
409   @staticmethod
410   def _GetSingleOfflineNode(uuid):
411     assert uuid == "node100-uuid"
412     return objects.Node(name="node100.example.com",
413                         uuid=uuid,
414                         offline=True,
415                         primary_ip="192.0.2.100")
416
417   def testSingleOnline(self):
418     self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOnlineNode,
419                                              NotImplemented,
420                                              ["node90-uuid"], None),
421                      [("node90.example.com", "192.0.2.90", "node90-uuid")])
422
423   def testSingleOffline(self):
424     self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOfflineNode,
425                                              NotImplemented,
426                                              ["node100-uuid"], None),
427                      [("node100.example.com", rpc._OFFLINE, "node100-uuid")])
428
429   def testSingleOfflineWithAcceptOffline(self):
430     fn = self._GetSingleOfflineNode
431     assert fn("node100-uuid").offline
432     self.assertEqual(rpc._NodeConfigResolver(fn, NotImplemented,
433                                              ["node100-uuid"],
434                                              rpc_defs.ACCEPT_OFFLINE_NODE),
435                      [("node100.example.com", "192.0.2.100", "node100-uuid")])
436     for i in [False, True, "", "Hello", 0, 1]:
437       self.assertRaises(AssertionError, rpc._NodeConfigResolver,
438                         fn, NotImplemented, ["node100.example.com"], i)
439
440   def testUnknownSingleNode(self):
441     self.assertEqual(rpc._NodeConfigResolver(lambda _: None, NotImplemented,
442                                              ["node110.example.com"], None),
443                      [("node110.example.com", "node110.example.com",
444                        "node110.example.com")])
445
446   def testMultiEmpty(self):
447     self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
448                                              lambda: {},
449                                              [], None),
450                      [])
451
452   def testMultiSomeOffline(self):
453     nodes = dict(("node%s-uuid" % i,
454                   objects.Node(name="node%s.example.com" % i,
455                                offline=((i % 3) == 0),
456                                primary_ip="192.0.2.%s" % i,
457                                uuid="node%s-uuid" % i))
458                   for i in range(1, 255))
459
460     # Resolve no names
461     self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
462                                              lambda: nodes,
463                                              [], None),
464                      [])
465
466     # Offline, online and unknown hosts
467     self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
468                                              lambda: nodes,
469                                              ["node3-uuid",
470                                               "node92-uuid",
471                                               "node54-uuid",
472                                               "unknown.example.com",],
473                                              None), [
474       ("node3.example.com", rpc._OFFLINE, "node3-uuid"),
475       ("node92.example.com", "192.0.2.92", "node92-uuid"),
476       ("node54.example.com", rpc._OFFLINE, "node54-uuid"),
477       ("unknown.example.com", "unknown.example.com", "unknown.example.com"),
478       ])
479
480
481 class TestCompress(unittest.TestCase):
482   def test(self):
483     for data in ["", "Hello", "Hello World!\nnew\nlines"]:
484       self.assertEqual(rpc._Compress(data),
485                        (constants.RPC_ENCODING_NONE, data))
486
487     for data in [512 * " ", 5242 * "Hello World!\n"]:
488       compressed = rpc._Compress(data)
489       self.assertEqual(len(compressed), 2)
490       self.assertEqual(backend._Decompress(compressed), data)
491
492   def testDecompression(self):
493     self.assertRaises(AssertionError, backend._Decompress, "")
494     self.assertRaises(AssertionError, backend._Decompress, [""])
495     self.assertRaises(AssertionError, backend._Decompress,
496                       ("unknown compression", "data"))
497     self.assertRaises(Exception, backend._Decompress,
498                       (constants.RPC_ENCODING_ZLIB_BASE64, "invalid zlib data"))
499
500
501 class TestRpcClientBase(unittest.TestCase):
502   def testNoHosts(self):
503     cdef = ("test_call", NotImplemented, None, constants.RPC_TMO_SLOW, [],
504             None, None, NotImplemented)
505     http_proc = _FakeRequestProcessor(NotImplemented)
506     client = rpc._RpcClientBase(rpc._StaticResolver([]), NotImplemented,
507                                 _req_process_fn=http_proc)
508     self.assertEqual(client._Call(cdef, [], []), {})
509
510     # Test wrong number of arguments
511     self.assertRaises(errors.ProgrammerError, client._Call,
512                       cdef, [], [0, 1, 2])
513
514   def testTimeout(self):
515     def _CalcTimeout((arg1, arg2)):
516       return arg1 + arg2
517
518     def _VerifyRequest(exp_timeout, req):
519       self.assertEqual(req.read_timeout, exp_timeout)
520
521       req.success = True
522       req.resp_status_code = http.HTTP_OK
523       req.resp_body = serializer.DumpJson((True, hex(req.read_timeout)))
524
525     resolver = rpc._StaticResolver([
526       "192.0.2.1",
527       "192.0.2.2",
528       ])
529
530     nodes = [
531       "node1.example.com",
532       "node2.example.com",
533       ]
534
535     tests = [(100, None, 100), (30, None, 30)]
536     tests.extend((_CalcTimeout, i, i + 300)
537                  for i in [0, 5, 16485, 30516])
538
539     for timeout, arg1, exp_timeout in tests:
540       cdef = ("test_call", NotImplemented, None, timeout, [
541         ("arg1", None, NotImplemented),
542         ("arg2", None, NotImplemented),
543         ], None, None, NotImplemented)
544
545       http_proc = _FakeRequestProcessor(compat.partial(_VerifyRequest,
546                                                        exp_timeout))
547       client = rpc._RpcClientBase(resolver, NotImplemented,
548                                   _req_process_fn=http_proc)
549       result = client._Call(cdef, nodes, [arg1, 300])
550       self.assertEqual(len(result), len(nodes))
551       self.assertTrue(compat.all(not res.fail_msg and
552                                  res.payload == hex(exp_timeout)
553                                  for res in result.values()))
554
555   def testArgumentEncoder(self):
556     (AT1, AT2) = range(1, 3)
557
558     resolver = rpc._StaticResolver([
559       "192.0.2.5",
560       "192.0.2.6",
561       ])
562
563     nodes = [
564       "node5.example.com",
565       "node6.example.com",
566       ]
567
568     encoders = {
569       AT1: hex,
570       AT2: hash,
571       }
572
573     cdef = ("test_call", NotImplemented, None, constants.RPC_TMO_NORMAL, [
574       ("arg0", None, NotImplemented),
575       ("arg1", AT1, NotImplemented),
576       ("arg1", AT2, NotImplemented),
577       ], None, None, NotImplemented)
578
579     def _VerifyRequest(req):
580       req.success = True
581       req.resp_status_code = http.HTTP_OK
582       req.resp_body = serializer.DumpJson((True, req.post_data))
583
584     http_proc = _FakeRequestProcessor(_VerifyRequest)
585
586     for num in [0, 3796, 9032119]:
587       client = rpc._RpcClientBase(resolver, encoders.get,
588                                   _req_process_fn=http_proc)
589       result = client._Call(cdef, nodes, ["foo", num, "Hello%s" % num])
590       self.assertEqual(len(result), len(nodes))
591       for res in result.values():
592         self.assertFalse(res.fail_msg)
593         self.assertEqual(serializer.LoadJson(res.payload),
594                          ["foo", hex(num), hash("Hello%s" % num)])
595
596   def testPostProc(self):
597     def _VerifyRequest(nums, req):
598       req.success = True
599       req.resp_status_code = http.HTTP_OK
600       req.resp_body = serializer.DumpJson((True, nums))
601
602     resolver = rpc._StaticResolver([
603       "192.0.2.90",
604       "192.0.2.95",
605       ])
606
607     nodes = [
608       "node90.example.com",
609       "node95.example.com",
610       ]
611
612     def _PostProc(res):
613       self.assertFalse(res.fail_msg)
614       res.payload = sum(res.payload)
615       return res
616
617     cdef = ("test_call", NotImplemented, None, constants.RPC_TMO_NORMAL, [],
618             None, _PostProc, NotImplemented)
619
620     # Seeded random generator
621     rnd = random.Random(20299)
622
623     for i in [0, 4, 74, 1391]:
624       nums = [rnd.randint(0, 1000) for _ in range(i)]
625       http_proc = _FakeRequestProcessor(compat.partial(_VerifyRequest, nums))
626       client = rpc._RpcClientBase(resolver, NotImplemented,
627                                   _req_process_fn=http_proc)
628       result = client._Call(cdef, nodes, [])
629       self.assertEqual(len(result), len(nodes))
630       for res in result.values():
631         self.assertFalse(res.fail_msg)
632         self.assertEqual(res.payload, sum(nums))
633
634   def testPreProc(self):
635     def _VerifyRequest(req):
636       req.success = True
637       req.resp_status_code = http.HTTP_OK
638       req.resp_body = serializer.DumpJson((True, req.post_data))
639
640     resolver = rpc._StaticResolver([
641       "192.0.2.30",
642       "192.0.2.35",
643       ])
644
645     nodes = [
646       "node30.example.com",
647       "node35.example.com",
648       ]
649
650     def _PreProc(node, data):
651       self.assertEqual(len(data), 1)
652       return data[0] + node
653
654     cdef = ("test_call", NotImplemented, None, constants.RPC_TMO_NORMAL, [
655       ("arg0", None, NotImplemented),
656       ], _PreProc, None, NotImplemented)
657
658     http_proc = _FakeRequestProcessor(_VerifyRequest)
659     client = rpc._RpcClientBase(resolver, NotImplemented,
660                                 _req_process_fn=http_proc)
661
662     for prefix in ["foo", "bar", "baz"]:
663       result = client._Call(cdef, nodes, [prefix])
664       self.assertEqual(len(result), len(nodes))
665       for (idx, (node, res)) in enumerate(result.items()):
666         self.assertFalse(res.fail_msg)
667         self.assertEqual(serializer.LoadJson(res.payload), prefix + node)
668
669   def testResolverOptions(self):
670     def _VerifyRequest(req):
671       req.success = True
672       req.resp_status_code = http.HTTP_OK
673       req.resp_body = serializer.DumpJson((True, req.post_data))
674
675     nodes = [
676       "node30.example.com",
677       "node35.example.com",
678       ]
679
680     def _Resolver(expected, hosts, options):
681       self.assertEqual(hosts, nodes)
682       self.assertEqual(options, expected)
683       return zip(hosts, nodes, hosts)
684
685     def _DynamicResolverOptions((arg0, )):
686       return sum(arg0)
687
688     tests = [
689       (None, None, None),
690       (rpc_defs.ACCEPT_OFFLINE_NODE, None, rpc_defs.ACCEPT_OFFLINE_NODE),
691       (False, None, False),
692       (True, None, True),
693       (0, None, 0),
694       (_DynamicResolverOptions, [1, 2, 3], 6),
695       (_DynamicResolverOptions, range(4, 19), 165),
696       ]
697
698     for (resolver_opts, arg0, expected) in tests:
699       cdef = ("test_call", NotImplemented, resolver_opts,
700               constants.RPC_TMO_NORMAL, [
701         ("arg0", None, NotImplemented),
702         ], None, None, NotImplemented)
703
704       http_proc = _FakeRequestProcessor(_VerifyRequest)
705
706       client = rpc._RpcClientBase(compat.partial(_Resolver, expected),
707                                   NotImplemented, _req_process_fn=http_proc)
708       result = client._Call(cdef, nodes, [arg0])
709       self.assertEqual(len(result), len(nodes))
710       for (idx, (node, res)) in enumerate(result.items()):
711         self.assertFalse(res.fail_msg)
712
713
714 class _FakeConfigForRpcRunner:
715   GetAllNodesInfo = NotImplemented
716
717   def __init__(self, cluster=NotImplemented):
718     self._cluster = cluster
719
720   def GetNodeInfo(self, name):
721     return objects.Node(name=name)
722
723   def GetClusterInfo(self):
724     return self._cluster
725
726   def GetInstanceDiskParams(self, _):
727     return constants.DISK_DT_DEFAULTS
728
729
730 class TestRpcRunner(unittest.TestCase):
731   def testUploadFile(self):
732     data = 1779 * "Hello World\n"
733
734     tmpfile = tempfile.NamedTemporaryFile()
735     tmpfile.write(data)
736     tmpfile.flush()
737     st = os.stat(tmpfile.name)
738
739     def _VerifyRequest(req):
740       (uldata, ) = serializer.LoadJson(req.post_data)
741       self.assertEqual(len(uldata), 7)
742       self.assertEqual(uldata[0], tmpfile.name)
743       self.assertEqual(list(uldata[1]), list(rpc._Compress(data)))
744       self.assertEqual(uldata[2], st.st_mode)
745       self.assertEqual(uldata[3], "user%s" % os.getuid())
746       self.assertEqual(uldata[4], "group%s" % os.getgid())
747       self.assertTrue(uldata[5] is not None)
748       self.assertEqual(uldata[6], st.st_mtime)
749
750       req.success = True
751       req.resp_status_code = http.HTTP_OK
752       req.resp_body = serializer.DumpJson((True, None))
753
754     http_proc = _FakeRequestProcessor(_VerifyRequest)
755
756     std_runner = rpc.RpcRunner(_FakeConfigForRpcRunner(), None,
757                                _req_process_fn=http_proc,
758                                _getents=mocks.FakeGetentResolver)
759
760     cfg_runner = rpc.ConfigRunner(None, ["192.0.2.13"],
761                                   _req_process_fn=http_proc,
762                                   _getents=mocks.FakeGetentResolver)
763
764     nodes = [
765       "node1.example.com",
766       ]
767
768     for runner in [std_runner, cfg_runner]:
769       result = runner.call_upload_file(nodes, tmpfile.name)
770       self.assertEqual(len(result), len(nodes))
771       for (idx, (node, res)) in enumerate(result.items()):
772         self.assertFalse(res.fail_msg)
773
774   def testEncodeInstance(self):
775     cluster = objects.Cluster(hvparams={
776       constants.HT_KVM: {
777         constants.HV_BLOCKDEV_PREFIX: "foo",
778         },
779       },
780       beparams={
781         constants.PP_DEFAULT: {
782           constants.BE_MAXMEM: 8192,
783           },
784         },
785       os_hvp={},
786       osparams={
787         "linux": {
788           "role": "unknown",
789           },
790         })
791     cluster.UpgradeConfig()
792
793     inst = objects.Instance(name="inst1.example.com",
794       hypervisor=constants.HT_FAKE,
795       os="linux",
796       hvparams={
797         constants.HT_KVM: {
798           constants.HV_BLOCKDEV_PREFIX: "bar",
799           constants.HV_ROOT_PATH: "/tmp",
800           },
801         },
802       beparams={
803         constants.BE_MINMEM: 128,
804         constants.BE_MAXMEM: 256,
805         },
806       nics=[
807         objects.NIC(nicparams={
808           constants.NIC_MODE: "mymode",
809           }),
810         ],
811       disk_template=constants.DT_PLAIN,
812       disks=[
813         objects.Disk(dev_type=constants.LD_LV, size=4096,
814                      logical_id=("vg", "disk6120")),
815         objects.Disk(dev_type=constants.LD_LV, size=1024,
816                      logical_id=("vg", "disk8508")),
817         ])
818     inst.UpgradeConfig()
819
820     cfg = _FakeConfigForRpcRunner(cluster=cluster)
821     runner = rpc.RpcRunner(cfg, None,
822                            _req_process_fn=NotImplemented,
823                            _getents=mocks.FakeGetentResolver)
824
825     def _CheckBasics(result):
826       self.assertEqual(result["name"], "inst1.example.com")
827       self.assertEqual(result["os"], "linux")
828       self.assertEqual(result["beparams"][constants.BE_MINMEM], 128)
829       self.assertEqual(len(result["hvparams"]), 1)
830       self.assertEqual(len(result["nics"]), 1)
831       self.assertEqual(result["nics"][0]["nicparams"][constants.NIC_MODE],
832                        "mymode")
833
834     # Generic object serialization
835     result = runner._encoder((rpc_defs.ED_OBJECT_DICT, inst))
836     _CheckBasics(result)
837
838     result = runner._encoder((rpc_defs.ED_OBJECT_DICT_LIST, 5 * [inst]))
839     map(_CheckBasics, result)
840
841     # Just an instance
842     result = runner._encoder((rpc_defs.ED_INST_DICT, inst))
843     _CheckBasics(result)
844     self.assertEqual(result["beparams"][constants.BE_MAXMEM], 256)
845     self.assertEqual(result["hvparams"][constants.HT_KVM], {
846       constants.HV_BLOCKDEV_PREFIX: "bar",
847       constants.HV_ROOT_PATH: "/tmp",
848       })
849     self.assertEqual(result["osparams"], {
850       "role": "unknown",
851       })
852
853     # Instance with OS parameters
854     result = runner._encoder((rpc_defs.ED_INST_DICT_OSP_DP, (inst, {
855       "role": "webserver",
856       "other": "field",
857       })))
858     _CheckBasics(result)
859     self.assertEqual(result["beparams"][constants.BE_MAXMEM], 256)
860     self.assertEqual(result["hvparams"][constants.HT_KVM], {
861       constants.HV_BLOCKDEV_PREFIX: "bar",
862       constants.HV_ROOT_PATH: "/tmp",
863       })
864     self.assertEqual(result["osparams"], {
865       "role": "webserver",
866       "other": "field",
867       })
868
869     # Instance with hypervisor and backend parameters
870     result = runner._encoder((rpc_defs.ED_INST_DICT_HVP_BEP_DP, (inst, {
871       constants.HT_KVM: {
872         constants.HV_BOOT_ORDER: "xyz",
873         },
874       }, {
875       constants.BE_VCPUS: 100,
876       constants.BE_MAXMEM: 4096,
877       })))
878     _CheckBasics(result)
879     self.assertEqual(result["beparams"][constants.BE_MAXMEM], 4096)
880     self.assertEqual(result["beparams"][constants.BE_VCPUS], 100)
881     self.assertEqual(result["hvparams"][constants.HT_KVM], {
882       constants.HV_BOOT_ORDER: "xyz",
883       })
884     self.assertEqual(result["disks"], [{
885       "dev_type": constants.LD_LV,
886       "size": 4096,
887       "logical_id": ("vg", "disk6120"),
888       "params": constants.DISK_DT_DEFAULTS[inst.disk_template],
889       }, {
890       "dev_type": constants.LD_LV,
891       "size": 1024,
892       "logical_id": ("vg", "disk8508"),
893       "params": constants.DISK_DT_DEFAULTS[inst.disk_template],
894       }])
895
896     self.assertTrue(compat.all(disk.params == {} for disk in inst.disks),
897                     msg="Configuration objects were modified")
898
899
900 class TestLegacyNodeInfo(unittest.TestCase):
901   KEY_BOOT = "bootid"
902   KEY_VG0 = "name"
903   KEY_VG1 = "vg_free"
904   KEY_VG2 = "vg_size"
905   KEY_HV = "cpu_count"
906   KEY_SP1 = "spindles_free"
907   KEY_SP2 = "spindles_total"
908   VAL_BOOT = 0
909   VAL_VG0 = "xy"
910   VAL_VG1 = 11
911   VAL_VG2 = 12
912   VAL_HV = 2
913   VAL_SP0 = "ab"
914   VAL_SP1 = 31
915   VAL_SP2 = 32
916   DICT_VG = {
917     KEY_VG0: VAL_VG0,
918     KEY_VG1: VAL_VG1,
919     KEY_VG2: VAL_VG2,
920     }
921   DICT_HV = {KEY_HV: VAL_HV}
922   DICT_SP = {
923     KEY_VG0: VAL_SP0,
924     KEY_VG1: VAL_SP1,
925     KEY_VG2: VAL_SP2,
926     }
927   STD_LST = [VAL_BOOT, [DICT_VG, DICT_SP], [DICT_HV]]
928   STD_DICT = {
929     KEY_BOOT: VAL_BOOT,
930     KEY_VG0: VAL_VG0,
931     KEY_VG1: VAL_VG1,
932     KEY_VG2: VAL_VG2,
933     KEY_HV: VAL_HV,
934     KEY_SP1: VAL_SP1,
935     KEY_SP2: VAL_SP2,
936     }
937
938   def testStandard(self):
939     result = rpc.MakeLegacyNodeInfo(self.STD_LST)
940     self.assertEqual(result, self.STD_DICT)
941
942   def testReqVg(self):
943     my_lst = [self.VAL_BOOT, [], [self.DICT_HV]]
944     self.assertRaises(ValueError, rpc.MakeLegacyNodeInfo, my_lst)
945
946   def testNoReqVg(self):
947     my_lst = [self.VAL_BOOT, [], [self.DICT_HV]]
948     result = rpc.MakeLegacyNodeInfo(my_lst, require_vg_info = False)
949     self.assertEqual(result, {self.KEY_BOOT: self.VAL_BOOT,
950                               self.KEY_HV: self.VAL_HV})
951     result = rpc.MakeLegacyNodeInfo(self.STD_LST, require_vg_info = False)
952     self.assertEqual(result, self.STD_DICT)
953
954
955 if __name__ == "__main__":
956   testutils.GanetiTestProgram()