config: Add check for disk's “iv_name”
[ganeti-local] / test / ganeti.rpc_unittest.py
1 #!/usr/bin/python
2 #
3
4 # Copyright (C) 2010, 2011 Google Inc.
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 # 02110-1301, USA.
20
21
22 """Script for testing ganeti.rpc"""
23
24 import os
25 import sys
26 import unittest
27 import random
28 import tempfile
29
30 from ganeti import constants
31 from ganeti import compat
32 from ganeti import rpc
33 from ganeti import rpc_defs
34 from ganeti import http
35 from ganeti import errors
36 from ganeti import serializer
37 from ganeti import objects
38 from ganeti import backend
39
40 import testutils
41 import mocks
42
43
44 class _FakeRequestProcessor:
45   def __init__(self, response_fn):
46     self._response_fn = response_fn
47     self.reqcount = 0
48
49   def __call__(self, reqs, lock_monitor_cb=None):
50     assert lock_monitor_cb is None or callable(lock_monitor_cb)
51     for req in reqs:
52       self.reqcount += 1
53       self._response_fn(req)
54
55
56 def GetFakeSimpleStoreClass(fn):
57   class FakeSimpleStore:
58     GetNodePrimaryIPList = fn
59     GetPrimaryIPFamily = lambda _: None
60
61   return FakeSimpleStore
62
63
64 class TestRpcProcessor(unittest.TestCase):
65   def _FakeAddressLookup(self, map):
66     return lambda node_list: [map.get(node) for node in node_list]
67
68   def _GetVersionResponse(self, req):
69     self.assertEqual(req.host, "127.0.0.1")
70     self.assertEqual(req.port, 24094)
71     self.assertEqual(req.path, "/version")
72     self.assertEqual(req.read_timeout, rpc._TMO_URGENT)
73     req.success = True
74     req.resp_status_code = http.HTTP_OK
75     req.resp_body = serializer.DumpJson((True, 123))
76
77   def testVersionSuccess(self):
78     resolver = rpc._StaticResolver(["127.0.0.1"])
79     http_proc = _FakeRequestProcessor(self._GetVersionResponse)
80     proc = rpc._RpcProcessor(resolver, 24094)
81     result = proc(["localhost"], "version", {"localhost": ""}, 60,
82                   NotImplemented, _req_process_fn=http_proc)
83     self.assertEqual(result.keys(), ["localhost"])
84     lhresp = result["localhost"]
85     self.assertFalse(lhresp.offline)
86     self.assertEqual(lhresp.node, "localhost")
87     self.assertFalse(lhresp.fail_msg)
88     self.assertEqual(lhresp.payload, 123)
89     self.assertEqual(lhresp.call, "version")
90     lhresp.Raise("should not raise")
91     self.assertEqual(http_proc.reqcount, 1)
92
93   def _ReadTimeoutResponse(self, req):
94     self.assertEqual(req.host, "192.0.2.13")
95     self.assertEqual(req.port, 19176)
96     self.assertEqual(req.path, "/version")
97     self.assertEqual(req.read_timeout, 12356)
98     req.success = True
99     req.resp_status_code = http.HTTP_OK
100     req.resp_body = serializer.DumpJson((True, -1))
101
102   def testReadTimeout(self):
103     resolver = rpc._StaticResolver(["192.0.2.13"])
104     http_proc = _FakeRequestProcessor(self._ReadTimeoutResponse)
105     proc = rpc._RpcProcessor(resolver, 19176)
106     host = "node31856"
107     body = {host: ""}
108     result = proc([host], "version", body, 12356, NotImplemented,
109                   _req_process_fn=http_proc)
110     self.assertEqual(result.keys(), [host])
111     lhresp = result[host]
112     self.assertFalse(lhresp.offline)
113     self.assertEqual(lhresp.node, host)
114     self.assertFalse(lhresp.fail_msg)
115     self.assertEqual(lhresp.payload, -1)
116     self.assertEqual(lhresp.call, "version")
117     lhresp.Raise("should not raise")
118     self.assertEqual(http_proc.reqcount, 1)
119
120   def testOfflineNode(self):
121     resolver = rpc._StaticResolver([rpc._OFFLINE])
122     http_proc = _FakeRequestProcessor(NotImplemented)
123     proc = rpc._RpcProcessor(resolver, 30668)
124     host = "n17296"
125     body = {host: ""}
126     result = proc([host], "version", body, 60, NotImplemented,
127                   _req_process_fn=http_proc)
128     self.assertEqual(result.keys(), [host])
129     lhresp = result[host]
130     self.assertTrue(lhresp.offline)
131     self.assertEqual(lhresp.node, host)
132     self.assertTrue(lhresp.fail_msg)
133     self.assertFalse(lhresp.payload)
134     self.assertEqual(lhresp.call, "version")
135
136     # With a message
137     self.assertRaises(errors.OpExecError, lhresp.Raise, "should raise")
138
139     # No message
140     self.assertRaises(errors.OpExecError, lhresp.Raise, None)
141
142     self.assertEqual(http_proc.reqcount, 0)
143
144   def _GetMultiVersionResponse(self, req):
145     self.assert_(req.host.startswith("node"))
146     self.assertEqual(req.port, 23245)
147     self.assertEqual(req.path, "/version")
148     req.success = True
149     req.resp_status_code = http.HTTP_OK
150     req.resp_body = serializer.DumpJson((True, 987))
151
152   def testMultiVersionSuccess(self):
153     nodes = ["node%s" % i for i in range(50)]
154     body = dict((n, "") for n in nodes)
155     resolver = rpc._StaticResolver(nodes)
156     http_proc = _FakeRequestProcessor(self._GetMultiVersionResponse)
157     proc = rpc._RpcProcessor(resolver, 23245)
158     result = proc(nodes, "version", body, 60, NotImplemented,
159                   _req_process_fn=http_proc)
160     self.assertEqual(sorted(result.keys()), sorted(nodes))
161
162     for name in nodes:
163       lhresp = result[name]
164       self.assertFalse(lhresp.offline)
165       self.assertEqual(lhresp.node, name)
166       self.assertFalse(lhresp.fail_msg)
167       self.assertEqual(lhresp.payload, 987)
168       self.assertEqual(lhresp.call, "version")
169       lhresp.Raise("should not raise")
170
171     self.assertEqual(http_proc.reqcount, len(nodes))
172
173   def _GetVersionResponseFail(self, errinfo, req):
174     self.assertEqual(req.path, "/version")
175     req.success = True
176     req.resp_status_code = http.HTTP_OK
177     req.resp_body = serializer.DumpJson((False, errinfo))
178
179   def testVersionFailure(self):
180     resolver = rpc._StaticResolver(["aef9ur4i.example.com"])
181     proc = rpc._RpcProcessor(resolver, 5903)
182     for errinfo in [None, "Unknown error"]:
183       http_proc = \
184         _FakeRequestProcessor(compat.partial(self._GetVersionResponseFail,
185                                              errinfo))
186       host = "aef9ur4i.example.com"
187       body = {host: ""}
188       result = proc(body.keys(), "version", body, 60, NotImplemented,
189                     _req_process_fn=http_proc)
190       self.assertEqual(result.keys(), [host])
191       lhresp = result[host]
192       self.assertFalse(lhresp.offline)
193       self.assertEqual(lhresp.node, host)
194       self.assert_(lhresp.fail_msg)
195       self.assertFalse(lhresp.payload)
196       self.assertEqual(lhresp.call, "version")
197       self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
198       self.assertEqual(http_proc.reqcount, 1)
199
200   def _GetHttpErrorResponse(self, httperrnodes, failnodes, req):
201     self.assertEqual(req.path, "/vg_list")
202     self.assertEqual(req.port, 15165)
203
204     if req.host in httperrnodes:
205       req.success = False
206       req.error = "Node set up for HTTP errors"
207
208     elif req.host in failnodes:
209       req.success = True
210       req.resp_status_code = 404
211       req.resp_body = serializer.DumpJson({
212         "code": 404,
213         "message": "Method not found",
214         "explain": "Explanation goes here",
215         })
216     else:
217       req.success = True
218       req.resp_status_code = http.HTTP_OK
219       req.resp_body = serializer.DumpJson((True, hash(req.host)))
220
221   def testHttpError(self):
222     nodes = ["uaf6pbbv%s" % i for i in range(50)]
223     body = dict((n, "") for n in nodes)
224     resolver = rpc._StaticResolver(nodes)
225
226     httperrnodes = set(nodes[1::7])
227     self.assertEqual(len(httperrnodes), 7)
228
229     failnodes = set(nodes[2::3]) - httperrnodes
230     self.assertEqual(len(failnodes), 14)
231
232     self.assertEqual(len(set(nodes) - failnodes - httperrnodes), 29)
233
234     proc = rpc._RpcProcessor(resolver, 15165)
235     http_proc = \
236       _FakeRequestProcessor(compat.partial(self._GetHttpErrorResponse,
237                                            httperrnodes, failnodes))
238     result = proc(nodes, "vg_list", body, rpc._TMO_URGENT, NotImplemented,
239                   _req_process_fn=http_proc)
240     self.assertEqual(sorted(result.keys()), sorted(nodes))
241
242     for name in nodes:
243       lhresp = result[name]
244       self.assertFalse(lhresp.offline)
245       self.assertEqual(lhresp.node, name)
246       self.assertEqual(lhresp.call, "vg_list")
247
248       if name in httperrnodes:
249         self.assert_(lhresp.fail_msg)
250         self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
251       elif name in failnodes:
252         self.assert_(lhresp.fail_msg)
253         self.assertRaises(errors.OpPrereqError, lhresp.Raise, "failed",
254                           prereq=True, ecode=errors.ECODE_INVAL)
255       else:
256         self.assertFalse(lhresp.fail_msg)
257         self.assertEqual(lhresp.payload, hash(name))
258         lhresp.Raise("should not raise")
259
260     self.assertEqual(http_proc.reqcount, len(nodes))
261
262   def _GetInvalidResponseA(self, req):
263     self.assertEqual(req.path, "/version")
264     req.success = True
265     req.resp_status_code = http.HTTP_OK
266     req.resp_body = serializer.DumpJson(("This", "is", "an", "invalid",
267                                          "response", "!", 1, 2, 3))
268
269   def _GetInvalidResponseB(self, req):
270     self.assertEqual(req.path, "/version")
271     req.success = True
272     req.resp_status_code = http.HTTP_OK
273     req.resp_body = serializer.DumpJson("invalid response")
274
275   def testInvalidResponse(self):
276     resolver = rpc._StaticResolver(["oqo7lanhly.example.com"])
277     proc = rpc._RpcProcessor(resolver, 19978)
278
279     for fn in [self._GetInvalidResponseA, self._GetInvalidResponseB]:
280       http_proc = _FakeRequestProcessor(fn)
281       host = "oqo7lanhly.example.com"
282       body = {host: ""}
283       result = proc([host], "version", body, 60, NotImplemented,
284                     _req_process_fn=http_proc)
285       self.assertEqual(result.keys(), [host])
286       lhresp = result[host]
287       self.assertFalse(lhresp.offline)
288       self.assertEqual(lhresp.node, host)
289       self.assert_(lhresp.fail_msg)
290       self.assertFalse(lhresp.payload)
291       self.assertEqual(lhresp.call, "version")
292       self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
293       self.assertEqual(http_proc.reqcount, 1)
294
295   def _GetBodyTestResponse(self, test_data, req):
296     self.assertEqual(req.host, "192.0.2.84")
297     self.assertEqual(req.port, 18700)
298     self.assertEqual(req.path, "/upload_file")
299     self.assertEqual(serializer.LoadJson(req.post_data), test_data)
300     req.success = True
301     req.resp_status_code = http.HTTP_OK
302     req.resp_body = serializer.DumpJson((True, None))
303
304   def testResponseBody(self):
305     test_data = {
306       "Hello": "World",
307       "xyz": range(10),
308       }
309     resolver = rpc._StaticResolver(["192.0.2.84"])
310     http_proc = _FakeRequestProcessor(compat.partial(self._GetBodyTestResponse,
311                                                      test_data))
312     proc = rpc._RpcProcessor(resolver, 18700)
313     host = "node19759"
314     body = {host: serializer.DumpJson(test_data)}
315     result = proc([host], "upload_file", body, 30, NotImplemented,
316                   _req_process_fn=http_proc)
317     self.assertEqual(result.keys(), [host])
318     lhresp = result[host]
319     self.assertFalse(lhresp.offline)
320     self.assertEqual(lhresp.node, host)
321     self.assertFalse(lhresp.fail_msg)
322     self.assertEqual(lhresp.payload, None)
323     self.assertEqual(lhresp.call, "upload_file")
324     lhresp.Raise("should not raise")
325     self.assertEqual(http_proc.reqcount, 1)
326
327
328 class TestSsconfResolver(unittest.TestCase):
329   def testSsconfLookup(self):
330     addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
331     node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
332     node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
333     ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
334     result = rpc._SsconfResolver(node_list, NotImplemented,
335                                  ssc=ssc, nslookup_fn=NotImplemented)
336     self.assertEqual(result, zip(node_list, addr_list))
337
338   def testNsLookup(self):
339     addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
340     node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
341     ssc = GetFakeSimpleStoreClass(lambda _: [])
342     node_addr_map = dict(zip(node_list, addr_list))
343     nslookup_fn = lambda name, family=None: node_addr_map.get(name)
344     result = rpc._SsconfResolver(node_list, NotImplemented,
345                                  ssc=ssc, nslookup_fn=nslookup_fn)
346     self.assertEqual(result, zip(node_list, addr_list))
347
348   def testBothLookups(self):
349     addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
350     node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
351     n = len(addr_list) / 2
352     node_addr_list = [" ".join(t) for t in zip(node_list[n:], addr_list[n:])]
353     ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
354     node_addr_map = dict(zip(node_list[:n], addr_list[:n]))
355     nslookup_fn = lambda name, family=None: node_addr_map.get(name)
356     result = rpc._SsconfResolver(node_list, NotImplemented,
357                                  ssc=ssc, nslookup_fn=nslookup_fn)
358     self.assertEqual(result, zip(node_list, addr_list))
359
360   def testAddressLookupIPv6(self):
361     addr_list = ["2001:db8::%d" % n for n in range(0, 255, 11)]
362     node_list = ["node%d.example.com" % n for n in range(0, 255, 11)]
363     node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
364     ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
365     result = rpc._SsconfResolver(node_list, NotImplemented,
366                                  ssc=ssc, nslookup_fn=NotImplemented)
367     self.assertEqual(result, zip(node_list, addr_list))
368
369
370 class TestStaticResolver(unittest.TestCase):
371   def test(self):
372     addresses = ["192.0.2.%d" % n for n in range(0, 123, 7)]
373     nodes = ["node%s.example.com" % n for n in range(0, 123, 7)]
374     res = rpc._StaticResolver(addresses)
375     self.assertEqual(res(nodes, NotImplemented), zip(nodes, addresses))
376
377   def testWrongLength(self):
378     res = rpc._StaticResolver([])
379     self.assertRaises(AssertionError, res, ["abc"], NotImplemented)
380
381
382 class TestNodeConfigResolver(unittest.TestCase):
383   @staticmethod
384   def _GetSingleOnlineNode(name):
385     assert name == "node90.example.com"
386     return objects.Node(name=name, offline=False, primary_ip="192.0.2.90")
387
388   @staticmethod
389   def _GetSingleOfflineNode(name):
390     assert name == "node100.example.com"
391     return objects.Node(name=name, offline=True, primary_ip="192.0.2.100")
392
393   def testSingleOnline(self):
394     self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOnlineNode,
395                                              NotImplemented,
396                                              ["node90.example.com"], None),
397                      [("node90.example.com", "192.0.2.90")])
398
399   def testSingleOffline(self):
400     self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOfflineNode,
401                                              NotImplemented,
402                                              ["node100.example.com"], None),
403                      [("node100.example.com", rpc._OFFLINE)])
404
405   def testSingleOfflineWithAcceptOffline(self):
406     fn = self._GetSingleOfflineNode
407     assert fn("node100.example.com").offline
408     self.assertEqual(rpc._NodeConfigResolver(fn, NotImplemented,
409                                              ["node100.example.com"],
410                                              rpc_defs.ACCEPT_OFFLINE_NODE),
411                      [("node100.example.com", "192.0.2.100")])
412     for i in [False, True, "", "Hello", 0, 1]:
413       self.assertRaises(AssertionError, rpc._NodeConfigResolver,
414                         fn, NotImplemented, ["node100.example.com"], i)
415
416   def testUnknownSingleNode(self):
417     self.assertEqual(rpc._NodeConfigResolver(lambda _: None, NotImplemented,
418                                              ["node110.example.com"], None),
419                      [("node110.example.com", "node110.example.com")])
420
421   def testMultiEmpty(self):
422     self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
423                                              lambda: {},
424                                              [], None),
425                      [])
426
427   def testMultiSomeOffline(self):
428     nodes = dict(("node%s.example.com" % i,
429                   objects.Node(name="node%s.example.com" % i,
430                                offline=((i % 3) == 0),
431                                primary_ip="192.0.2.%s" % i))
432                   for i in range(1, 255))
433
434     # Resolve no names
435     self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
436                                              lambda: nodes,
437                                              [], None),
438                      [])
439
440     # Offline, online and unknown hosts
441     self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
442                                              lambda: nodes,
443                                              ["node3.example.com",
444                                               "node92.example.com",
445                                               "node54.example.com",
446                                               "unknown.example.com",],
447                                              None), [
448       ("node3.example.com", rpc._OFFLINE),
449       ("node92.example.com", "192.0.2.92"),
450       ("node54.example.com", rpc._OFFLINE),
451       ("unknown.example.com", "unknown.example.com"),
452       ])
453
454
455 class TestCompress(unittest.TestCase):
456   def test(self):
457     for data in ["", "Hello", "Hello World!\nnew\nlines"]:
458       self.assertEqual(rpc._Compress(data),
459                        (constants.RPC_ENCODING_NONE, data))
460
461     for data in [512 * " ", 5242 * "Hello World!\n"]:
462       compressed = rpc._Compress(data)
463       self.assertEqual(len(compressed), 2)
464       self.assertEqual(backend._Decompress(compressed), data)
465
466   def testDecompression(self):
467     self.assertRaises(AssertionError, backend._Decompress, "")
468     self.assertRaises(AssertionError, backend._Decompress, [""])
469     self.assertRaises(AssertionError, backend._Decompress,
470                       ("unknown compression", "data"))
471     self.assertRaises(Exception, backend._Decompress,
472                       (constants.RPC_ENCODING_ZLIB_BASE64, "invalid zlib data"))
473
474
475 class TestRpcClientBase(unittest.TestCase):
476   def testNoHosts(self):
477     cdef = ("test_call", NotImplemented, None, rpc_defs.TMO_SLOW, [],
478             None, None, NotImplemented)
479     http_proc = _FakeRequestProcessor(NotImplemented)
480     client = rpc._RpcClientBase(rpc._StaticResolver([]), NotImplemented,
481                                 _req_process_fn=http_proc)
482     self.assertEqual(client._Call(cdef, [], []), {})
483
484     # Test wrong number of arguments
485     self.assertRaises(errors.ProgrammerError, client._Call,
486                       cdef, [], [0, 1, 2])
487
488   def testTimeout(self):
489     def _CalcTimeout((arg1, arg2)):
490       return arg1 + arg2
491
492     def _VerifyRequest(exp_timeout, req):
493       self.assertEqual(req.read_timeout, exp_timeout)
494
495       req.success = True
496       req.resp_status_code = http.HTTP_OK
497       req.resp_body = serializer.DumpJson((True, hex(req.read_timeout)))
498
499     resolver = rpc._StaticResolver([
500       "192.0.2.1",
501       "192.0.2.2",
502       ])
503
504     nodes = [
505       "node1.example.com",
506       "node2.example.com",
507       ]
508
509     tests = [(100, None, 100), (30, None, 30)]
510     tests.extend((_CalcTimeout, i, i + 300)
511                  for i in [0, 5, 16485, 30516])
512
513     for timeout, arg1, exp_timeout in tests:
514       cdef = ("test_call", NotImplemented, None, timeout, [
515         ("arg1", None, NotImplemented),
516         ("arg2", None, NotImplemented),
517         ], None, None, NotImplemented)
518
519       http_proc = _FakeRequestProcessor(compat.partial(_VerifyRequest,
520                                                        exp_timeout))
521       client = rpc._RpcClientBase(resolver, NotImplemented,
522                                   _req_process_fn=http_proc)
523       result = client._Call(cdef, nodes, [arg1, 300])
524       self.assertEqual(len(result), len(nodes))
525       self.assertTrue(compat.all(not res.fail_msg and
526                                  res.payload == hex(exp_timeout)
527                                  for res in result.values()))
528
529   def testArgumentEncoder(self):
530     (AT1, AT2) = range(1, 3)
531
532     resolver = rpc._StaticResolver([
533       "192.0.2.5",
534       "192.0.2.6",
535       ])
536
537     nodes = [
538       "node5.example.com",
539       "node6.example.com",
540       ]
541
542     encoders = {
543       AT1: hex,
544       AT2: hash,
545       }
546
547     cdef = ("test_call", NotImplemented, None, rpc_defs.TMO_NORMAL, [
548       ("arg0", None, NotImplemented),
549       ("arg1", AT1, NotImplemented),
550       ("arg1", AT2, NotImplemented),
551       ], None, None, NotImplemented)
552
553     def _VerifyRequest(req):
554       req.success = True
555       req.resp_status_code = http.HTTP_OK
556       req.resp_body = serializer.DumpJson((True, req.post_data))
557
558     http_proc = _FakeRequestProcessor(_VerifyRequest)
559
560     for num in [0, 3796, 9032119]:
561       client = rpc._RpcClientBase(resolver, encoders.get,
562                                   _req_process_fn=http_proc)
563       result = client._Call(cdef, nodes, ["foo", num, "Hello%s" % num])
564       self.assertEqual(len(result), len(nodes))
565       for res in result.values():
566         self.assertFalse(res.fail_msg)
567         self.assertEqual(serializer.LoadJson(res.payload),
568                          ["foo", hex(num), hash("Hello%s" % num)])
569
570   def testPostProc(self):
571     def _VerifyRequest(nums, req):
572       req.success = True
573       req.resp_status_code = http.HTTP_OK
574       req.resp_body = serializer.DumpJson((True, nums))
575
576     resolver = rpc._StaticResolver([
577       "192.0.2.90",
578       "192.0.2.95",
579       ])
580
581     nodes = [
582       "node90.example.com",
583       "node95.example.com",
584       ]
585
586     def _PostProc(res):
587       self.assertFalse(res.fail_msg)
588       res.payload = sum(res.payload)
589       return res
590
591     cdef = ("test_call", NotImplemented, None, rpc_defs.TMO_NORMAL, [],
592             None, _PostProc, NotImplemented)
593
594     # Seeded random generator
595     rnd = random.Random(20299)
596
597     for i in [0, 4, 74, 1391]:
598       nums = [rnd.randint(0, 1000) for _ in range(i)]
599       http_proc = _FakeRequestProcessor(compat.partial(_VerifyRequest, nums))
600       client = rpc._RpcClientBase(resolver, NotImplemented,
601                                   _req_process_fn=http_proc)
602       result = client._Call(cdef, nodes, [])
603       self.assertEqual(len(result), len(nodes))
604       for res in result.values():
605         self.assertFalse(res.fail_msg)
606         self.assertEqual(res.payload, sum(nums))
607
608   def testPreProc(self):
609     def _VerifyRequest(req):
610       req.success = True
611       req.resp_status_code = http.HTTP_OK
612       req.resp_body = serializer.DumpJson((True, req.post_data))
613
614     resolver = rpc._StaticResolver([
615       "192.0.2.30",
616       "192.0.2.35",
617       ])
618
619     nodes = [
620       "node30.example.com",
621       "node35.example.com",
622       ]
623
624     def _PreProc(node, data):
625       self.assertEqual(len(data), 1)
626       return data[0] + node
627
628     cdef = ("test_call", NotImplemented, None, rpc_defs.TMO_NORMAL, [
629       ("arg0", None, NotImplemented),
630       ], _PreProc, None, NotImplemented)
631
632     http_proc = _FakeRequestProcessor(_VerifyRequest)
633     client = rpc._RpcClientBase(resolver, NotImplemented,
634                                 _req_process_fn=http_proc)
635
636     for prefix in ["foo", "bar", "baz"]:
637       result = client._Call(cdef, nodes, [prefix])
638       self.assertEqual(len(result), len(nodes))
639       for (idx, (node, res)) in enumerate(result.items()):
640         self.assertFalse(res.fail_msg)
641         self.assertEqual(serializer.LoadJson(res.payload), prefix + node)
642
643   def testResolverOptions(self):
644     def _VerifyRequest(req):
645       req.success = True
646       req.resp_status_code = http.HTTP_OK
647       req.resp_body = serializer.DumpJson((True, req.post_data))
648
649     nodes = [
650       "node30.example.com",
651       "node35.example.com",
652       ]
653
654     def _Resolver(expected, hosts, options):
655       self.assertEqual(hosts, nodes)
656       self.assertEqual(options, expected)
657       return zip(hosts, nodes)
658
659     def _DynamicResolverOptions((arg0, )):
660       return sum(arg0)
661
662     tests = [
663       (None, None, None),
664       (rpc_defs.ACCEPT_OFFLINE_NODE, None, rpc_defs.ACCEPT_OFFLINE_NODE),
665       (False, None, False),
666       (True, None, True),
667       (0, None, 0),
668       (_DynamicResolverOptions, [1, 2, 3], 6),
669       (_DynamicResolverOptions, range(4, 19), 165),
670       ]
671
672     for (resolver_opts, arg0, expected) in tests:
673       cdef = ("test_call", NotImplemented, resolver_opts, rpc_defs.TMO_NORMAL, [
674         ("arg0", None, NotImplemented),
675         ], None, None, NotImplemented)
676
677       http_proc = _FakeRequestProcessor(_VerifyRequest)
678
679       client = rpc._RpcClientBase(compat.partial(_Resolver, expected),
680                                   NotImplemented, _req_process_fn=http_proc)
681       result = client._Call(cdef, nodes, [arg0])
682       self.assertEqual(len(result), len(nodes))
683       for (idx, (node, res)) in enumerate(result.items()):
684         self.assertFalse(res.fail_msg)
685
686
687 class _FakeConfigForRpcRunner:
688   GetAllNodesInfo = NotImplemented
689
690   def __init__(self, cluster=NotImplemented):
691     self._cluster = cluster
692
693   def GetNodeInfo(self, name):
694     return objects.Node(name=name)
695
696   def GetClusterInfo(self):
697     return self._cluster
698
699
700 class TestRpcRunner(unittest.TestCase):
701   def testUploadFile(self):
702     data = 1779 * "Hello World\n"
703
704     tmpfile = tempfile.NamedTemporaryFile()
705     tmpfile.write(data)
706     tmpfile.flush()
707     st = os.stat(tmpfile.name)
708
709     def _VerifyRequest(req):
710       (uldata, ) = serializer.LoadJson(req.post_data)
711       self.assertEqual(len(uldata), 7)
712       self.assertEqual(uldata[0], tmpfile.name)
713       self.assertEqual(list(uldata[1]), list(rpc._Compress(data)))
714       self.assertEqual(uldata[2], st.st_mode)
715       self.assertEqual(uldata[3], "user%s" % os.getuid())
716       self.assertEqual(uldata[4], "group%s" % os.getgid())
717       self.assertTrue(uldata[5] is not None)
718       self.assertEqual(uldata[6], st.st_mtime)
719
720       req.success = True
721       req.resp_status_code = http.HTTP_OK
722       req.resp_body = serializer.DumpJson((True, None))
723
724     http_proc = _FakeRequestProcessor(_VerifyRequest)
725
726     std_runner = rpc.RpcRunner(_FakeConfigForRpcRunner(), None,
727                                _req_process_fn=http_proc,
728                                _getents=mocks.FakeGetentResolver)
729
730     cfg_runner = rpc.ConfigRunner(None, ["192.0.2.13"],
731                                   _req_process_fn=http_proc,
732                                   _getents=mocks.FakeGetentResolver)
733
734     nodes = [
735       "node1.example.com",
736       ]
737
738     for runner in [std_runner, cfg_runner]:
739       result = runner.call_upload_file(nodes, tmpfile.name)
740       self.assertEqual(len(result), len(nodes))
741       for (idx, (node, res)) in enumerate(result.items()):
742         self.assertFalse(res.fail_msg)
743
744   def testEncodeInstance(self):
745     cluster = objects.Cluster(hvparams={
746       constants.HT_KVM: {
747         constants.HV_BLOCKDEV_PREFIX: "foo",
748         },
749       },
750       beparams={
751         constants.PP_DEFAULT: {
752           constants.BE_MAXMEM: 8192,
753           },
754         },
755       os_hvp={},
756       osparams={
757         "linux": {
758           "role": "unknown",
759           },
760         })
761     cluster.UpgradeConfig()
762
763     inst = objects.Instance(name="inst1.example.com",
764       hypervisor=constants.HT_FAKE,
765       os="linux",
766       hvparams={
767         constants.HT_KVM: {
768           constants.HV_BLOCKDEV_PREFIX: "bar",
769           constants.HV_ROOT_PATH: "/tmp",
770           },
771         },
772       beparams={
773         constants.BE_MINMEM: 128,
774         constants.BE_MAXMEM: 256,
775         },
776       nics=[
777         objects.NIC(nicparams={
778           constants.NIC_MODE: "mymode",
779           }),
780         ],
781       disks=[])
782     inst.UpgradeConfig()
783
784     cfg = _FakeConfigForRpcRunner(cluster=cluster)
785     runner = rpc.RpcRunner(cfg, None,
786                            _req_process_fn=NotImplemented,
787                            _getents=mocks.FakeGetentResolver)
788
789     def _CheckBasics(result):
790       self.assertEqual(result["name"], "inst1.example.com")
791       self.assertEqual(result["os"], "linux")
792       self.assertEqual(result["beparams"][constants.BE_MINMEM], 128)
793       self.assertEqual(len(result["hvparams"]), 1)
794       self.assertEqual(len(result["nics"]), 1)
795       self.assertEqual(result["nics"][0]["nicparams"][constants.NIC_MODE],
796                        "mymode")
797
798     # Generic object serialization
799     result = runner._encoder((rpc_defs.ED_OBJECT_DICT, inst))
800     _CheckBasics(result)
801
802     result = runner._encoder((rpc_defs.ED_OBJECT_DICT_LIST, 5 * [inst]))
803     map(_CheckBasics, result)
804
805     # Just an instance
806     result = runner._encoder((rpc_defs.ED_INST_DICT, inst))
807     _CheckBasics(result)
808     self.assertEqual(result["beparams"][constants.BE_MAXMEM], 256)
809     self.assertEqual(result["hvparams"][constants.HT_KVM], {
810       constants.HV_BLOCKDEV_PREFIX: "bar",
811       constants.HV_ROOT_PATH: "/tmp",
812       })
813     self.assertEqual(result["osparams"], {
814       "role": "unknown",
815       })
816
817     # Instance with OS parameters
818     result = runner._encoder((rpc_defs.ED_INST_DICT_OSP, (inst, {
819       "role": "webserver",
820       "other": "field",
821       })))
822     _CheckBasics(result)
823     self.assertEqual(result["beparams"][constants.BE_MAXMEM], 256)
824     self.assertEqual(result["hvparams"][constants.HT_KVM], {
825       constants.HV_BLOCKDEV_PREFIX: "bar",
826       constants.HV_ROOT_PATH: "/tmp",
827       })
828     self.assertEqual(result["osparams"], {
829       "role": "webserver",
830       "other": "field",
831       })
832
833     # Instance with hypervisor and backend parameters
834     result = runner._encoder((rpc_defs.ED_INST_DICT_HVP_BEP, (inst, {
835       constants.HT_KVM: {
836         constants.HV_BOOT_ORDER: "xyz",
837         },
838       }, {
839       constants.BE_VCPUS: 100,
840       constants.BE_MAXMEM: 4096,
841       })))
842     _CheckBasics(result)
843     self.assertEqual(result["beparams"][constants.BE_MAXMEM], 4096)
844     self.assertEqual(result["beparams"][constants.BE_VCPUS], 100)
845     self.assertEqual(result["hvparams"][constants.HT_KVM], {
846       constants.HV_BOOT_ORDER: "xyz",
847       })
848
849
850 if __name__ == "__main__":
851   testutils.GanetiTestProgram()