Statistics
| Branch: | Tag: | Revision:

root / test / ganeti.rpc_unittest.py @ 7e6b6f1f

History | View | Annotate | Download (29.5 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2010, 2011 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for testing ganeti.rpc"""
23

    
24
import os
25
import sys
26
import unittest
27
import random
28
import tempfile
29

    
30
from ganeti import constants
31
from ganeti import compat
32
from ganeti import rpc
33
from ganeti import rpc_defs
34
from ganeti import http
35
from ganeti import errors
36
from ganeti import serializer
37
from ganeti import objects
38
from ganeti import backend
39

    
40
import testutils
41
import mocks
42

    
43

    
44
class _FakeRequestProcessor:
45
  def __init__(self, response_fn):
46
    self._response_fn = response_fn
47
    self.reqcount = 0
48

    
49
  def __call__(self, reqs, lock_monitor_cb=None):
50
    assert lock_monitor_cb is None or callable(lock_monitor_cb)
51
    for req in reqs:
52
      self.reqcount += 1
53
      self._response_fn(req)
54

    
55

    
56
def GetFakeSimpleStoreClass(fn):
57
  class FakeSimpleStore:
58
    GetNodePrimaryIPList = fn
59
    GetPrimaryIPFamily = lambda _: None
60

    
61
  return FakeSimpleStore
62

    
63

    
64
class TestRpcProcessor(unittest.TestCase):
65
  def _FakeAddressLookup(self, map):
66
    return lambda node_list: [map.get(node) for node in node_list]
67

    
68
  def _GetVersionResponse(self, req):
69
    self.assertEqual(req.host, "127.0.0.1")
70
    self.assertEqual(req.port, 24094)
71
    self.assertEqual(req.path, "/version")
72
    self.assertEqual(req.read_timeout, rpc._TMO_URGENT)
73
    req.success = True
74
    req.resp_status_code = http.HTTP_OK
75
    req.resp_body = serializer.DumpJson((True, 123))
76

    
77
  def testVersionSuccess(self):
78
    resolver = rpc._StaticResolver(["127.0.0.1"])
79
    http_proc = _FakeRequestProcessor(self._GetVersionResponse)
80
    proc = rpc._RpcProcessor(resolver, 24094)
81
    result = proc(["localhost"], "version", {"localhost": ""}, 60,
82
                  NotImplemented, _req_process_fn=http_proc)
83
    self.assertEqual(result.keys(), ["localhost"])
84
    lhresp = result["localhost"]
85
    self.assertFalse(lhresp.offline)
86
    self.assertEqual(lhresp.node, "localhost")
87
    self.assertFalse(lhresp.fail_msg)
88
    self.assertEqual(lhresp.payload, 123)
89
    self.assertEqual(lhresp.call, "version")
90
    lhresp.Raise("should not raise")
91
    self.assertEqual(http_proc.reqcount, 1)
92

    
93
  def _ReadTimeoutResponse(self, req):
94
    self.assertEqual(req.host, "192.0.2.13")
95
    self.assertEqual(req.port, 19176)
96
    self.assertEqual(req.path, "/version")
97
    self.assertEqual(req.read_timeout, 12356)
98
    req.success = True
99
    req.resp_status_code = http.HTTP_OK
100
    req.resp_body = serializer.DumpJson((True, -1))
101

    
102
  def testReadTimeout(self):
103
    resolver = rpc._StaticResolver(["192.0.2.13"])
104
    http_proc = _FakeRequestProcessor(self._ReadTimeoutResponse)
105
    proc = rpc._RpcProcessor(resolver, 19176)
106
    host = "node31856"
107
    body = {host: ""}
108
    result = proc([host], "version", body, 12356, NotImplemented,
109
                  _req_process_fn=http_proc)
110
    self.assertEqual(result.keys(), [host])
111
    lhresp = result[host]
112
    self.assertFalse(lhresp.offline)
113
    self.assertEqual(lhresp.node, host)
114
    self.assertFalse(lhresp.fail_msg)
115
    self.assertEqual(lhresp.payload, -1)
116
    self.assertEqual(lhresp.call, "version")
117
    lhresp.Raise("should not raise")
118
    self.assertEqual(http_proc.reqcount, 1)
119

    
120
  def testOfflineNode(self):
121
    resolver = rpc._StaticResolver([rpc._OFFLINE])
122
    http_proc = _FakeRequestProcessor(NotImplemented)
123
    proc = rpc._RpcProcessor(resolver, 30668)
124
    host = "n17296"
125
    body = {host: ""}
126
    result = proc([host], "version", body, 60, NotImplemented,
127
                  _req_process_fn=http_proc)
128
    self.assertEqual(result.keys(), [host])
129
    lhresp = result[host]
130
    self.assertTrue(lhresp.offline)
131
    self.assertEqual(lhresp.node, host)
132
    self.assertTrue(lhresp.fail_msg)
133
    self.assertFalse(lhresp.payload)
134
    self.assertEqual(lhresp.call, "version")
135

    
136
    # With a message
137
    self.assertRaises(errors.OpExecError, lhresp.Raise, "should raise")
138

    
139
    # No message
140
    self.assertRaises(errors.OpExecError, lhresp.Raise, None)
141

    
142
    self.assertEqual(http_proc.reqcount, 0)
143

    
144
  def _GetMultiVersionResponse(self, req):
145
    self.assert_(req.host.startswith("node"))
146
    self.assertEqual(req.port, 23245)
147
    self.assertEqual(req.path, "/version")
148
    req.success = True
149
    req.resp_status_code = http.HTTP_OK
150
    req.resp_body = serializer.DumpJson((True, 987))
151

    
152
  def testMultiVersionSuccess(self):
153
    nodes = ["node%s" % i for i in range(50)]
154
    body = dict((n, "") for n in nodes)
155
    resolver = rpc._StaticResolver(nodes)
156
    http_proc = _FakeRequestProcessor(self._GetMultiVersionResponse)
157
    proc = rpc._RpcProcessor(resolver, 23245)
158
    result = proc(nodes, "version", body, 60, NotImplemented,
159
                  _req_process_fn=http_proc)
160
    self.assertEqual(sorted(result.keys()), sorted(nodes))
161

    
162
    for name in nodes:
163
      lhresp = result[name]
164
      self.assertFalse(lhresp.offline)
165
      self.assertEqual(lhresp.node, name)
166
      self.assertFalse(lhresp.fail_msg)
167
      self.assertEqual(lhresp.payload, 987)
168
      self.assertEqual(lhresp.call, "version")
169
      lhresp.Raise("should not raise")
170

    
171
    self.assertEqual(http_proc.reqcount, len(nodes))
172

    
173
  def _GetVersionResponseFail(self, errinfo, req):
174
    self.assertEqual(req.path, "/version")
175
    req.success = True
176
    req.resp_status_code = http.HTTP_OK
177
    req.resp_body = serializer.DumpJson((False, errinfo))
178

    
179
  def testVersionFailure(self):
180
    resolver = rpc._StaticResolver(["aef9ur4i.example.com"])
181
    proc = rpc._RpcProcessor(resolver, 5903)
182
    for errinfo in [None, "Unknown error"]:
183
      http_proc = \
184
        _FakeRequestProcessor(compat.partial(self._GetVersionResponseFail,
185
                                             errinfo))
186
      host = "aef9ur4i.example.com"
187
      body = {host: ""}
188
      result = proc(body.keys(), "version", body, 60, NotImplemented,
189
                    _req_process_fn=http_proc)
190
      self.assertEqual(result.keys(), [host])
191
      lhresp = result[host]
192
      self.assertFalse(lhresp.offline)
193
      self.assertEqual(lhresp.node, host)
194
      self.assert_(lhresp.fail_msg)
195
      self.assertFalse(lhresp.payload)
196
      self.assertEqual(lhresp.call, "version")
197
      self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
198
      self.assertEqual(http_proc.reqcount, 1)
199

    
200
  def _GetHttpErrorResponse(self, httperrnodes, failnodes, req):
201
    self.assertEqual(req.path, "/vg_list")
202
    self.assertEqual(req.port, 15165)
203

    
204
    if req.host in httperrnodes:
205
      req.success = False
206
      req.error = "Node set up for HTTP errors"
207

    
208
    elif req.host in failnodes:
209
      req.success = True
210
      req.resp_status_code = 404
211
      req.resp_body = serializer.DumpJson({
212
        "code": 404,
213
        "message": "Method not found",
214
        "explain": "Explanation goes here",
215
        })
216
    else:
217
      req.success = True
218
      req.resp_status_code = http.HTTP_OK
219
      req.resp_body = serializer.DumpJson((True, hash(req.host)))
220

    
221
  def testHttpError(self):
222
    nodes = ["uaf6pbbv%s" % i for i in range(50)]
223
    body = dict((n, "") for n in nodes)
224
    resolver = rpc._StaticResolver(nodes)
225

    
226
    httperrnodes = set(nodes[1::7])
227
    self.assertEqual(len(httperrnodes), 7)
228

    
229
    failnodes = set(nodes[2::3]) - httperrnodes
230
    self.assertEqual(len(failnodes), 14)
231

    
232
    self.assertEqual(len(set(nodes) - failnodes - httperrnodes), 29)
233

    
234
    proc = rpc._RpcProcessor(resolver, 15165)
235
    http_proc = \
236
      _FakeRequestProcessor(compat.partial(self._GetHttpErrorResponse,
237
                                           httperrnodes, failnodes))
238
    result = proc(nodes, "vg_list", body, rpc._TMO_URGENT, NotImplemented,
239
                  _req_process_fn=http_proc)
240
    self.assertEqual(sorted(result.keys()), sorted(nodes))
241

    
242
    for name in nodes:
243
      lhresp = result[name]
244
      self.assertFalse(lhresp.offline)
245
      self.assertEqual(lhresp.node, name)
246
      self.assertEqual(lhresp.call, "vg_list")
247

    
248
      if name in httperrnodes:
249
        self.assert_(lhresp.fail_msg)
250
        self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
251
      elif name in failnodes:
252
        self.assert_(lhresp.fail_msg)
253
        self.assertRaises(errors.OpPrereqError, lhresp.Raise, "failed",
254
                          prereq=True, ecode=errors.ECODE_INVAL)
255
      else:
256
        self.assertFalse(lhresp.fail_msg)
257
        self.assertEqual(lhresp.payload, hash(name))
258
        lhresp.Raise("should not raise")
259

    
260
    self.assertEqual(http_proc.reqcount, len(nodes))
261

    
262
  def _GetInvalidResponseA(self, req):
263
    self.assertEqual(req.path, "/version")
264
    req.success = True
265
    req.resp_status_code = http.HTTP_OK
266
    req.resp_body = serializer.DumpJson(("This", "is", "an", "invalid",
267
                                         "response", "!", 1, 2, 3))
268

    
269
  def _GetInvalidResponseB(self, req):
270
    self.assertEqual(req.path, "/version")
271
    req.success = True
272
    req.resp_status_code = http.HTTP_OK
273
    req.resp_body = serializer.DumpJson("invalid response")
274

    
275
  def testInvalidResponse(self):
276
    resolver = rpc._StaticResolver(["oqo7lanhly.example.com"])
277
    proc = rpc._RpcProcessor(resolver, 19978)
278

    
279
    for fn in [self._GetInvalidResponseA, self._GetInvalidResponseB]:
280
      http_proc = _FakeRequestProcessor(fn)
281
      host = "oqo7lanhly.example.com"
282
      body = {host: ""}
283
      result = proc([host], "version", body, 60, NotImplemented,
284
                    _req_process_fn=http_proc)
285
      self.assertEqual(result.keys(), [host])
286
      lhresp = result[host]
287
      self.assertFalse(lhresp.offline)
288
      self.assertEqual(lhresp.node, host)
289
      self.assert_(lhresp.fail_msg)
290
      self.assertFalse(lhresp.payload)
291
      self.assertEqual(lhresp.call, "version")
292
      self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
293
      self.assertEqual(http_proc.reqcount, 1)
294

    
295
  def _GetBodyTestResponse(self, test_data, req):
296
    self.assertEqual(req.host, "192.0.2.84")
297
    self.assertEqual(req.port, 18700)
298
    self.assertEqual(req.path, "/upload_file")
299
    self.assertEqual(serializer.LoadJson(req.post_data), test_data)
300
    req.success = True
301
    req.resp_status_code = http.HTTP_OK
302
    req.resp_body = serializer.DumpJson((True, None))
303

    
304
  def testResponseBody(self):
305
    test_data = {
306
      "Hello": "World",
307
      "xyz": range(10),
308
      }
309
    resolver = rpc._StaticResolver(["192.0.2.84"])
310
    http_proc = _FakeRequestProcessor(compat.partial(self._GetBodyTestResponse,
311
                                                     test_data))
312
    proc = rpc._RpcProcessor(resolver, 18700)
313
    host = "node19759"
314
    body = {host: serializer.DumpJson(test_data)}
315
    result = proc([host], "upload_file", body, 30, NotImplemented,
316
                  _req_process_fn=http_proc)
317
    self.assertEqual(result.keys(), [host])
318
    lhresp = result[host]
319
    self.assertFalse(lhresp.offline)
320
    self.assertEqual(lhresp.node, host)
321
    self.assertFalse(lhresp.fail_msg)
322
    self.assertEqual(lhresp.payload, None)
323
    self.assertEqual(lhresp.call, "upload_file")
324
    lhresp.Raise("should not raise")
325
    self.assertEqual(http_proc.reqcount, 1)
326

    
327

    
328
class TestSsconfResolver(unittest.TestCase):
329
  def testSsconfLookup(self):
330
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
331
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
332
    node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
333
    ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
334
    result = rpc._SsconfResolver(node_list, NotImplemented,
335
                                 ssc=ssc, nslookup_fn=NotImplemented)
336
    self.assertEqual(result, zip(node_list, addr_list))
337

    
338
  def testNsLookup(self):
339
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
340
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
341
    ssc = GetFakeSimpleStoreClass(lambda _: [])
342
    node_addr_map = dict(zip(node_list, addr_list))
343
    nslookup_fn = lambda name, family=None: node_addr_map.get(name)
344
    result = rpc._SsconfResolver(node_list, NotImplemented,
345
                                 ssc=ssc, nslookup_fn=nslookup_fn)
346
    self.assertEqual(result, zip(node_list, addr_list))
347

    
348
  def testBothLookups(self):
349
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
350
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
351
    n = len(addr_list) / 2
352
    node_addr_list = [" ".join(t) for t in zip(node_list[n:], addr_list[n:])]
353
    ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
354
    node_addr_map = dict(zip(node_list[:n], addr_list[:n]))
355
    nslookup_fn = lambda name, family=None: node_addr_map.get(name)
356
    result = rpc._SsconfResolver(node_list, NotImplemented,
357
                                 ssc=ssc, nslookup_fn=nslookup_fn)
358
    self.assertEqual(result, zip(node_list, addr_list))
359

    
360
  def testAddressLookupIPv6(self):
361
    addr_list = ["2001:db8::%d" % n for n in range(0, 255, 11)]
362
    node_list = ["node%d.example.com" % n for n in range(0, 255, 11)]
363
    node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
364
    ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
365
    result = rpc._SsconfResolver(node_list, NotImplemented,
366
                                 ssc=ssc, nslookup_fn=NotImplemented)
367
    self.assertEqual(result, zip(node_list, addr_list))
368

    
369

    
370
class TestStaticResolver(unittest.TestCase):
371
  def test(self):
372
    addresses = ["192.0.2.%d" % n for n in range(0, 123, 7)]
373
    nodes = ["node%s.example.com" % n for n in range(0, 123, 7)]
374
    res = rpc._StaticResolver(addresses)
375
    self.assertEqual(res(nodes, NotImplemented), zip(nodes, addresses))
376

    
377
  def testWrongLength(self):
378
    res = rpc._StaticResolver([])
379
    self.assertRaises(AssertionError, res, ["abc"], NotImplemented)
380

    
381

    
382
class TestNodeConfigResolver(unittest.TestCase):
383
  @staticmethod
384
  def _GetSingleOnlineNode(name):
385
    assert name == "node90.example.com"
386
    return objects.Node(name=name, offline=False, primary_ip="192.0.2.90")
387

    
388
  @staticmethod
389
  def _GetSingleOfflineNode(name):
390
    assert name == "node100.example.com"
391
    return objects.Node(name=name, offline=True, primary_ip="192.0.2.100")
392

    
393
  def testSingleOnline(self):
394
    self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOnlineNode,
395
                                             NotImplemented,
396
                                             ["node90.example.com"], None),
397
                     [("node90.example.com", "192.0.2.90")])
398

    
399
  def testSingleOffline(self):
400
    self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOfflineNode,
401
                                             NotImplemented,
402
                                             ["node100.example.com"], None),
403
                     [("node100.example.com", rpc._OFFLINE)])
404

    
405
  def testSingleOfflineWithAcceptOffline(self):
406
    fn = self._GetSingleOfflineNode
407
    assert fn("node100.example.com").offline
408
    self.assertEqual(rpc._NodeConfigResolver(fn, NotImplemented,
409
                                             ["node100.example.com"],
410
                                             rpc_defs.ACCEPT_OFFLINE_NODE),
411
                     [("node100.example.com", "192.0.2.100")])
412
    for i in [False, True, "", "Hello", 0, 1]:
413
      self.assertRaises(AssertionError, rpc._NodeConfigResolver,
414
                        fn, NotImplemented, ["node100.example.com"], i)
415

    
416
  def testUnknownSingleNode(self):
417
    self.assertEqual(rpc._NodeConfigResolver(lambda _: None, NotImplemented,
418
                                             ["node110.example.com"], None),
419
                     [("node110.example.com", "node110.example.com")])
420

    
421
  def testMultiEmpty(self):
422
    self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
423
                                             lambda: {},
424
                                             [], None),
425
                     [])
426

    
427
  def testMultiSomeOffline(self):
428
    nodes = dict(("node%s.example.com" % i,
429
                  objects.Node(name="node%s.example.com" % i,
430
                               offline=((i % 3) == 0),
431
                               primary_ip="192.0.2.%s" % i))
432
                  for i in range(1, 255))
433

    
434
    # Resolve no names
435
    self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
436
                                             lambda: nodes,
437
                                             [], None),
438
                     [])
439

    
440
    # Offline, online and unknown hosts
441
    self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
442
                                             lambda: nodes,
443
                                             ["node3.example.com",
444
                                              "node92.example.com",
445
                                              "node54.example.com",
446
                                              "unknown.example.com",],
447
                                             None), [
448
      ("node3.example.com", rpc._OFFLINE),
449
      ("node92.example.com", "192.0.2.92"),
450
      ("node54.example.com", rpc._OFFLINE),
451
      ("unknown.example.com", "unknown.example.com"),
452
      ])
453

    
454

    
455
class TestCompress(unittest.TestCase):
456
  def test(self):
457
    for data in ["", "Hello", "Hello World!\nnew\nlines"]:
458
      self.assertEqual(rpc._Compress(data),
459
                       (constants.RPC_ENCODING_NONE, data))
460

    
461
    for data in [512 * " ", 5242 * "Hello World!\n"]:
462
      compressed = rpc._Compress(data)
463
      self.assertEqual(len(compressed), 2)
464
      self.assertEqual(backend._Decompress(compressed), data)
465

    
466
  def testDecompression(self):
467
    self.assertRaises(AssertionError, backend._Decompress, "")
468
    self.assertRaises(AssertionError, backend._Decompress, [""])
469
    self.assertRaises(AssertionError, backend._Decompress,
470
                      ("unknown compression", "data"))
471
    self.assertRaises(Exception, backend._Decompress,
472
                      (constants.RPC_ENCODING_ZLIB_BASE64, "invalid zlib data"))
473

    
474

    
475
class TestRpcClientBase(unittest.TestCase):
476
  def testNoHosts(self):
477
    cdef = ("test_call", NotImplemented, None, rpc_defs.TMO_SLOW, [],
478
            None, None, NotImplemented)
479
    http_proc = _FakeRequestProcessor(NotImplemented)
480
    client = rpc._RpcClientBase(rpc._StaticResolver([]), NotImplemented,
481
                                _req_process_fn=http_proc)
482
    self.assertEqual(client._Call(cdef, [], []), {})
483

    
484
    # Test wrong number of arguments
485
    self.assertRaises(errors.ProgrammerError, client._Call,
486
                      cdef, [], [0, 1, 2])
487

    
488
  def testTimeout(self):
489
    def _CalcTimeout((arg1, arg2)):
490
      return arg1 + arg2
491

    
492
    def _VerifyRequest(exp_timeout, req):
493
      self.assertEqual(req.read_timeout, exp_timeout)
494

    
495
      req.success = True
496
      req.resp_status_code = http.HTTP_OK
497
      req.resp_body = serializer.DumpJson((True, hex(req.read_timeout)))
498

    
499
    resolver = rpc._StaticResolver([
500
      "192.0.2.1",
501
      "192.0.2.2",
502
      ])
503

    
504
    nodes = [
505
      "node1.example.com",
506
      "node2.example.com",
507
      ]
508

    
509
    tests = [(100, None, 100), (30, None, 30)]
510
    tests.extend((_CalcTimeout, i, i + 300)
511
                 for i in [0, 5, 16485, 30516])
512

    
513
    for timeout, arg1, exp_timeout in tests:
514
      cdef = ("test_call", NotImplemented, None, timeout, [
515
        ("arg1", None, NotImplemented),
516
        ("arg2", None, NotImplemented),
517
        ], None, None, NotImplemented)
518

    
519
      http_proc = _FakeRequestProcessor(compat.partial(_VerifyRequest,
520
                                                       exp_timeout))
521
      client = rpc._RpcClientBase(resolver, NotImplemented,
522
                                  _req_process_fn=http_proc)
523
      result = client._Call(cdef, nodes, [arg1, 300])
524
      self.assertEqual(len(result), len(nodes))
525
      self.assertTrue(compat.all(not res.fail_msg and
526
                                 res.payload == hex(exp_timeout)
527
                                 for res in result.values()))
528

    
529
  def testArgumentEncoder(self):
530
    (AT1, AT2) = range(1, 3)
531

    
532
    resolver = rpc._StaticResolver([
533
      "192.0.2.5",
534
      "192.0.2.6",
535
      ])
536

    
537
    nodes = [
538
      "node5.example.com",
539
      "node6.example.com",
540
      ]
541

    
542
    encoders = {
543
      AT1: hex,
544
      AT2: hash,
545
      }
546

    
547
    cdef = ("test_call", NotImplemented, None, rpc_defs.TMO_NORMAL, [
548
      ("arg0", None, NotImplemented),
549
      ("arg1", AT1, NotImplemented),
550
      ("arg1", AT2, NotImplemented),
551
      ], None, None, NotImplemented)
552

    
553
    def _VerifyRequest(req):
554
      req.success = True
555
      req.resp_status_code = http.HTTP_OK
556
      req.resp_body = serializer.DumpJson((True, req.post_data))
557

    
558
    http_proc = _FakeRequestProcessor(_VerifyRequest)
559

    
560
    for num in [0, 3796, 9032119]:
561
      client = rpc._RpcClientBase(resolver, encoders.get,
562
                                  _req_process_fn=http_proc)
563
      result = client._Call(cdef, nodes, ["foo", num, "Hello%s" % num])
564
      self.assertEqual(len(result), len(nodes))
565
      for res in result.values():
566
        self.assertFalse(res.fail_msg)
567
        self.assertEqual(serializer.LoadJson(res.payload),
568
                         ["foo", hex(num), hash("Hello%s" % num)])
569

    
570
  def testPostProc(self):
571
    def _VerifyRequest(nums, req):
572
      req.success = True
573
      req.resp_status_code = http.HTTP_OK
574
      req.resp_body = serializer.DumpJson((True, nums))
575

    
576
    resolver = rpc._StaticResolver([
577
      "192.0.2.90",
578
      "192.0.2.95",
579
      ])
580

    
581
    nodes = [
582
      "node90.example.com",
583
      "node95.example.com",
584
      ]
585

    
586
    def _PostProc(res):
587
      self.assertFalse(res.fail_msg)
588
      res.payload = sum(res.payload)
589
      return res
590

    
591
    cdef = ("test_call", NotImplemented, None, rpc_defs.TMO_NORMAL, [],
592
            None, _PostProc, NotImplemented)
593

    
594
    # Seeded random generator
595
    rnd = random.Random(20299)
596

    
597
    for i in [0, 4, 74, 1391]:
598
      nums = [rnd.randint(0, 1000) for _ in range(i)]
599
      http_proc = _FakeRequestProcessor(compat.partial(_VerifyRequest, nums))
600
      client = rpc._RpcClientBase(resolver, NotImplemented,
601
                                  _req_process_fn=http_proc)
602
      result = client._Call(cdef, nodes, [])
603
      self.assertEqual(len(result), len(nodes))
604
      for res in result.values():
605
        self.assertFalse(res.fail_msg)
606
        self.assertEqual(res.payload, sum(nums))
607

    
608
  def testPreProc(self):
609
    def _VerifyRequest(req):
610
      req.success = True
611
      req.resp_status_code = http.HTTP_OK
612
      req.resp_body = serializer.DumpJson((True, req.post_data))
613

    
614
    resolver = rpc._StaticResolver([
615
      "192.0.2.30",
616
      "192.0.2.35",
617
      ])
618

    
619
    nodes = [
620
      "node30.example.com",
621
      "node35.example.com",
622
      ]
623

    
624
    def _PreProc(node, data):
625
      self.assertEqual(len(data), 1)
626
      return data[0] + node
627

    
628
    cdef = ("test_call", NotImplemented, None, rpc_defs.TMO_NORMAL, [
629
      ("arg0", None, NotImplemented),
630
      ], _PreProc, None, NotImplemented)
631

    
632
    http_proc = _FakeRequestProcessor(_VerifyRequest)
633
    client = rpc._RpcClientBase(resolver, NotImplemented,
634
                                _req_process_fn=http_proc)
635

    
636
    for prefix in ["foo", "bar", "baz"]:
637
      result = client._Call(cdef, nodes, [prefix])
638
      self.assertEqual(len(result), len(nodes))
639
      for (idx, (node, res)) in enumerate(result.items()):
640
        self.assertFalse(res.fail_msg)
641
        self.assertEqual(serializer.LoadJson(res.payload), prefix + node)
642

    
643
  def testResolverOptions(self):
644
    def _VerifyRequest(req):
645
      req.success = True
646
      req.resp_status_code = http.HTTP_OK
647
      req.resp_body = serializer.DumpJson((True, req.post_data))
648

    
649
    nodes = [
650
      "node30.example.com",
651
      "node35.example.com",
652
      ]
653

    
654
    def _Resolver(expected, hosts, options):
655
      self.assertEqual(hosts, nodes)
656
      self.assertEqual(options, expected)
657
      return zip(hosts, nodes)
658

    
659
    def _DynamicResolverOptions((arg0, )):
660
      return sum(arg0)
661

    
662
    tests = [
663
      (None, None, None),
664
      (rpc_defs.ACCEPT_OFFLINE_NODE, None, rpc_defs.ACCEPT_OFFLINE_NODE),
665
      (False, None, False),
666
      (True, None, True),
667
      (0, None, 0),
668
      (_DynamicResolverOptions, [1, 2, 3], 6),
669
      (_DynamicResolverOptions, range(4, 19), 165),
670
      ]
671

    
672
    for (resolver_opts, arg0, expected) in tests:
673
      cdef = ("test_call", NotImplemented, resolver_opts, rpc_defs.TMO_NORMAL, [
674
        ("arg0", None, NotImplemented),
675
        ], None, None, NotImplemented)
676

    
677
      http_proc = _FakeRequestProcessor(_VerifyRequest)
678

    
679
      client = rpc._RpcClientBase(compat.partial(_Resolver, expected),
680
                                  NotImplemented, _req_process_fn=http_proc)
681
      result = client._Call(cdef, nodes, [arg0])
682
      self.assertEqual(len(result), len(nodes))
683
      for (idx, (node, res)) in enumerate(result.items()):
684
        self.assertFalse(res.fail_msg)
685

    
686

    
687
class _FakeConfigForRpcRunner:
688
  GetAllNodesInfo = NotImplemented
689

    
690
  def __init__(self, cluster=NotImplemented):
691
    self._cluster = cluster
692

    
693
  def GetNodeInfo(self, name):
694
    return objects.Node(name=name)
695

    
696
  def GetClusterInfo(self):
697
    return self._cluster
698

    
699

    
700
class TestRpcRunner(unittest.TestCase):
701
  def testUploadFile(self):
702
    data = 1779 * "Hello World\n"
703

    
704
    tmpfile = tempfile.NamedTemporaryFile()
705
    tmpfile.write(data)
706
    tmpfile.flush()
707
    st = os.stat(tmpfile.name)
708

    
709
    def _VerifyRequest(req):
710
      (uldata, ) = serializer.LoadJson(req.post_data)
711
      self.assertEqual(len(uldata), 7)
712
      self.assertEqual(uldata[0], tmpfile.name)
713
      self.assertEqual(list(uldata[1]), list(rpc._Compress(data)))
714
      self.assertEqual(uldata[2], st.st_mode)
715
      self.assertEqual(uldata[3], "user%s" % os.getuid())
716
      self.assertEqual(uldata[4], "group%s" % os.getgid())
717
      self.assertTrue(uldata[5] is not None)
718
      self.assertEqual(uldata[6], st.st_mtime)
719

    
720
      req.success = True
721
      req.resp_status_code = http.HTTP_OK
722
      req.resp_body = serializer.DumpJson((True, None))
723

    
724
    http_proc = _FakeRequestProcessor(_VerifyRequest)
725

    
726
    std_runner = rpc.RpcRunner(_FakeConfigForRpcRunner(), None,
727
                               _req_process_fn=http_proc,
728
                               _getents=mocks.FakeGetentResolver)
729

    
730
    cfg_runner = rpc.ConfigRunner(None, ["192.0.2.13"],
731
                                  _req_process_fn=http_proc,
732
                                  _getents=mocks.FakeGetentResolver)
733

    
734
    nodes = [
735
      "node1.example.com",
736
      ]
737

    
738
    for runner in [std_runner, cfg_runner]:
739
      result = runner.call_upload_file(nodes, tmpfile.name)
740
      self.assertEqual(len(result), len(nodes))
741
      for (idx, (node, res)) in enumerate(result.items()):
742
        self.assertFalse(res.fail_msg)
743

    
744
  def testEncodeInstance(self):
745
    cluster = objects.Cluster(hvparams={
746
      constants.HT_KVM: {
747
        constants.HV_BLOCKDEV_PREFIX: "foo",
748
        },
749
      },
750
      beparams={
751
        constants.PP_DEFAULT: {
752
          constants.BE_MAXMEM: 8192,
753
          },
754
        },
755
      os_hvp={},
756
      osparams={
757
        "linux": {
758
          "role": "unknown",
759
          },
760
        })
761
    cluster.UpgradeConfig()
762

    
763
    inst = objects.Instance(name="inst1.example.com",
764
      hypervisor=constants.HT_FAKE,
765
      os="linux",
766
      hvparams={
767
        constants.HT_KVM: {
768
          constants.HV_BLOCKDEV_PREFIX: "bar",
769
          constants.HV_ROOT_PATH: "/tmp",
770
          },
771
        },
772
      beparams={
773
        constants.BE_MINMEM: 128,
774
        constants.BE_MAXMEM: 256,
775
        },
776
      nics=[
777
        objects.NIC(nicparams={
778
          constants.NIC_MODE: "mymode",
779
          }),
780
        ],
781
      disks=[])
782
    inst.UpgradeConfig()
783

    
784
    cfg = _FakeConfigForRpcRunner(cluster=cluster)
785
    runner = rpc.RpcRunner(cfg, None,
786
                           _req_process_fn=NotImplemented,
787
                           _getents=mocks.FakeGetentResolver)
788

    
789
    def _CheckBasics(result):
790
      self.assertEqual(result["name"], "inst1.example.com")
791
      self.assertEqual(result["os"], "linux")
792
      self.assertEqual(result["beparams"][constants.BE_MINMEM], 128)
793
      self.assertEqual(len(result["hvparams"]), 1)
794
      self.assertEqual(len(result["nics"]), 1)
795
      self.assertEqual(result["nics"][0]["nicparams"][constants.NIC_MODE],
796
                       "mymode")
797

    
798
    # Generic object serialization
799
    result = runner._encoder((rpc_defs.ED_OBJECT_DICT, inst))
800
    _CheckBasics(result)
801

    
802
    result = runner._encoder((rpc_defs.ED_OBJECT_DICT_LIST, 5 * [inst]))
803
    map(_CheckBasics, result)
804

    
805
    # Just an instance
806
    result = runner._encoder((rpc_defs.ED_INST_DICT, inst))
807
    _CheckBasics(result)
808
    self.assertEqual(result["beparams"][constants.BE_MAXMEM], 256)
809
    self.assertEqual(result["hvparams"][constants.HT_KVM], {
810
      constants.HV_BLOCKDEV_PREFIX: "bar",
811
      constants.HV_ROOT_PATH: "/tmp",
812
      })
813
    self.assertEqual(result["osparams"], {
814
      "role": "unknown",
815
      })
816

    
817
    # Instance with OS parameters
818
    result = runner._encoder((rpc_defs.ED_INST_DICT_OSP, (inst, {
819
      "role": "webserver",
820
      "other": "field",
821
      })))
822
    _CheckBasics(result)
823
    self.assertEqual(result["beparams"][constants.BE_MAXMEM], 256)
824
    self.assertEqual(result["hvparams"][constants.HT_KVM], {
825
      constants.HV_BLOCKDEV_PREFIX: "bar",
826
      constants.HV_ROOT_PATH: "/tmp",
827
      })
828
    self.assertEqual(result["osparams"], {
829
      "role": "webserver",
830
      "other": "field",
831
      })
832

    
833
    # Instance with hypervisor and backend parameters
834
    result = runner._encoder((rpc_defs.ED_INST_DICT_HVP_BEP, (inst, {
835
      constants.HT_KVM: {
836
        constants.HV_BOOT_ORDER: "xyz",
837
        },
838
      }, {
839
      constants.BE_VCPUS: 100,
840
      constants.BE_MAXMEM: 4096,
841
      })))
842
    _CheckBasics(result)
843
    self.assertEqual(result["beparams"][constants.BE_MAXMEM], 4096)
844
    self.assertEqual(result["beparams"][constants.BE_VCPUS], 100)
845
    self.assertEqual(result["hvparams"][constants.HT_KVM], {
846
      constants.HV_BOOT_ORDER: "xyz",
847
      })
848

    
849

    
850
if __name__ == "__main__":
851
  testutils.GanetiTestProgram()