bash_completion: Enable extglob while parsing file
[ganeti-local] / test / ganeti.rpc_unittest.py
1 #!/usr/bin/python
2 #
3
4 # Copyright (C) 2010, 2011, 2012 Google Inc.
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 # 02110-1301, USA.
20
21
22 """Script for testing ganeti.rpc"""
23
24 import os
25 import sys
26 import unittest
27 import random
28 import tempfile
29
30 from ganeti import constants
31 from ganeti import compat
32 from ganeti import rpc
33 from ganeti import rpc_defs
34 from ganeti import http
35 from ganeti import errors
36 from ganeti import serializer
37 from ganeti import objects
38 from ganeti import backend
39
40 import testutils
41 import mocks
42
43
44 class _FakeRequestProcessor:
45   def __init__(self, response_fn):
46     self._response_fn = response_fn
47     self.reqcount = 0
48
49   def __call__(self, reqs, lock_monitor_cb=None):
50     assert lock_monitor_cb is None or callable(lock_monitor_cb)
51     for req in reqs:
52       self.reqcount += 1
53       self._response_fn(req)
54
55
56 def GetFakeSimpleStoreClass(fn):
57   class FakeSimpleStore:
58     GetNodePrimaryIPList = fn
59     GetPrimaryIPFamily = lambda _: None
60
61   return FakeSimpleStore
62
63
64 def _RaiseNotImplemented():
65   """Simple wrapper to raise NotImplementedError.
66
67   """
68   raise NotImplementedError
69
70
71 class TestRpcProcessor(unittest.TestCase):
72   def _FakeAddressLookup(self, map):
73     return lambda node_list: [map.get(node) for node in node_list]
74
75   def _GetVersionResponse(self, req):
76     self.assertEqual(req.host, "127.0.0.1")
77     self.assertEqual(req.port, 24094)
78     self.assertEqual(req.path, "/version")
79     self.assertEqual(req.read_timeout, rpc._TMO_URGENT)
80     req.success = True
81     req.resp_status_code = http.HTTP_OK
82     req.resp_body = serializer.DumpJson((True, 123))
83
84   def testVersionSuccess(self):
85     resolver = rpc._StaticResolver(["127.0.0.1"])
86     http_proc = _FakeRequestProcessor(self._GetVersionResponse)
87     proc = rpc._RpcProcessor(resolver, 24094)
88     result = proc(["localhost"], "version", {"localhost": ""}, 60,
89                   NotImplemented, _req_process_fn=http_proc)
90     self.assertEqual(result.keys(), ["localhost"])
91     lhresp = result["localhost"]
92     self.assertFalse(lhresp.offline)
93     self.assertEqual(lhresp.node, "localhost")
94     self.assertFalse(lhresp.fail_msg)
95     self.assertEqual(lhresp.payload, 123)
96     self.assertEqual(lhresp.call, "version")
97     lhresp.Raise("should not raise")
98     self.assertEqual(http_proc.reqcount, 1)
99
100   def _ReadTimeoutResponse(self, req):
101     self.assertEqual(req.host, "192.0.2.13")
102     self.assertEqual(req.port, 19176)
103     self.assertEqual(req.path, "/version")
104     self.assertEqual(req.read_timeout, 12356)
105     req.success = True
106     req.resp_status_code = http.HTTP_OK
107     req.resp_body = serializer.DumpJson((True, -1))
108
109   def testReadTimeout(self):
110     resolver = rpc._StaticResolver(["192.0.2.13"])
111     http_proc = _FakeRequestProcessor(self._ReadTimeoutResponse)
112     proc = rpc._RpcProcessor(resolver, 19176)
113     host = "node31856"
114     body = {host: ""}
115     result = proc([host], "version", body, 12356, NotImplemented,
116                   _req_process_fn=http_proc)
117     self.assertEqual(result.keys(), [host])
118     lhresp = result[host]
119     self.assertFalse(lhresp.offline)
120     self.assertEqual(lhresp.node, host)
121     self.assertFalse(lhresp.fail_msg)
122     self.assertEqual(lhresp.payload, -1)
123     self.assertEqual(lhresp.call, "version")
124     lhresp.Raise("should not raise")
125     self.assertEqual(http_proc.reqcount, 1)
126
127   def testOfflineNode(self):
128     resolver = rpc._StaticResolver([rpc._OFFLINE])
129     http_proc = _FakeRequestProcessor(NotImplemented)
130     proc = rpc._RpcProcessor(resolver, 30668)
131     host = "n17296"
132     body = {host: ""}
133     result = proc([host], "version", body, 60, NotImplemented,
134                   _req_process_fn=http_proc)
135     self.assertEqual(result.keys(), [host])
136     lhresp = result[host]
137     self.assertTrue(lhresp.offline)
138     self.assertEqual(lhresp.node, host)
139     self.assertTrue(lhresp.fail_msg)
140     self.assertFalse(lhresp.payload)
141     self.assertEqual(lhresp.call, "version")
142
143     # With a message
144     self.assertRaises(errors.OpExecError, lhresp.Raise, "should raise")
145
146     # No message
147     self.assertRaises(errors.OpExecError, lhresp.Raise, None)
148
149     self.assertEqual(http_proc.reqcount, 0)
150
151   def _GetMultiVersionResponse(self, req):
152     self.assert_(req.host.startswith("node"))
153     self.assertEqual(req.port, 23245)
154     self.assertEqual(req.path, "/version")
155     req.success = True
156     req.resp_status_code = http.HTTP_OK
157     req.resp_body = serializer.DumpJson((True, 987))
158
159   def testMultiVersionSuccess(self):
160     nodes = ["node%s" % i for i in range(50)]
161     body = dict((n, "") for n in nodes)
162     resolver = rpc._StaticResolver(nodes)
163     http_proc = _FakeRequestProcessor(self._GetMultiVersionResponse)
164     proc = rpc._RpcProcessor(resolver, 23245)
165     result = proc(nodes, "version", body, 60, NotImplemented,
166                   _req_process_fn=http_proc)
167     self.assertEqual(sorted(result.keys()), sorted(nodes))
168
169     for name in nodes:
170       lhresp = result[name]
171       self.assertFalse(lhresp.offline)
172       self.assertEqual(lhresp.node, name)
173       self.assertFalse(lhresp.fail_msg)
174       self.assertEqual(lhresp.payload, 987)
175       self.assertEqual(lhresp.call, "version")
176       lhresp.Raise("should not raise")
177
178     self.assertEqual(http_proc.reqcount, len(nodes))
179
180   def _GetVersionResponseFail(self, errinfo, req):
181     self.assertEqual(req.path, "/version")
182     req.success = True
183     req.resp_status_code = http.HTTP_OK
184     req.resp_body = serializer.DumpJson((False, errinfo))
185
186   def testVersionFailure(self):
187     resolver = rpc._StaticResolver(["aef9ur4i.example.com"])
188     proc = rpc._RpcProcessor(resolver, 5903)
189     for errinfo in [None, "Unknown error"]:
190       http_proc = \
191         _FakeRequestProcessor(compat.partial(self._GetVersionResponseFail,
192                                              errinfo))
193       host = "aef9ur4i.example.com"
194       body = {host: ""}
195       result = proc(body.keys(), "version", body, 60, NotImplemented,
196                     _req_process_fn=http_proc)
197       self.assertEqual(result.keys(), [host])
198       lhresp = result[host]
199       self.assertFalse(lhresp.offline)
200       self.assertEqual(lhresp.node, host)
201       self.assert_(lhresp.fail_msg)
202       self.assertFalse(lhresp.payload)
203       self.assertEqual(lhresp.call, "version")
204       self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
205       self.assertEqual(http_proc.reqcount, 1)
206
207   def _GetHttpErrorResponse(self, httperrnodes, failnodes, req):
208     self.assertEqual(req.path, "/vg_list")
209     self.assertEqual(req.port, 15165)
210
211     if req.host in httperrnodes:
212       req.success = False
213       req.error = "Node set up for HTTP errors"
214
215     elif req.host in failnodes:
216       req.success = True
217       req.resp_status_code = 404
218       req.resp_body = serializer.DumpJson({
219         "code": 404,
220         "message": "Method not found",
221         "explain": "Explanation goes here",
222         })
223     else:
224       req.success = True
225       req.resp_status_code = http.HTTP_OK
226       req.resp_body = serializer.DumpJson((True, hash(req.host)))
227
228   def testHttpError(self):
229     nodes = ["uaf6pbbv%s" % i for i in range(50)]
230     body = dict((n, "") for n in nodes)
231     resolver = rpc._StaticResolver(nodes)
232
233     httperrnodes = set(nodes[1::7])
234     self.assertEqual(len(httperrnodes), 7)
235
236     failnodes = set(nodes[2::3]) - httperrnodes
237     self.assertEqual(len(failnodes), 14)
238
239     self.assertEqual(len(set(nodes) - failnodes - httperrnodes), 29)
240
241     proc = rpc._RpcProcessor(resolver, 15165)
242     http_proc = \
243       _FakeRequestProcessor(compat.partial(self._GetHttpErrorResponse,
244                                            httperrnodes, failnodes))
245     result = proc(nodes, "vg_list", body, rpc._TMO_URGENT, NotImplemented,
246                   _req_process_fn=http_proc)
247     self.assertEqual(sorted(result.keys()), sorted(nodes))
248
249     for name in nodes:
250       lhresp = result[name]
251       self.assertFalse(lhresp.offline)
252       self.assertEqual(lhresp.node, name)
253       self.assertEqual(lhresp.call, "vg_list")
254
255       if name in httperrnodes:
256         self.assert_(lhresp.fail_msg)
257         self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
258       elif name in failnodes:
259         self.assert_(lhresp.fail_msg)
260         self.assertRaises(errors.OpPrereqError, lhresp.Raise, "failed",
261                           prereq=True, ecode=errors.ECODE_INVAL)
262       else:
263         self.assertFalse(lhresp.fail_msg)
264         self.assertEqual(lhresp.payload, hash(name))
265         lhresp.Raise("should not raise")
266
267     self.assertEqual(http_proc.reqcount, len(nodes))
268
269   def _GetInvalidResponseA(self, req):
270     self.assertEqual(req.path, "/version")
271     req.success = True
272     req.resp_status_code = http.HTTP_OK
273     req.resp_body = serializer.DumpJson(("This", "is", "an", "invalid",
274                                          "response", "!", 1, 2, 3))
275
276   def _GetInvalidResponseB(self, req):
277     self.assertEqual(req.path, "/version")
278     req.success = True
279     req.resp_status_code = http.HTTP_OK
280     req.resp_body = serializer.DumpJson("invalid response")
281
282   def testInvalidResponse(self):
283     resolver = rpc._StaticResolver(["oqo7lanhly.example.com"])
284     proc = rpc._RpcProcessor(resolver, 19978)
285
286     for fn in [self._GetInvalidResponseA, self._GetInvalidResponseB]:
287       http_proc = _FakeRequestProcessor(fn)
288       host = "oqo7lanhly.example.com"
289       body = {host: ""}
290       result = proc([host], "version", body, 60, NotImplemented,
291                     _req_process_fn=http_proc)
292       self.assertEqual(result.keys(), [host])
293       lhresp = result[host]
294       self.assertFalse(lhresp.offline)
295       self.assertEqual(lhresp.node, host)
296       self.assert_(lhresp.fail_msg)
297       self.assertFalse(lhresp.payload)
298       self.assertEqual(lhresp.call, "version")
299       self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
300       self.assertEqual(http_proc.reqcount, 1)
301
302   def _GetBodyTestResponse(self, test_data, req):
303     self.assertEqual(req.host, "192.0.2.84")
304     self.assertEqual(req.port, 18700)
305     self.assertEqual(req.path, "/upload_file")
306     self.assertEqual(serializer.LoadJson(req.post_data), test_data)
307     req.success = True
308     req.resp_status_code = http.HTTP_OK
309     req.resp_body = serializer.DumpJson((True, None))
310
311   def testResponseBody(self):
312     test_data = {
313       "Hello": "World",
314       "xyz": range(10),
315       }
316     resolver = rpc._StaticResolver(["192.0.2.84"])
317     http_proc = _FakeRequestProcessor(compat.partial(self._GetBodyTestResponse,
318                                                      test_data))
319     proc = rpc._RpcProcessor(resolver, 18700)
320     host = "node19759"
321     body = {host: serializer.DumpJson(test_data)}
322     result = proc([host], "upload_file", body, 30, NotImplemented,
323                   _req_process_fn=http_proc)
324     self.assertEqual(result.keys(), [host])
325     lhresp = result[host]
326     self.assertFalse(lhresp.offline)
327     self.assertEqual(lhresp.node, host)
328     self.assertFalse(lhresp.fail_msg)
329     self.assertEqual(lhresp.payload, None)
330     self.assertEqual(lhresp.call, "upload_file")
331     lhresp.Raise("should not raise")
332     self.assertEqual(http_proc.reqcount, 1)
333
334
335 class TestSsconfResolver(unittest.TestCase):
336   def testSsconfLookup(self):
337     addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
338     node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
339     node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
340     ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
341     result = rpc._SsconfResolver(True, node_list, NotImplemented,
342                                  ssc=ssc, nslookup_fn=NotImplemented)
343     self.assertEqual(result, zip(node_list, addr_list))
344
345   def testNsLookup(self):
346     addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
347     node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
348     ssc = GetFakeSimpleStoreClass(lambda _: [])
349     node_addr_map = dict(zip(node_list, addr_list))
350     nslookup_fn = lambda name, family=None: node_addr_map.get(name)
351     result = rpc._SsconfResolver(True, node_list, NotImplemented,
352                                  ssc=ssc, nslookup_fn=nslookup_fn)
353     self.assertEqual(result, zip(node_list, addr_list))
354
355   def testDisabledSsconfIp(self):
356     addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
357     node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
358     ssc = GetFakeSimpleStoreClass(_RaiseNotImplemented)
359     node_addr_map = dict(zip(node_list, addr_list))
360     nslookup_fn = lambda name, family=None: node_addr_map.get(name)
361     result = rpc._SsconfResolver(False, node_list, NotImplemented,
362                                  ssc=ssc, nslookup_fn=nslookup_fn)
363     self.assertEqual(result, zip(node_list, addr_list))
364
365   def testBothLookups(self):
366     addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
367     node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
368     n = len(addr_list) / 2
369     node_addr_list = [" ".join(t) for t in zip(node_list[n:], addr_list[n:])]
370     ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
371     node_addr_map = dict(zip(node_list[:n], addr_list[:n]))
372     nslookup_fn = lambda name, family=None: node_addr_map.get(name)
373     result = rpc._SsconfResolver(True, node_list, NotImplemented,
374                                  ssc=ssc, nslookup_fn=nslookup_fn)
375     self.assertEqual(result, zip(node_list, addr_list))
376
377   def testAddressLookupIPv6(self):
378     addr_list = ["2001:db8::%d" % n for n in range(0, 255, 11)]
379     node_list = ["node%d.example.com" % n for n in range(0, 255, 11)]
380     node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
381     ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
382     result = rpc._SsconfResolver(True, node_list, NotImplemented,
383                                  ssc=ssc, nslookup_fn=NotImplemented)
384     self.assertEqual(result, zip(node_list, addr_list))
385
386
387 class TestStaticResolver(unittest.TestCase):
388   def test(self):
389     addresses = ["192.0.2.%d" % n for n in range(0, 123, 7)]
390     nodes = ["node%s.example.com" % n for n in range(0, 123, 7)]
391     res = rpc._StaticResolver(addresses)
392     self.assertEqual(res(nodes, NotImplemented), zip(nodes, addresses))
393
394   def testWrongLength(self):
395     res = rpc._StaticResolver([])
396     self.assertRaises(AssertionError, res, ["abc"], NotImplemented)
397
398
399 class TestNodeConfigResolver(unittest.TestCase):
400   @staticmethod
401   def _GetSingleOnlineNode(name):
402     assert name == "node90.example.com"
403     return objects.Node(name=name, offline=False, primary_ip="192.0.2.90")
404
405   @staticmethod
406   def _GetSingleOfflineNode(name):
407     assert name == "node100.example.com"
408     return objects.Node(name=name, offline=True, primary_ip="192.0.2.100")
409
410   def testSingleOnline(self):
411     self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOnlineNode,
412                                              NotImplemented,
413                                              ["node90.example.com"], None),
414                      [("node90.example.com", "192.0.2.90")])
415
416   def testSingleOffline(self):
417     self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOfflineNode,
418                                              NotImplemented,
419                                              ["node100.example.com"], None),
420                      [("node100.example.com", rpc._OFFLINE)])
421
422   def testSingleOfflineWithAcceptOffline(self):
423     fn = self._GetSingleOfflineNode
424     assert fn("node100.example.com").offline
425     self.assertEqual(rpc._NodeConfigResolver(fn, NotImplemented,
426                                              ["node100.example.com"],
427                                              rpc_defs.ACCEPT_OFFLINE_NODE),
428                      [("node100.example.com", "192.0.2.100")])
429     for i in [False, True, "", "Hello", 0, 1]:
430       self.assertRaises(AssertionError, rpc._NodeConfigResolver,
431                         fn, NotImplemented, ["node100.example.com"], i)
432
433   def testUnknownSingleNode(self):
434     self.assertEqual(rpc._NodeConfigResolver(lambda _: None, NotImplemented,
435                                              ["node110.example.com"], None),
436                      [("node110.example.com", "node110.example.com")])
437
438   def testMultiEmpty(self):
439     self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
440                                              lambda: {},
441                                              [], None),
442                      [])
443
444   def testMultiSomeOffline(self):
445     nodes = dict(("node%s.example.com" % i,
446                   objects.Node(name="node%s.example.com" % i,
447                                offline=((i % 3) == 0),
448                                primary_ip="192.0.2.%s" % i))
449                   for i in range(1, 255))
450
451     # Resolve no names
452     self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
453                                              lambda: nodes,
454                                              [], None),
455                      [])
456
457     # Offline, online and unknown hosts
458     self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
459                                              lambda: nodes,
460                                              ["node3.example.com",
461                                               "node92.example.com",
462                                               "node54.example.com",
463                                               "unknown.example.com",],
464                                              None), [
465       ("node3.example.com", rpc._OFFLINE),
466       ("node92.example.com", "192.0.2.92"),
467       ("node54.example.com", rpc._OFFLINE),
468       ("unknown.example.com", "unknown.example.com"),
469       ])
470
471
472 class TestCompress(unittest.TestCase):
473   def test(self):
474     for data in ["", "Hello", "Hello World!\nnew\nlines"]:
475       self.assertEqual(rpc._Compress(data),
476                        (constants.RPC_ENCODING_NONE, data))
477
478     for data in [512 * " ", 5242 * "Hello World!\n"]:
479       compressed = rpc._Compress(data)
480       self.assertEqual(len(compressed), 2)
481       self.assertEqual(backend._Decompress(compressed), data)
482
483   def testDecompression(self):
484     self.assertRaises(AssertionError, backend._Decompress, "")
485     self.assertRaises(AssertionError, backend._Decompress, [""])
486     self.assertRaises(AssertionError, backend._Decompress,
487                       ("unknown compression", "data"))
488     self.assertRaises(Exception, backend._Decompress,
489                       (constants.RPC_ENCODING_ZLIB_BASE64, "invalid zlib data"))
490
491
492 class TestRpcClientBase(unittest.TestCase):
493   def testNoHosts(self):
494     cdef = ("test_call", NotImplemented, None, rpc_defs.TMO_SLOW, [],
495             None, None, NotImplemented)
496     http_proc = _FakeRequestProcessor(NotImplemented)
497     client = rpc._RpcClientBase(rpc._StaticResolver([]), NotImplemented,
498                                 _req_process_fn=http_proc)
499     self.assertEqual(client._Call(cdef, [], []), {})
500
501     # Test wrong number of arguments
502     self.assertRaises(errors.ProgrammerError, client._Call,
503                       cdef, [], [0, 1, 2])
504
505   def testTimeout(self):
506     def _CalcTimeout((arg1, arg2)):
507       return arg1 + arg2
508
509     def _VerifyRequest(exp_timeout, req):
510       self.assertEqual(req.read_timeout, exp_timeout)
511
512       req.success = True
513       req.resp_status_code = http.HTTP_OK
514       req.resp_body = serializer.DumpJson((True, hex(req.read_timeout)))
515
516     resolver = rpc._StaticResolver([
517       "192.0.2.1",
518       "192.0.2.2",
519       ])
520
521     nodes = [
522       "node1.example.com",
523       "node2.example.com",
524       ]
525
526     tests = [(100, None, 100), (30, None, 30)]
527     tests.extend((_CalcTimeout, i, i + 300)
528                  for i in [0, 5, 16485, 30516])
529
530     for timeout, arg1, exp_timeout in tests:
531       cdef = ("test_call", NotImplemented, None, timeout, [
532         ("arg1", None, NotImplemented),
533         ("arg2", None, NotImplemented),
534         ], None, None, NotImplemented)
535
536       http_proc = _FakeRequestProcessor(compat.partial(_VerifyRequest,
537                                                        exp_timeout))
538       client = rpc._RpcClientBase(resolver, NotImplemented,
539                                   _req_process_fn=http_proc)
540       result = client._Call(cdef, nodes, [arg1, 300])
541       self.assertEqual(len(result), len(nodes))
542       self.assertTrue(compat.all(not res.fail_msg and
543                                  res.payload == hex(exp_timeout)
544                                  for res in result.values()))
545
546   def testArgumentEncoder(self):
547     (AT1, AT2) = range(1, 3)
548
549     resolver = rpc._StaticResolver([
550       "192.0.2.5",
551       "192.0.2.6",
552       ])
553
554     nodes = [
555       "node5.example.com",
556       "node6.example.com",
557       ]
558
559     encoders = {
560       AT1: hex,
561       AT2: hash,
562       }
563
564     cdef = ("test_call", NotImplemented, None, rpc_defs.TMO_NORMAL, [
565       ("arg0", None, NotImplemented),
566       ("arg1", AT1, NotImplemented),
567       ("arg1", AT2, NotImplemented),
568       ], None, None, NotImplemented)
569
570     def _VerifyRequest(req):
571       req.success = True
572       req.resp_status_code = http.HTTP_OK
573       req.resp_body = serializer.DumpJson((True, req.post_data))
574
575     http_proc = _FakeRequestProcessor(_VerifyRequest)
576
577     for num in [0, 3796, 9032119]:
578       client = rpc._RpcClientBase(resolver, encoders.get,
579                                   _req_process_fn=http_proc)
580       result = client._Call(cdef, nodes, ["foo", num, "Hello%s" % num])
581       self.assertEqual(len(result), len(nodes))
582       for res in result.values():
583         self.assertFalse(res.fail_msg)
584         self.assertEqual(serializer.LoadJson(res.payload),
585                          ["foo", hex(num), hash("Hello%s" % num)])
586
587   def testPostProc(self):
588     def _VerifyRequest(nums, req):
589       req.success = True
590       req.resp_status_code = http.HTTP_OK
591       req.resp_body = serializer.DumpJson((True, nums))
592
593     resolver = rpc._StaticResolver([
594       "192.0.2.90",
595       "192.0.2.95",
596       ])
597
598     nodes = [
599       "node90.example.com",
600       "node95.example.com",
601       ]
602
603     def _PostProc(res):
604       self.assertFalse(res.fail_msg)
605       res.payload = sum(res.payload)
606       return res
607
608     cdef = ("test_call", NotImplemented, None, rpc_defs.TMO_NORMAL, [],
609             None, _PostProc, NotImplemented)
610
611     # Seeded random generator
612     rnd = random.Random(20299)
613
614     for i in [0, 4, 74, 1391]:
615       nums = [rnd.randint(0, 1000) for _ in range(i)]
616       http_proc = _FakeRequestProcessor(compat.partial(_VerifyRequest, nums))
617       client = rpc._RpcClientBase(resolver, NotImplemented,
618                                   _req_process_fn=http_proc)
619       result = client._Call(cdef, nodes, [])
620       self.assertEqual(len(result), len(nodes))
621       for res in result.values():
622         self.assertFalse(res.fail_msg)
623         self.assertEqual(res.payload, sum(nums))
624
625   def testPreProc(self):
626     def _VerifyRequest(req):
627       req.success = True
628       req.resp_status_code = http.HTTP_OK
629       req.resp_body = serializer.DumpJson((True, req.post_data))
630
631     resolver = rpc._StaticResolver([
632       "192.0.2.30",
633       "192.0.2.35",
634       ])
635
636     nodes = [
637       "node30.example.com",
638       "node35.example.com",
639       ]
640
641     def _PreProc(node, data):
642       self.assertEqual(len(data), 1)
643       return data[0] + node
644
645     cdef = ("test_call", NotImplemented, None, rpc_defs.TMO_NORMAL, [
646       ("arg0", None, NotImplemented),
647       ], _PreProc, None, NotImplemented)
648
649     http_proc = _FakeRequestProcessor(_VerifyRequest)
650     client = rpc._RpcClientBase(resolver, NotImplemented,
651                                 _req_process_fn=http_proc)
652
653     for prefix in ["foo", "bar", "baz"]:
654       result = client._Call(cdef, nodes, [prefix])
655       self.assertEqual(len(result), len(nodes))
656       for (idx, (node, res)) in enumerate(result.items()):
657         self.assertFalse(res.fail_msg)
658         self.assertEqual(serializer.LoadJson(res.payload), prefix + node)
659
660   def testResolverOptions(self):
661     def _VerifyRequest(req):
662       req.success = True
663       req.resp_status_code = http.HTTP_OK
664       req.resp_body = serializer.DumpJson((True, req.post_data))
665
666     nodes = [
667       "node30.example.com",
668       "node35.example.com",
669       ]
670
671     def _Resolver(expected, hosts, options):
672       self.assertEqual(hosts, nodes)
673       self.assertEqual(options, expected)
674       return zip(hosts, nodes)
675
676     def _DynamicResolverOptions((arg0, )):
677       return sum(arg0)
678
679     tests = [
680       (None, None, None),
681       (rpc_defs.ACCEPT_OFFLINE_NODE, None, rpc_defs.ACCEPT_OFFLINE_NODE),
682       (False, None, False),
683       (True, None, True),
684       (0, None, 0),
685       (_DynamicResolverOptions, [1, 2, 3], 6),
686       (_DynamicResolverOptions, range(4, 19), 165),
687       ]
688
689     for (resolver_opts, arg0, expected) in tests:
690       cdef = ("test_call", NotImplemented, resolver_opts, rpc_defs.TMO_NORMAL, [
691         ("arg0", None, NotImplemented),
692         ], None, None, NotImplemented)
693
694       http_proc = _FakeRequestProcessor(_VerifyRequest)
695
696       client = rpc._RpcClientBase(compat.partial(_Resolver, expected),
697                                   NotImplemented, _req_process_fn=http_proc)
698       result = client._Call(cdef, nodes, [arg0])
699       self.assertEqual(len(result), len(nodes))
700       for (idx, (node, res)) in enumerate(result.items()):
701         self.assertFalse(res.fail_msg)
702
703
704 class _FakeConfigForRpcRunner:
705   GetAllNodesInfo = NotImplemented
706
707   def __init__(self, cluster=NotImplemented):
708     self._cluster = cluster
709
710   def GetNodeInfo(self, name):
711     return objects.Node(name=name)
712
713   def GetClusterInfo(self):
714     return self._cluster
715
716   def GetInstanceDiskParams(self, _):
717     return constants.DISK_DT_DEFAULTS
718
719
720 class TestRpcRunner(unittest.TestCase):
721   def testUploadFile(self):
722     data = 1779 * "Hello World\n"
723
724     tmpfile = tempfile.NamedTemporaryFile()
725     tmpfile.write(data)
726     tmpfile.flush()
727     st = os.stat(tmpfile.name)
728
729     def _VerifyRequest(req):
730       (uldata, ) = serializer.LoadJson(req.post_data)
731       self.assertEqual(len(uldata), 7)
732       self.assertEqual(uldata[0], tmpfile.name)
733       self.assertEqual(list(uldata[1]), list(rpc._Compress(data)))
734       self.assertEqual(uldata[2], st.st_mode)
735       self.assertEqual(uldata[3], "user%s" % os.getuid())
736       self.assertEqual(uldata[4], "group%s" % os.getgid())
737       self.assertTrue(uldata[5] is not None)
738       self.assertEqual(uldata[6], st.st_mtime)
739
740       req.success = True
741       req.resp_status_code = http.HTTP_OK
742       req.resp_body = serializer.DumpJson((True, None))
743
744     http_proc = _FakeRequestProcessor(_VerifyRequest)
745
746     std_runner = rpc.RpcRunner(_FakeConfigForRpcRunner(), None,
747                                _req_process_fn=http_proc,
748                                _getents=mocks.FakeGetentResolver)
749
750     cfg_runner = rpc.ConfigRunner(None, ["192.0.2.13"],
751                                   _req_process_fn=http_proc,
752                                   _getents=mocks.FakeGetentResolver)
753
754     nodes = [
755       "node1.example.com",
756       ]
757
758     for runner in [std_runner, cfg_runner]:
759       result = runner.call_upload_file(nodes, tmpfile.name)
760       self.assertEqual(len(result), len(nodes))
761       for (idx, (node, res)) in enumerate(result.items()):
762         self.assertFalse(res.fail_msg)
763
764   def testEncodeInstance(self):
765     cluster = objects.Cluster(hvparams={
766       constants.HT_KVM: {
767         constants.HV_BLOCKDEV_PREFIX: "foo",
768         },
769       },
770       beparams={
771         constants.PP_DEFAULT: {
772           constants.BE_MAXMEM: 8192,
773           },
774         },
775       os_hvp={},
776       osparams={
777         "linux": {
778           "role": "unknown",
779           },
780         })
781     cluster.UpgradeConfig()
782
783     inst = objects.Instance(name="inst1.example.com",
784       hypervisor=constants.HT_FAKE,
785       os="linux",
786       hvparams={
787         constants.HT_KVM: {
788           constants.HV_BLOCKDEV_PREFIX: "bar",
789           constants.HV_ROOT_PATH: "/tmp",
790           },
791         },
792       beparams={
793         constants.BE_MINMEM: 128,
794         constants.BE_MAXMEM: 256,
795         },
796       nics=[
797         objects.NIC(nicparams={
798           constants.NIC_MODE: "mymode",
799           }),
800         ],
801       disk_template=constants.DT_PLAIN,
802       disks=[
803         objects.Disk(dev_type=constants.LD_LV, size=4096,
804                      logical_id=("vg", "disk6120")),
805         objects.Disk(dev_type=constants.LD_LV, size=1024,
806                      logical_id=("vg", "disk8508")),
807         ])
808     inst.UpgradeConfig()
809
810     cfg = _FakeConfigForRpcRunner(cluster=cluster)
811     runner = rpc.RpcRunner(cfg, None,
812                            _req_process_fn=NotImplemented,
813                            _getents=mocks.FakeGetentResolver)
814
815     def _CheckBasics(result):
816       self.assertEqual(result["name"], "inst1.example.com")
817       self.assertEqual(result["os"], "linux")
818       self.assertEqual(result["beparams"][constants.BE_MINMEM], 128)
819       self.assertEqual(len(result["hvparams"]), 1)
820       self.assertEqual(len(result["nics"]), 1)
821       self.assertEqual(result["nics"][0]["nicparams"][constants.NIC_MODE],
822                        "mymode")
823
824     # Generic object serialization
825     result = runner._encoder((rpc_defs.ED_OBJECT_DICT, inst))
826     _CheckBasics(result)
827
828     result = runner._encoder((rpc_defs.ED_OBJECT_DICT_LIST, 5 * [inst]))
829     map(_CheckBasics, result)
830
831     # Just an instance
832     result = runner._encoder((rpc_defs.ED_INST_DICT, inst))
833     _CheckBasics(result)
834     self.assertEqual(result["beparams"][constants.BE_MAXMEM], 256)
835     self.assertEqual(result["hvparams"][constants.HT_KVM], {
836       constants.HV_BLOCKDEV_PREFIX: "bar",
837       constants.HV_ROOT_PATH: "/tmp",
838       })
839     self.assertEqual(result["osparams"], {
840       "role": "unknown",
841       })
842
843     # Instance with OS parameters
844     result = runner._encoder((rpc_defs.ED_INST_DICT_OSP_DP, (inst, {
845       "role": "webserver",
846       "other": "field",
847       })))
848     _CheckBasics(result)
849     self.assertEqual(result["beparams"][constants.BE_MAXMEM], 256)
850     self.assertEqual(result["hvparams"][constants.HT_KVM], {
851       constants.HV_BLOCKDEV_PREFIX: "bar",
852       constants.HV_ROOT_PATH: "/tmp",
853       })
854     self.assertEqual(result["osparams"], {
855       "role": "webserver",
856       "other": "field",
857       })
858
859     # Instance with hypervisor and backend parameters
860     result = runner._encoder((rpc_defs.ED_INST_DICT_HVP_BEP_DP, (inst, {
861       constants.HT_KVM: {
862         constants.HV_BOOT_ORDER: "xyz",
863         },
864       }, {
865       constants.BE_VCPUS: 100,
866       constants.BE_MAXMEM: 4096,
867       })))
868     _CheckBasics(result)
869     self.assertEqual(result["beparams"][constants.BE_MAXMEM], 4096)
870     self.assertEqual(result["beparams"][constants.BE_VCPUS], 100)
871     self.assertEqual(result["hvparams"][constants.HT_KVM], {
872       constants.HV_BOOT_ORDER: "xyz",
873       })
874     self.assertEqual(result["disks"], [{
875       "dev_type": constants.LD_LV,
876       "size": 4096,
877       "logical_id": ("vg", "disk6120"),
878       "params": constants.DISK_DT_DEFAULTS[inst.disk_template],
879       }, {
880       "dev_type": constants.LD_LV,
881       "size": 1024,
882       "logical_id": ("vg", "disk8508"),
883       "params": constants.DISK_DT_DEFAULTS[inst.disk_template],
884       }])
885
886     self.assertTrue(compat.all(disk.params == {} for disk in inst.disks),
887                     msg="Configuration objects were modified")
888
889
890 if __name__ == "__main__":
891   testutils.GanetiTestProgram()