Statistics
| Branch: | Tag: | Revision:

root / test / ganeti.rpc_unittest.py @ 1a2eb2dc

History | View | Annotate | Download (31.1 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2010, 2011, 2012 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for testing ganeti.rpc"""
23

    
24
import os
25
import sys
26
import unittest
27
import random
28
import tempfile
29

    
30
from ganeti import constants
31
from ganeti import compat
32
from ganeti import rpc
33
from ganeti import rpc_defs
34
from ganeti import http
35
from ganeti import errors
36
from ganeti import serializer
37
from ganeti import objects
38
from ganeti import backend
39

    
40
import testutils
41
import mocks
42

    
43

    
44
class _FakeRequestProcessor:
45
  def __init__(self, response_fn):
46
    self._response_fn = response_fn
47
    self.reqcount = 0
48

    
49
  def __call__(self, reqs, lock_monitor_cb=None):
50
    assert lock_monitor_cb is None or callable(lock_monitor_cb)
51
    for req in reqs:
52
      self.reqcount += 1
53
      self._response_fn(req)
54

    
55

    
56
def GetFakeSimpleStoreClass(fn):
57
  class FakeSimpleStore:
58
    GetNodePrimaryIPList = fn
59
    GetPrimaryIPFamily = lambda _: None
60

    
61
  return FakeSimpleStore
62

    
63

    
64
def _RaiseNotImplemented():
65
  """Simple wrapper to raise NotImplementedError.
66

67
  """
68
  raise NotImplementedError
69

    
70

    
71
class TestRpcProcessor(unittest.TestCase):
72
  def _FakeAddressLookup(self, map):
73
    return lambda node_list: [map.get(node) for node in node_list]
74

    
75
  def _GetVersionResponse(self, req):
76
    self.assertEqual(req.host, "127.0.0.1")
77
    self.assertEqual(req.port, 24094)
78
    self.assertEqual(req.path, "/version")
79
    self.assertEqual(req.read_timeout, constants.RPC_TMO_URGENT)
80
    req.success = True
81
    req.resp_status_code = http.HTTP_OK
82
    req.resp_body = serializer.DumpJson((True, 123))
83

    
84
  def testVersionSuccess(self):
85
    resolver = rpc._StaticResolver(["127.0.0.1"])
86
    http_proc = _FakeRequestProcessor(self._GetVersionResponse)
87
    proc = rpc._RpcProcessor(resolver, 24094)
88
    result = proc(["localhost"], "version", {"localhost": ""}, 60,
89
                  NotImplemented, _req_process_fn=http_proc)
90
    self.assertEqual(result.keys(), ["localhost"])
91
    lhresp = result["localhost"]
92
    self.assertFalse(lhresp.offline)
93
    self.assertEqual(lhresp.node, "localhost")
94
    self.assertFalse(lhresp.fail_msg)
95
    self.assertEqual(lhresp.payload, 123)
96
    self.assertEqual(lhresp.call, "version")
97
    lhresp.Raise("should not raise")
98
    self.assertEqual(http_proc.reqcount, 1)
99

    
100
  def _ReadTimeoutResponse(self, req):
101
    self.assertEqual(req.host, "192.0.2.13")
102
    self.assertEqual(req.port, 19176)
103
    self.assertEqual(req.path, "/version")
104
    self.assertEqual(req.read_timeout, 12356)
105
    req.success = True
106
    req.resp_status_code = http.HTTP_OK
107
    req.resp_body = serializer.DumpJson((True, -1))
108

    
109
  def testReadTimeout(self):
110
    resolver = rpc._StaticResolver(["192.0.2.13"])
111
    http_proc = _FakeRequestProcessor(self._ReadTimeoutResponse)
112
    proc = rpc._RpcProcessor(resolver, 19176)
113
    host = "node31856"
114
    body = {host: ""}
115
    result = proc([host], "version", body, 12356, NotImplemented,
116
                  _req_process_fn=http_proc)
117
    self.assertEqual(result.keys(), [host])
118
    lhresp = result[host]
119
    self.assertFalse(lhresp.offline)
120
    self.assertEqual(lhresp.node, host)
121
    self.assertFalse(lhresp.fail_msg)
122
    self.assertEqual(lhresp.payload, -1)
123
    self.assertEqual(lhresp.call, "version")
124
    lhresp.Raise("should not raise")
125
    self.assertEqual(http_proc.reqcount, 1)
126

    
127
  def testOfflineNode(self):
128
    resolver = rpc._StaticResolver([rpc._OFFLINE])
129
    http_proc = _FakeRequestProcessor(NotImplemented)
130
    proc = rpc._RpcProcessor(resolver, 30668)
131
    host = "n17296"
132
    body = {host: ""}
133
    result = proc([host], "version", body, 60, NotImplemented,
134
                  _req_process_fn=http_proc)
135
    self.assertEqual(result.keys(), [host])
136
    lhresp = result[host]
137
    self.assertTrue(lhresp.offline)
138
    self.assertEqual(lhresp.node, host)
139
    self.assertTrue(lhresp.fail_msg)
140
    self.assertFalse(lhresp.payload)
141
    self.assertEqual(lhresp.call, "version")
142

    
143
    # With a message
144
    self.assertRaises(errors.OpExecError, lhresp.Raise, "should raise")
145

    
146
    # No message
147
    self.assertRaises(errors.OpExecError, lhresp.Raise, None)
148

    
149
    self.assertEqual(http_proc.reqcount, 0)
150

    
151
  def _GetMultiVersionResponse(self, req):
152
    self.assert_(req.host.startswith("node"))
153
    self.assertEqual(req.port, 23245)
154
    self.assertEqual(req.path, "/version")
155
    req.success = True
156
    req.resp_status_code = http.HTTP_OK
157
    req.resp_body = serializer.DumpJson((True, 987))
158

    
159
  def testMultiVersionSuccess(self):
160
    nodes = ["node%s" % i for i in range(50)]
161
    body = dict((n, "") for n in nodes)
162
    resolver = rpc._StaticResolver(nodes)
163
    http_proc = _FakeRequestProcessor(self._GetMultiVersionResponse)
164
    proc = rpc._RpcProcessor(resolver, 23245)
165
    result = proc(nodes, "version", body, 60, NotImplemented,
166
                  _req_process_fn=http_proc)
167
    self.assertEqual(sorted(result.keys()), sorted(nodes))
168

    
169
    for name in nodes:
170
      lhresp = result[name]
171
      self.assertFalse(lhresp.offline)
172
      self.assertEqual(lhresp.node, name)
173
      self.assertFalse(lhresp.fail_msg)
174
      self.assertEqual(lhresp.payload, 987)
175
      self.assertEqual(lhresp.call, "version")
176
      lhresp.Raise("should not raise")
177

    
178
    self.assertEqual(http_proc.reqcount, len(nodes))
179

    
180
  def _GetVersionResponseFail(self, errinfo, req):
181
    self.assertEqual(req.path, "/version")
182
    req.success = True
183
    req.resp_status_code = http.HTTP_OK
184
    req.resp_body = serializer.DumpJson((False, errinfo))
185

    
186
  def testVersionFailure(self):
187
    resolver = rpc._StaticResolver(["aef9ur4i.example.com"])
188
    proc = rpc._RpcProcessor(resolver, 5903)
189
    for errinfo in [None, "Unknown error"]:
190
      http_proc = \
191
        _FakeRequestProcessor(compat.partial(self._GetVersionResponseFail,
192
                                             errinfo))
193
      host = "aef9ur4i.example.com"
194
      body = {host: ""}
195
      result = proc(body.keys(), "version", body, 60, NotImplemented,
196
                    _req_process_fn=http_proc)
197
      self.assertEqual(result.keys(), [host])
198
      lhresp = result[host]
199
      self.assertFalse(lhresp.offline)
200
      self.assertEqual(lhresp.node, host)
201
      self.assert_(lhresp.fail_msg)
202
      self.assertFalse(lhresp.payload)
203
      self.assertEqual(lhresp.call, "version")
204
      self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
205
      self.assertEqual(http_proc.reqcount, 1)
206

    
207
  def _GetHttpErrorResponse(self, httperrnodes, failnodes, req):
208
    self.assertEqual(req.path, "/vg_list")
209
    self.assertEqual(req.port, 15165)
210

    
211
    if req.host in httperrnodes:
212
      req.success = False
213
      req.error = "Node set up for HTTP errors"
214

    
215
    elif req.host in failnodes:
216
      req.success = True
217
      req.resp_status_code = 404
218
      req.resp_body = serializer.DumpJson({
219
        "code": 404,
220
        "message": "Method not found",
221
        "explain": "Explanation goes here",
222
        })
223
    else:
224
      req.success = True
225
      req.resp_status_code = http.HTTP_OK
226
      req.resp_body = serializer.DumpJson((True, hash(req.host)))
227

    
228
  def testHttpError(self):
229
    nodes = ["uaf6pbbv%s" % i for i in range(50)]
230
    body = dict((n, "") for n in nodes)
231
    resolver = rpc._StaticResolver(nodes)
232

    
233
    httperrnodes = set(nodes[1::7])
234
    self.assertEqual(len(httperrnodes), 7)
235

    
236
    failnodes = set(nodes[2::3]) - httperrnodes
237
    self.assertEqual(len(failnodes), 14)
238

    
239
    self.assertEqual(len(set(nodes) - failnodes - httperrnodes), 29)
240

    
241
    proc = rpc._RpcProcessor(resolver, 15165)
242
    http_proc = \
243
      _FakeRequestProcessor(compat.partial(self._GetHttpErrorResponse,
244
                                           httperrnodes, failnodes))
245
    result = proc(nodes, "vg_list", body,
246
                  constants.RPC_TMO_URGENT, NotImplemented,
247
                  _req_process_fn=http_proc)
248
    self.assertEqual(sorted(result.keys()), sorted(nodes))
249

    
250
    for name in nodes:
251
      lhresp = result[name]
252
      self.assertFalse(lhresp.offline)
253
      self.assertEqual(lhresp.node, name)
254
      self.assertEqual(lhresp.call, "vg_list")
255

    
256
      if name in httperrnodes:
257
        self.assert_(lhresp.fail_msg)
258
        self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
259
      elif name in failnodes:
260
        self.assert_(lhresp.fail_msg)
261
        self.assertRaises(errors.OpPrereqError, lhresp.Raise, "failed",
262
                          prereq=True, ecode=errors.ECODE_INVAL)
263
      else:
264
        self.assertFalse(lhresp.fail_msg)
265
        self.assertEqual(lhresp.payload, hash(name))
266
        lhresp.Raise("should not raise")
267

    
268
    self.assertEqual(http_proc.reqcount, len(nodes))
269

    
270
  def _GetInvalidResponseA(self, req):
271
    self.assertEqual(req.path, "/version")
272
    req.success = True
273
    req.resp_status_code = http.HTTP_OK
274
    req.resp_body = serializer.DumpJson(("This", "is", "an", "invalid",
275
                                         "response", "!", 1, 2, 3))
276

    
277
  def _GetInvalidResponseB(self, req):
278
    self.assertEqual(req.path, "/version")
279
    req.success = True
280
    req.resp_status_code = http.HTTP_OK
281
    req.resp_body = serializer.DumpJson("invalid response")
282

    
283
  def testInvalidResponse(self):
284
    resolver = rpc._StaticResolver(["oqo7lanhly.example.com"])
285
    proc = rpc._RpcProcessor(resolver, 19978)
286

    
287
    for fn in [self._GetInvalidResponseA, self._GetInvalidResponseB]:
288
      http_proc = _FakeRequestProcessor(fn)
289
      host = "oqo7lanhly.example.com"
290
      body = {host: ""}
291
      result = proc([host], "version", body, 60, NotImplemented,
292
                    _req_process_fn=http_proc)
293
      self.assertEqual(result.keys(), [host])
294
      lhresp = result[host]
295
      self.assertFalse(lhresp.offline)
296
      self.assertEqual(lhresp.node, host)
297
      self.assert_(lhresp.fail_msg)
298
      self.assertFalse(lhresp.payload)
299
      self.assertEqual(lhresp.call, "version")
300
      self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
301
      self.assertEqual(http_proc.reqcount, 1)
302

    
303
  def _GetBodyTestResponse(self, test_data, req):
304
    self.assertEqual(req.host, "192.0.2.84")
305
    self.assertEqual(req.port, 18700)
306
    self.assertEqual(req.path, "/upload_file")
307
    self.assertEqual(serializer.LoadJson(req.post_data), test_data)
308
    req.success = True
309
    req.resp_status_code = http.HTTP_OK
310
    req.resp_body = serializer.DumpJson((True, None))
311

    
312
  def testResponseBody(self):
313
    test_data = {
314
      "Hello": "World",
315
      "xyz": range(10),
316
      }
317
    resolver = rpc._StaticResolver(["192.0.2.84"])
318
    http_proc = _FakeRequestProcessor(compat.partial(self._GetBodyTestResponse,
319
                                                     test_data))
320
    proc = rpc._RpcProcessor(resolver, 18700)
321
    host = "node19759"
322
    body = {host: serializer.DumpJson(test_data)}
323
    result = proc([host], "upload_file", body, 30, NotImplemented,
324
                  _req_process_fn=http_proc)
325
    self.assertEqual(result.keys(), [host])
326
    lhresp = result[host]
327
    self.assertFalse(lhresp.offline)
328
    self.assertEqual(lhresp.node, host)
329
    self.assertFalse(lhresp.fail_msg)
330
    self.assertEqual(lhresp.payload, None)
331
    self.assertEqual(lhresp.call, "upload_file")
332
    lhresp.Raise("should not raise")
333
    self.assertEqual(http_proc.reqcount, 1)
334

    
335

    
336
class TestSsconfResolver(unittest.TestCase):
337
  def testSsconfLookup(self):
338
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
339
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
340
    node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
341
    ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
342
    result = rpc._SsconfResolver(True, node_list, NotImplemented,
343
                                 ssc=ssc, nslookup_fn=NotImplemented)
344
    self.assertEqual(result, zip(node_list, addr_list))
345

    
346
  def testNsLookup(self):
347
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
348
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
349
    ssc = GetFakeSimpleStoreClass(lambda _: [])
350
    node_addr_map = dict(zip(node_list, addr_list))
351
    nslookup_fn = lambda name, family=None: node_addr_map.get(name)
352
    result = rpc._SsconfResolver(True, node_list, NotImplemented,
353
                                 ssc=ssc, nslookup_fn=nslookup_fn)
354
    self.assertEqual(result, zip(node_list, addr_list))
355

    
356
  def testDisabledSsconfIp(self):
357
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
358
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
359
    ssc = GetFakeSimpleStoreClass(_RaiseNotImplemented)
360
    node_addr_map = dict(zip(node_list, addr_list))
361
    nslookup_fn = lambda name, family=None: node_addr_map.get(name)
362
    result = rpc._SsconfResolver(False, node_list, NotImplemented,
363
                                 ssc=ssc, nslookup_fn=nslookup_fn)
364
    self.assertEqual(result, zip(node_list, addr_list))
365

    
366
  def testBothLookups(self):
367
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
368
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
369
    n = len(addr_list) / 2
370
    node_addr_list = [" ".join(t) for t in zip(node_list[n:], addr_list[n:])]
371
    ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
372
    node_addr_map = dict(zip(node_list[:n], addr_list[:n]))
373
    nslookup_fn = lambda name, family=None: node_addr_map.get(name)
374
    result = rpc._SsconfResolver(True, node_list, NotImplemented,
375
                                 ssc=ssc, nslookup_fn=nslookup_fn)
376
    self.assertEqual(result, zip(node_list, addr_list))
377

    
378
  def testAddressLookupIPv6(self):
379
    addr_list = ["2001:db8::%d" % n for n in range(0, 255, 11)]
380
    node_list = ["node%d.example.com" % n for n in range(0, 255, 11)]
381
    node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
382
    ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
383
    result = rpc._SsconfResolver(True, node_list, NotImplemented,
384
                                 ssc=ssc, nslookup_fn=NotImplemented)
385
    self.assertEqual(result, zip(node_list, addr_list))
386

    
387

    
388
class TestStaticResolver(unittest.TestCase):
389
  def test(self):
390
    addresses = ["192.0.2.%d" % n for n in range(0, 123, 7)]
391
    nodes = ["node%s.example.com" % n for n in range(0, 123, 7)]
392
    res = rpc._StaticResolver(addresses)
393
    self.assertEqual(res(nodes, NotImplemented), zip(nodes, addresses))
394

    
395
  def testWrongLength(self):
396
    res = rpc._StaticResolver([])
397
    self.assertRaises(AssertionError, res, ["abc"], NotImplemented)
398

    
399

    
400
class TestNodeConfigResolver(unittest.TestCase):
401
  @staticmethod
402
  def _GetSingleOnlineNode(name):
403
    assert name == "node90.example.com"
404
    return objects.Node(name=name, offline=False, primary_ip="192.0.2.90")
405

    
406
  @staticmethod
407
  def _GetSingleOfflineNode(name):
408
    assert name == "node100.example.com"
409
    return objects.Node(name=name, offline=True, primary_ip="192.0.2.100")
410

    
411
  def testSingleOnline(self):
412
    self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOnlineNode,
413
                                             NotImplemented,
414
                                             ["node90.example.com"], None),
415
                     [("node90.example.com", "192.0.2.90")])
416

    
417
  def testSingleOffline(self):
418
    self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOfflineNode,
419
                                             NotImplemented,
420
                                             ["node100.example.com"], None),
421
                     [("node100.example.com", rpc._OFFLINE)])
422

    
423
  def testSingleOfflineWithAcceptOffline(self):
424
    fn = self._GetSingleOfflineNode
425
    assert fn("node100.example.com").offline
426
    self.assertEqual(rpc._NodeConfigResolver(fn, NotImplemented,
427
                                             ["node100.example.com"],
428
                                             rpc_defs.ACCEPT_OFFLINE_NODE),
429
                     [("node100.example.com", "192.0.2.100")])
430
    for i in [False, True, "", "Hello", 0, 1]:
431
      self.assertRaises(AssertionError, rpc._NodeConfigResolver,
432
                        fn, NotImplemented, ["node100.example.com"], i)
433

    
434
  def testUnknownSingleNode(self):
435
    self.assertEqual(rpc._NodeConfigResolver(lambda _: None, NotImplemented,
436
                                             ["node110.example.com"], None),
437
                     [("node110.example.com", "node110.example.com")])
438

    
439
  def testMultiEmpty(self):
440
    self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
441
                                             lambda: {},
442
                                             [], None),
443
                     [])
444

    
445
  def testMultiSomeOffline(self):
446
    nodes = dict(("node%s.example.com" % i,
447
                  objects.Node(name="node%s.example.com" % i,
448
                               offline=((i % 3) == 0),
449
                               primary_ip="192.0.2.%s" % i))
450
                  for i in range(1, 255))
451

    
452
    # Resolve no names
453
    self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
454
                                             lambda: nodes,
455
                                             [], None),
456
                     [])
457

    
458
    # Offline, online and unknown hosts
459
    self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
460
                                             lambda: nodes,
461
                                             ["node3.example.com",
462
                                              "node92.example.com",
463
                                              "node54.example.com",
464
                                              "unknown.example.com",],
465
                                             None), [
466
      ("node3.example.com", rpc._OFFLINE),
467
      ("node92.example.com", "192.0.2.92"),
468
      ("node54.example.com", rpc._OFFLINE),
469
      ("unknown.example.com", "unknown.example.com"),
470
      ])
471

    
472

    
473
class TestCompress(unittest.TestCase):
474
  def test(self):
475
    for data in ["", "Hello", "Hello World!\nnew\nlines"]:
476
      self.assertEqual(rpc._Compress(data),
477
                       (constants.RPC_ENCODING_NONE, data))
478

    
479
    for data in [512 * " ", 5242 * "Hello World!\n"]:
480
      compressed = rpc._Compress(data)
481
      self.assertEqual(len(compressed), 2)
482
      self.assertEqual(backend._Decompress(compressed), data)
483

    
484
  def testDecompression(self):
485
    self.assertRaises(AssertionError, backend._Decompress, "")
486
    self.assertRaises(AssertionError, backend._Decompress, [""])
487
    self.assertRaises(AssertionError, backend._Decompress,
488
                      ("unknown compression", "data"))
489
    self.assertRaises(Exception, backend._Decompress,
490
                      (constants.RPC_ENCODING_ZLIB_BASE64, "invalid zlib data"))
491

    
492

    
493
class TestRpcClientBase(unittest.TestCase):
494
  def testNoHosts(self):
495
    cdef = ("test_call", NotImplemented, None, constants.RPC_TMO_SLOW, [],
496
            None, None, NotImplemented)
497
    http_proc = _FakeRequestProcessor(NotImplemented)
498
    client = rpc._RpcClientBase(rpc._StaticResolver([]), NotImplemented,
499
                                _req_process_fn=http_proc)
500
    self.assertEqual(client._Call(cdef, [], []), {})
501

    
502
    # Test wrong number of arguments
503
    self.assertRaises(errors.ProgrammerError, client._Call,
504
                      cdef, [], [0, 1, 2])
505

    
506
  def testTimeout(self):
507
    def _CalcTimeout((arg1, arg2)):
508
      return arg1 + arg2
509

    
510
    def _VerifyRequest(exp_timeout, req):
511
      self.assertEqual(req.read_timeout, exp_timeout)
512

    
513
      req.success = True
514
      req.resp_status_code = http.HTTP_OK
515
      req.resp_body = serializer.DumpJson((True, hex(req.read_timeout)))
516

    
517
    resolver = rpc._StaticResolver([
518
      "192.0.2.1",
519
      "192.0.2.2",
520
      ])
521

    
522
    nodes = [
523
      "node1.example.com",
524
      "node2.example.com",
525
      ]
526

    
527
    tests = [(100, None, 100), (30, None, 30)]
528
    tests.extend((_CalcTimeout, i, i + 300)
529
                 for i in [0, 5, 16485, 30516])
530

    
531
    for timeout, arg1, exp_timeout in tests:
532
      cdef = ("test_call", NotImplemented, None, timeout, [
533
        ("arg1", None, NotImplemented),
534
        ("arg2", None, NotImplemented),
535
        ], None, None, NotImplemented)
536

    
537
      http_proc = _FakeRequestProcessor(compat.partial(_VerifyRequest,
538
                                                       exp_timeout))
539
      client = rpc._RpcClientBase(resolver, NotImplemented,
540
                                  _req_process_fn=http_proc)
541
      result = client._Call(cdef, nodes, [arg1, 300])
542
      self.assertEqual(len(result), len(nodes))
543
      self.assertTrue(compat.all(not res.fail_msg and
544
                                 res.payload == hex(exp_timeout)
545
                                 for res in result.values()))
546

    
547
  def testArgumentEncoder(self):
548
    (AT1, AT2) = range(1, 3)
549

    
550
    resolver = rpc._StaticResolver([
551
      "192.0.2.5",
552
      "192.0.2.6",
553
      ])
554

    
555
    nodes = [
556
      "node5.example.com",
557
      "node6.example.com",
558
      ]
559

    
560
    encoders = {
561
      AT1: hex,
562
      AT2: hash,
563
      }
564

    
565
    cdef = ("test_call", NotImplemented, None, constants.RPC_TMO_NORMAL, [
566
      ("arg0", None, NotImplemented),
567
      ("arg1", AT1, NotImplemented),
568
      ("arg1", AT2, NotImplemented),
569
      ], None, None, NotImplemented)
570

    
571
    def _VerifyRequest(req):
572
      req.success = True
573
      req.resp_status_code = http.HTTP_OK
574
      req.resp_body = serializer.DumpJson((True, req.post_data))
575

    
576
    http_proc = _FakeRequestProcessor(_VerifyRequest)
577

    
578
    for num in [0, 3796, 9032119]:
579
      client = rpc._RpcClientBase(resolver, encoders.get,
580
                                  _req_process_fn=http_proc)
581
      result = client._Call(cdef, nodes, ["foo", num, "Hello%s" % num])
582
      self.assertEqual(len(result), len(nodes))
583
      for res in result.values():
584
        self.assertFalse(res.fail_msg)
585
        self.assertEqual(serializer.LoadJson(res.payload),
586
                         ["foo", hex(num), hash("Hello%s" % num)])
587

    
588
  def testPostProc(self):
589
    def _VerifyRequest(nums, req):
590
      req.success = True
591
      req.resp_status_code = http.HTTP_OK
592
      req.resp_body = serializer.DumpJson((True, nums))
593

    
594
    resolver = rpc._StaticResolver([
595
      "192.0.2.90",
596
      "192.0.2.95",
597
      ])
598

    
599
    nodes = [
600
      "node90.example.com",
601
      "node95.example.com",
602
      ]
603

    
604
    def _PostProc(res):
605
      self.assertFalse(res.fail_msg)
606
      res.payload = sum(res.payload)
607
      return res
608

    
609
    cdef = ("test_call", NotImplemented, None, constants.RPC_TMO_NORMAL, [],
610
            None, _PostProc, NotImplemented)
611

    
612
    # Seeded random generator
613
    rnd = random.Random(20299)
614

    
615
    for i in [0, 4, 74, 1391]:
616
      nums = [rnd.randint(0, 1000) for _ in range(i)]
617
      http_proc = _FakeRequestProcessor(compat.partial(_VerifyRequest, nums))
618
      client = rpc._RpcClientBase(resolver, NotImplemented,
619
                                  _req_process_fn=http_proc)
620
      result = client._Call(cdef, nodes, [])
621
      self.assertEqual(len(result), len(nodes))
622
      for res in result.values():
623
        self.assertFalse(res.fail_msg)
624
        self.assertEqual(res.payload, sum(nums))
625

    
626
  def testPreProc(self):
627
    def _VerifyRequest(req):
628
      req.success = True
629
      req.resp_status_code = http.HTTP_OK
630
      req.resp_body = serializer.DumpJson((True, req.post_data))
631

    
632
    resolver = rpc._StaticResolver([
633
      "192.0.2.30",
634
      "192.0.2.35",
635
      ])
636

    
637
    nodes = [
638
      "node30.example.com",
639
      "node35.example.com",
640
      ]
641

    
642
    def _PreProc(node, data):
643
      self.assertEqual(len(data), 1)
644
      return data[0] + node
645

    
646
    cdef = ("test_call", NotImplemented, None, constants.RPC_TMO_NORMAL, [
647
      ("arg0", None, NotImplemented),
648
      ], _PreProc, None, NotImplemented)
649

    
650
    http_proc = _FakeRequestProcessor(_VerifyRequest)
651
    client = rpc._RpcClientBase(resolver, NotImplemented,
652
                                _req_process_fn=http_proc)
653

    
654
    for prefix in ["foo", "bar", "baz"]:
655
      result = client._Call(cdef, nodes, [prefix])
656
      self.assertEqual(len(result), len(nodes))
657
      for (idx, (node, res)) in enumerate(result.items()):
658
        self.assertFalse(res.fail_msg)
659
        self.assertEqual(serializer.LoadJson(res.payload), prefix + node)
660

    
661
  def testResolverOptions(self):
662
    def _VerifyRequest(req):
663
      req.success = True
664
      req.resp_status_code = http.HTTP_OK
665
      req.resp_body = serializer.DumpJson((True, req.post_data))
666

    
667
    nodes = [
668
      "node30.example.com",
669
      "node35.example.com",
670
      ]
671

    
672
    def _Resolver(expected, hosts, options):
673
      self.assertEqual(hosts, nodes)
674
      self.assertEqual(options, expected)
675
      return zip(hosts, nodes)
676

    
677
    def _DynamicResolverOptions((arg0, )):
678
      return sum(arg0)
679

    
680
    tests = [
681
      (None, None, None),
682
      (rpc_defs.ACCEPT_OFFLINE_NODE, None, rpc_defs.ACCEPT_OFFLINE_NODE),
683
      (False, None, False),
684
      (True, None, True),
685
      (0, None, 0),
686
      (_DynamicResolverOptions, [1, 2, 3], 6),
687
      (_DynamicResolverOptions, range(4, 19), 165),
688
      ]
689

    
690
    for (resolver_opts, arg0, expected) in tests:
691
      cdef = ("test_call", NotImplemented, resolver_opts,
692
              constants.RPC_TMO_NORMAL, [
693
        ("arg0", None, NotImplemented),
694
        ], None, None, NotImplemented)
695

    
696
      http_proc = _FakeRequestProcessor(_VerifyRequest)
697

    
698
      client = rpc._RpcClientBase(compat.partial(_Resolver, expected),
699
                                  NotImplemented, _req_process_fn=http_proc)
700
      result = client._Call(cdef, nodes, [arg0])
701
      self.assertEqual(len(result), len(nodes))
702
      for (idx, (node, res)) in enumerate(result.items()):
703
        self.assertFalse(res.fail_msg)
704

    
705

    
706
class _FakeConfigForRpcRunner:
707
  GetAllNodesInfo = NotImplemented
708

    
709
  def __init__(self, cluster=NotImplemented):
710
    self._cluster = cluster
711

    
712
  def GetNodeInfo(self, name):
713
    return objects.Node(name=name)
714

    
715
  def GetClusterInfo(self):
716
    return self._cluster
717

    
718
  def GetInstanceDiskParams(self, _):
719
    return constants.DISK_DT_DEFAULTS
720

    
721

    
722
class TestRpcRunner(unittest.TestCase):
723
  def testUploadFile(self):
724
    data = 1779 * "Hello World\n"
725

    
726
    tmpfile = tempfile.NamedTemporaryFile()
727
    tmpfile.write(data)
728
    tmpfile.flush()
729
    st = os.stat(tmpfile.name)
730

    
731
    def _VerifyRequest(req):
732
      (uldata, ) = serializer.LoadJson(req.post_data)
733
      self.assertEqual(len(uldata), 7)
734
      self.assertEqual(uldata[0], tmpfile.name)
735
      self.assertEqual(list(uldata[1]), list(rpc._Compress(data)))
736
      self.assertEqual(uldata[2], st.st_mode)
737
      self.assertEqual(uldata[3], "user%s" % os.getuid())
738
      self.assertEqual(uldata[4], "group%s" % os.getgid())
739
      self.assertTrue(uldata[5] is not None)
740
      self.assertEqual(uldata[6], st.st_mtime)
741

    
742
      req.success = True
743
      req.resp_status_code = http.HTTP_OK
744
      req.resp_body = serializer.DumpJson((True, None))
745

    
746
    http_proc = _FakeRequestProcessor(_VerifyRequest)
747

    
748
    std_runner = rpc.RpcRunner(_FakeConfigForRpcRunner(), None,
749
                               _req_process_fn=http_proc,
750
                               _getents=mocks.FakeGetentResolver)
751

    
752
    cfg_runner = rpc.ConfigRunner(None, ["192.0.2.13"],
753
                                  _req_process_fn=http_proc,
754
                                  _getents=mocks.FakeGetentResolver)
755

    
756
    nodes = [
757
      "node1.example.com",
758
      ]
759

    
760
    for runner in [std_runner, cfg_runner]:
761
      result = runner.call_upload_file(nodes, tmpfile.name)
762
      self.assertEqual(len(result), len(nodes))
763
      for (idx, (node, res)) in enumerate(result.items()):
764
        self.assertFalse(res.fail_msg)
765

    
766
  def testEncodeInstance(self):
767
    cluster = objects.Cluster(hvparams={
768
      constants.HT_KVM: {
769
        constants.HV_BLOCKDEV_PREFIX: "foo",
770
        },
771
      },
772
      beparams={
773
        constants.PP_DEFAULT: {
774
          constants.BE_MAXMEM: 8192,
775
          },
776
        },
777
      os_hvp={},
778
      osparams={
779
        "linux": {
780
          "role": "unknown",
781
          },
782
        })
783
    cluster.UpgradeConfig()
784

    
785
    inst = objects.Instance(name="inst1.example.com",
786
      hypervisor=constants.HT_FAKE,
787
      os="linux",
788
      hvparams={
789
        constants.HT_KVM: {
790
          constants.HV_BLOCKDEV_PREFIX: "bar",
791
          constants.HV_ROOT_PATH: "/tmp",
792
          },
793
        },
794
      beparams={
795
        constants.BE_MINMEM: 128,
796
        constants.BE_MAXMEM: 256,
797
        },
798
      nics=[
799
        objects.NIC(nicparams={
800
          constants.NIC_MODE: "mymode",
801
          }),
802
        ],
803
      disk_template=constants.DT_PLAIN,
804
      disks=[
805
        objects.Disk(dev_type=constants.LD_LV, size=4096,
806
                     logical_id=("vg", "disk6120")),
807
        objects.Disk(dev_type=constants.LD_LV, size=1024,
808
                     logical_id=("vg", "disk8508")),
809
        ])
810
    inst.UpgradeConfig()
811

    
812
    cfg = _FakeConfigForRpcRunner(cluster=cluster)
813
    runner = rpc.RpcRunner(cfg, None,
814
                           _req_process_fn=NotImplemented,
815
                           _getents=mocks.FakeGetentResolver)
816

    
817
    def _CheckBasics(result):
818
      self.assertEqual(result["name"], "inst1.example.com")
819
      self.assertEqual(result["os"], "linux")
820
      self.assertEqual(result["beparams"][constants.BE_MINMEM], 128)
821
      self.assertEqual(len(result["hvparams"]), 1)
822
      self.assertEqual(len(result["nics"]), 1)
823
      self.assertEqual(result["nics"][0]["nicparams"][constants.NIC_MODE],
824
                       "mymode")
825

    
826
    # Generic object serialization
827
    result = runner._encoder((rpc_defs.ED_OBJECT_DICT, inst))
828
    _CheckBasics(result)
829

    
830
    result = runner._encoder((rpc_defs.ED_OBJECT_DICT_LIST, 5 * [inst]))
831
    map(_CheckBasics, result)
832

    
833
    # Just an instance
834
    result = runner._encoder((rpc_defs.ED_INST_DICT, inst))
835
    _CheckBasics(result)
836
    self.assertEqual(result["beparams"][constants.BE_MAXMEM], 256)
837
    self.assertEqual(result["hvparams"][constants.HT_KVM], {
838
      constants.HV_BLOCKDEV_PREFIX: "bar",
839
      constants.HV_ROOT_PATH: "/tmp",
840
      })
841
    self.assertEqual(result["osparams"], {
842
      "role": "unknown",
843
      })
844

    
845
    # Instance with OS parameters
846
    result = runner._encoder((rpc_defs.ED_INST_DICT_OSP_DP, (inst, {
847
      "role": "webserver",
848
      "other": "field",
849
      })))
850
    _CheckBasics(result)
851
    self.assertEqual(result["beparams"][constants.BE_MAXMEM], 256)
852
    self.assertEqual(result["hvparams"][constants.HT_KVM], {
853
      constants.HV_BLOCKDEV_PREFIX: "bar",
854
      constants.HV_ROOT_PATH: "/tmp",
855
      })
856
    self.assertEqual(result["osparams"], {
857
      "role": "webserver",
858
      "other": "field",
859
      })
860

    
861
    # Instance with hypervisor and backend parameters
862
    result = runner._encoder((rpc_defs.ED_INST_DICT_HVP_BEP_DP, (inst, {
863
      constants.HT_KVM: {
864
        constants.HV_BOOT_ORDER: "xyz",
865
        },
866
      }, {
867
      constants.BE_VCPUS: 100,
868
      constants.BE_MAXMEM: 4096,
869
      })))
870
    _CheckBasics(result)
871
    self.assertEqual(result["beparams"][constants.BE_MAXMEM], 4096)
872
    self.assertEqual(result["beparams"][constants.BE_VCPUS], 100)
873
    self.assertEqual(result["hvparams"][constants.HT_KVM], {
874
      constants.HV_BOOT_ORDER: "xyz",
875
      })
876
    self.assertEqual(result["disks"], [{
877
      "dev_type": constants.LD_LV,
878
      "size": 4096,
879
      "logical_id": ("vg", "disk6120"),
880
      "params": constants.DISK_DT_DEFAULTS[inst.disk_template],
881
      }, {
882
      "dev_type": constants.LD_LV,
883
      "size": 1024,
884
      "logical_id": ("vg", "disk8508"),
885
      "params": constants.DISK_DT_DEFAULTS[inst.disk_template],
886
      }])
887

    
888
    self.assertTrue(compat.all(disk.params == {} for disk in inst.disks),
889
                    msg="Configuration objects were modified")
890

    
891

    
892
if __name__ == "__main__":
893
  testutils.GanetiTestProgram()