Statistics
| Branch: | Tag: | Revision:

root / test / ganeti.rpc_unittest.py @ 065be3f0

History | View | Annotate | Download (24.6 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2010, 2011 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for testing ganeti.rpc"""
23

    
24
import os
25
import sys
26
import unittest
27
import random
28

    
29
from ganeti import constants
30
from ganeti import compat
31
from ganeti import rpc
32
from ganeti import rpc_defs
33
from ganeti import http
34
from ganeti import errors
35
from ganeti import serializer
36
from ganeti import objects
37
from ganeti import backend
38

    
39
import testutils
40

    
41

    
42
class _FakeRequestProcessor:
43
  def __init__(self, response_fn):
44
    self._response_fn = response_fn
45
    self.reqcount = 0
46

    
47
  def __call__(self, reqs, lock_monitor_cb=None):
48
    assert lock_monitor_cb is None or callable(lock_monitor_cb)
49
    for req in reqs:
50
      self.reqcount += 1
51
      self._response_fn(req)
52

    
53

    
54
def GetFakeSimpleStoreClass(fn):
55
  class FakeSimpleStore:
56
    GetNodePrimaryIPList = fn
57
    GetPrimaryIPFamily = lambda _: None
58

    
59
  return FakeSimpleStore
60

    
61

    
62
class TestRpcProcessor(unittest.TestCase):
63
  def _FakeAddressLookup(self, map):
64
    return lambda node_list: [map.get(node) for node in node_list]
65

    
66
  def _GetVersionResponse(self, req):
67
    self.assertEqual(req.host, "127.0.0.1")
68
    self.assertEqual(req.port, 24094)
69
    self.assertEqual(req.path, "/version")
70
    self.assertEqual(req.read_timeout, rpc._TMO_URGENT)
71
    req.success = True
72
    req.resp_status_code = http.HTTP_OK
73
    req.resp_body = serializer.DumpJson((True, 123))
74

    
75
  def testVersionSuccess(self):
76
    resolver = rpc._StaticResolver(["127.0.0.1"])
77
    http_proc = _FakeRequestProcessor(self._GetVersionResponse)
78
    proc = rpc._RpcProcessor(resolver, 24094)
79
    result = proc(["localhost"], "version", {"localhost": ""}, 60,
80
                  NotImplemented, _req_process_fn=http_proc)
81
    self.assertEqual(result.keys(), ["localhost"])
82
    lhresp = result["localhost"]
83
    self.assertFalse(lhresp.offline)
84
    self.assertEqual(lhresp.node, "localhost")
85
    self.assertFalse(lhresp.fail_msg)
86
    self.assertEqual(lhresp.payload, 123)
87
    self.assertEqual(lhresp.call, "version")
88
    lhresp.Raise("should not raise")
89
    self.assertEqual(http_proc.reqcount, 1)
90

    
91
  def _ReadTimeoutResponse(self, req):
92
    self.assertEqual(req.host, "192.0.2.13")
93
    self.assertEqual(req.port, 19176)
94
    self.assertEqual(req.path, "/version")
95
    self.assertEqual(req.read_timeout, 12356)
96
    req.success = True
97
    req.resp_status_code = http.HTTP_OK
98
    req.resp_body = serializer.DumpJson((True, -1))
99

    
100
  def testReadTimeout(self):
101
    resolver = rpc._StaticResolver(["192.0.2.13"])
102
    http_proc = _FakeRequestProcessor(self._ReadTimeoutResponse)
103
    proc = rpc._RpcProcessor(resolver, 19176)
104
    host = "node31856"
105
    body = {host: ""}
106
    result = proc([host], "version", body, 12356, NotImplemented,
107
                  _req_process_fn=http_proc)
108
    self.assertEqual(result.keys(), [host])
109
    lhresp = result[host]
110
    self.assertFalse(lhresp.offline)
111
    self.assertEqual(lhresp.node, host)
112
    self.assertFalse(lhresp.fail_msg)
113
    self.assertEqual(lhresp.payload, -1)
114
    self.assertEqual(lhresp.call, "version")
115
    lhresp.Raise("should not raise")
116
    self.assertEqual(http_proc.reqcount, 1)
117

    
118
  def testOfflineNode(self):
119
    resolver = rpc._StaticResolver([rpc._OFFLINE])
120
    http_proc = _FakeRequestProcessor(NotImplemented)
121
    proc = rpc._RpcProcessor(resolver, 30668)
122
    host = "n17296"
123
    body = {host: ""}
124
    result = proc([host], "version", body, 60, NotImplemented,
125
                  _req_process_fn=http_proc)
126
    self.assertEqual(result.keys(), [host])
127
    lhresp = result[host]
128
    self.assertTrue(lhresp.offline)
129
    self.assertEqual(lhresp.node, host)
130
    self.assertTrue(lhresp.fail_msg)
131
    self.assertFalse(lhresp.payload)
132
    self.assertEqual(lhresp.call, "version")
133

    
134
    # With a message
135
    self.assertRaises(errors.OpExecError, lhresp.Raise, "should raise")
136

    
137
    # No message
138
    self.assertRaises(errors.OpExecError, lhresp.Raise, None)
139

    
140
    self.assertEqual(http_proc.reqcount, 0)
141

    
142
  def _GetMultiVersionResponse(self, req):
143
    self.assert_(req.host.startswith("node"))
144
    self.assertEqual(req.port, 23245)
145
    self.assertEqual(req.path, "/version")
146
    req.success = True
147
    req.resp_status_code = http.HTTP_OK
148
    req.resp_body = serializer.DumpJson((True, 987))
149

    
150
  def testMultiVersionSuccess(self):
151
    nodes = ["node%s" % i for i in range(50)]
152
    body = dict((n, "") for n in nodes)
153
    resolver = rpc._StaticResolver(nodes)
154
    http_proc = _FakeRequestProcessor(self._GetMultiVersionResponse)
155
    proc = rpc._RpcProcessor(resolver, 23245)
156
    result = proc(nodes, "version", body, 60, NotImplemented,
157
                  _req_process_fn=http_proc)
158
    self.assertEqual(sorted(result.keys()), sorted(nodes))
159

    
160
    for name in nodes:
161
      lhresp = result[name]
162
      self.assertFalse(lhresp.offline)
163
      self.assertEqual(lhresp.node, name)
164
      self.assertFalse(lhresp.fail_msg)
165
      self.assertEqual(lhresp.payload, 987)
166
      self.assertEqual(lhresp.call, "version")
167
      lhresp.Raise("should not raise")
168

    
169
    self.assertEqual(http_proc.reqcount, len(nodes))
170

    
171
  def _GetVersionResponseFail(self, errinfo, req):
172
    self.assertEqual(req.path, "/version")
173
    req.success = True
174
    req.resp_status_code = http.HTTP_OK
175
    req.resp_body = serializer.DumpJson((False, errinfo))
176

    
177
  def testVersionFailure(self):
178
    resolver = rpc._StaticResolver(["aef9ur4i.example.com"])
179
    proc = rpc._RpcProcessor(resolver, 5903)
180
    for errinfo in [None, "Unknown error"]:
181
      http_proc = \
182
        _FakeRequestProcessor(compat.partial(self._GetVersionResponseFail,
183
                                             errinfo))
184
      host = "aef9ur4i.example.com"
185
      body = {host: ""}
186
      result = proc(body.keys(), "version", body, 60, NotImplemented,
187
                    _req_process_fn=http_proc)
188
      self.assertEqual(result.keys(), [host])
189
      lhresp = result[host]
190
      self.assertFalse(lhresp.offline)
191
      self.assertEqual(lhresp.node, host)
192
      self.assert_(lhresp.fail_msg)
193
      self.assertFalse(lhresp.payload)
194
      self.assertEqual(lhresp.call, "version")
195
      self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
196
      self.assertEqual(http_proc.reqcount, 1)
197

    
198
  def _GetHttpErrorResponse(self, httperrnodes, failnodes, req):
199
    self.assertEqual(req.path, "/vg_list")
200
    self.assertEqual(req.port, 15165)
201

    
202
    if req.host in httperrnodes:
203
      req.success = False
204
      req.error = "Node set up for HTTP errors"
205

    
206
    elif req.host in failnodes:
207
      req.success = True
208
      req.resp_status_code = 404
209
      req.resp_body = serializer.DumpJson({
210
        "code": 404,
211
        "message": "Method not found",
212
        "explain": "Explanation goes here",
213
        })
214
    else:
215
      req.success = True
216
      req.resp_status_code = http.HTTP_OK
217
      req.resp_body = serializer.DumpJson((True, hash(req.host)))
218

    
219
  def testHttpError(self):
220
    nodes = ["uaf6pbbv%s" % i for i in range(50)]
221
    body = dict((n, "") for n in nodes)
222
    resolver = rpc._StaticResolver(nodes)
223

    
224
    httperrnodes = set(nodes[1::7])
225
    self.assertEqual(len(httperrnodes), 7)
226

    
227
    failnodes = set(nodes[2::3]) - httperrnodes
228
    self.assertEqual(len(failnodes), 14)
229

    
230
    self.assertEqual(len(set(nodes) - failnodes - httperrnodes), 29)
231

    
232
    proc = rpc._RpcProcessor(resolver, 15165)
233
    http_proc = \
234
      _FakeRequestProcessor(compat.partial(self._GetHttpErrorResponse,
235
                                           httperrnodes, failnodes))
236
    result = proc(nodes, "vg_list", body, rpc._TMO_URGENT, NotImplemented,
237
                  _req_process_fn=http_proc)
238
    self.assertEqual(sorted(result.keys()), sorted(nodes))
239

    
240
    for name in nodes:
241
      lhresp = result[name]
242
      self.assertFalse(lhresp.offline)
243
      self.assertEqual(lhresp.node, name)
244
      self.assertEqual(lhresp.call, "vg_list")
245

    
246
      if name in httperrnodes:
247
        self.assert_(lhresp.fail_msg)
248
        self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
249
      elif name in failnodes:
250
        self.assert_(lhresp.fail_msg)
251
        self.assertRaises(errors.OpPrereqError, lhresp.Raise, "failed",
252
                          prereq=True, ecode=errors.ECODE_INVAL)
253
      else:
254
        self.assertFalse(lhresp.fail_msg)
255
        self.assertEqual(lhresp.payload, hash(name))
256
        lhresp.Raise("should not raise")
257

    
258
    self.assertEqual(http_proc.reqcount, len(nodes))
259

    
260
  def _GetInvalidResponseA(self, req):
261
    self.assertEqual(req.path, "/version")
262
    req.success = True
263
    req.resp_status_code = http.HTTP_OK
264
    req.resp_body = serializer.DumpJson(("This", "is", "an", "invalid",
265
                                         "response", "!", 1, 2, 3))
266

    
267
  def _GetInvalidResponseB(self, req):
268
    self.assertEqual(req.path, "/version")
269
    req.success = True
270
    req.resp_status_code = http.HTTP_OK
271
    req.resp_body = serializer.DumpJson("invalid response")
272

    
273
  def testInvalidResponse(self):
274
    resolver = rpc._StaticResolver(["oqo7lanhly.example.com"])
275
    proc = rpc._RpcProcessor(resolver, 19978)
276

    
277
    for fn in [self._GetInvalidResponseA, self._GetInvalidResponseB]:
278
      http_proc = _FakeRequestProcessor(fn)
279
      host = "oqo7lanhly.example.com"
280
      body = {host: ""}
281
      result = proc([host], "version", body, 60, NotImplemented,
282
                    _req_process_fn=http_proc)
283
      self.assertEqual(result.keys(), [host])
284
      lhresp = result[host]
285
      self.assertFalse(lhresp.offline)
286
      self.assertEqual(lhresp.node, host)
287
      self.assert_(lhresp.fail_msg)
288
      self.assertFalse(lhresp.payload)
289
      self.assertEqual(lhresp.call, "version")
290
      self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
291
      self.assertEqual(http_proc.reqcount, 1)
292

    
293
  def _GetBodyTestResponse(self, test_data, req):
294
    self.assertEqual(req.host, "192.0.2.84")
295
    self.assertEqual(req.port, 18700)
296
    self.assertEqual(req.path, "/upload_file")
297
    self.assertEqual(serializer.LoadJson(req.post_data), test_data)
298
    req.success = True
299
    req.resp_status_code = http.HTTP_OK
300
    req.resp_body = serializer.DumpJson((True, None))
301

    
302
  def testResponseBody(self):
303
    test_data = {
304
      "Hello": "World",
305
      "xyz": range(10),
306
      }
307
    resolver = rpc._StaticResolver(["192.0.2.84"])
308
    http_proc = _FakeRequestProcessor(compat.partial(self._GetBodyTestResponse,
309
                                                     test_data))
310
    proc = rpc._RpcProcessor(resolver, 18700)
311
    host = "node19759"
312
    body = {host: serializer.DumpJson(test_data)}
313
    result = proc([host], "upload_file", body, 30, NotImplemented,
314
                  _req_process_fn=http_proc)
315
    self.assertEqual(result.keys(), [host])
316
    lhresp = result[host]
317
    self.assertFalse(lhresp.offline)
318
    self.assertEqual(lhresp.node, host)
319
    self.assertFalse(lhresp.fail_msg)
320
    self.assertEqual(lhresp.payload, None)
321
    self.assertEqual(lhresp.call, "upload_file")
322
    lhresp.Raise("should not raise")
323
    self.assertEqual(http_proc.reqcount, 1)
324

    
325

    
326
class TestSsconfResolver(unittest.TestCase):
327
  def testSsconfLookup(self):
328
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
329
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
330
    node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
331
    ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
332
    result = rpc._SsconfResolver(node_list, NotImplemented,
333
                                 ssc=ssc, nslookup_fn=NotImplemented)
334
    self.assertEqual(result, zip(node_list, addr_list))
335

    
336
  def testNsLookup(self):
337
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
338
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
339
    ssc = GetFakeSimpleStoreClass(lambda _: [])
340
    node_addr_map = dict(zip(node_list, addr_list))
341
    nslookup_fn = lambda name, family=None: node_addr_map.get(name)
342
    result = rpc._SsconfResolver(node_list, NotImplemented,
343
                                 ssc=ssc, nslookup_fn=nslookup_fn)
344
    self.assertEqual(result, zip(node_list, addr_list))
345

    
346
  def testBothLookups(self):
347
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
348
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
349
    n = len(addr_list) / 2
350
    node_addr_list = [" ".join(t) for t in zip(node_list[n:], addr_list[n:])]
351
    ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
352
    node_addr_map = dict(zip(node_list[:n], addr_list[:n]))
353
    nslookup_fn = lambda name, family=None: node_addr_map.get(name)
354
    result = rpc._SsconfResolver(node_list, NotImplemented,
355
                                 ssc=ssc, nslookup_fn=nslookup_fn)
356
    self.assertEqual(result, zip(node_list, addr_list))
357

    
358
  def testAddressLookupIPv6(self):
359
    addr_list = ["2001:db8::%d" % n for n in range(0, 255, 11)]
360
    node_list = ["node%d.example.com" % n for n in range(0, 255, 11)]
361
    node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
362
    ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
363
    result = rpc._SsconfResolver(node_list, NotImplemented,
364
                                 ssc=ssc, nslookup_fn=NotImplemented)
365
    self.assertEqual(result, zip(node_list, addr_list))
366

    
367

    
368
class TestStaticResolver(unittest.TestCase):
369
  def test(self):
370
    addresses = ["192.0.2.%d" % n for n in range(0, 123, 7)]
371
    nodes = ["node%s.example.com" % n for n in range(0, 123, 7)]
372
    res = rpc._StaticResolver(addresses)
373
    self.assertEqual(res(nodes, NotImplemented), zip(nodes, addresses))
374

    
375
  def testWrongLength(self):
376
    res = rpc._StaticResolver([])
377
    self.assertRaises(AssertionError, res, ["abc"], NotImplemented)
378

    
379

    
380
class TestNodeConfigResolver(unittest.TestCase):
381
  @staticmethod
382
  def _GetSingleOnlineNode(name):
383
    assert name == "node90.example.com"
384
    return objects.Node(name=name, offline=False, primary_ip="192.0.2.90")
385

    
386
  @staticmethod
387
  def _GetSingleOfflineNode(name):
388
    assert name == "node100.example.com"
389
    return objects.Node(name=name, offline=True, primary_ip="192.0.2.100")
390

    
391
  def testSingleOnline(self):
392
    self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOnlineNode,
393
                                             NotImplemented,
394
                                             ["node90.example.com"], None),
395
                     [("node90.example.com", "192.0.2.90")])
396

    
397
  def testSingleOffline(self):
398
    self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOfflineNode,
399
                                             NotImplemented,
400
                                             ["node100.example.com"], None),
401
                     [("node100.example.com", rpc._OFFLINE)])
402

    
403
  def testSingleOfflineWithAcceptOffline(self):
404
    fn = self._GetSingleOfflineNode
405
    assert fn("node100.example.com").offline
406
    self.assertEqual(rpc._NodeConfigResolver(fn, NotImplemented,
407
                                             ["node100.example.com"],
408
                                             rpc_defs.ACCEPT_OFFLINE_NODE),
409
                     [("node100.example.com", "192.0.2.100")])
410
    for i in [False, True, "", "Hello", 0, 1]:
411
      self.assertRaises(AssertionError, rpc._NodeConfigResolver,
412
                        fn, NotImplemented, ["node100.example.com"], i)
413

    
414
  def testUnknownSingleNode(self):
415
    self.assertEqual(rpc._NodeConfigResolver(lambda _: None, NotImplemented,
416
                                             ["node110.example.com"], None),
417
                     [("node110.example.com", "node110.example.com")])
418

    
419
  def testMultiEmpty(self):
420
    self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
421
                                             lambda: {},
422
                                             [], None),
423
                     [])
424

    
425
  def testMultiSomeOffline(self):
426
    nodes = dict(("node%s.example.com" % i,
427
                  objects.Node(name="node%s.example.com" % i,
428
                               offline=((i % 3) == 0),
429
                               primary_ip="192.0.2.%s" % i))
430
                  for i in range(1, 255))
431

    
432
    # Resolve no names
433
    self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
434
                                             lambda: nodes,
435
                                             [], None),
436
                     [])
437

    
438
    # Offline, online and unknown hosts
439
    self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
440
                                             lambda: nodes,
441
                                             ["node3.example.com",
442
                                              "node92.example.com",
443
                                              "node54.example.com",
444
                                              "unknown.example.com",],
445
                                             None), [
446
      ("node3.example.com", rpc._OFFLINE),
447
      ("node92.example.com", "192.0.2.92"),
448
      ("node54.example.com", rpc._OFFLINE),
449
      ("unknown.example.com", "unknown.example.com"),
450
      ])
451

    
452

    
453
class TestCompress(unittest.TestCase):
454
  def test(self):
455
    for data in ["", "Hello", "Hello World!\nnew\nlines"]:
456
      self.assertEqual(rpc._Compress(data),
457
                       (constants.RPC_ENCODING_NONE, data))
458

    
459
    for data in [512 * " ", 5242 * "Hello World!\n"]:
460
      compressed = rpc._Compress(data)
461
      self.assertEqual(len(compressed), 2)
462
      self.assertEqual(backend._Decompress(compressed), data)
463

    
464
  def testDecompression(self):
465
    self.assertRaises(AssertionError, backend._Decompress, "")
466
    self.assertRaises(AssertionError, backend._Decompress, [""])
467
    self.assertRaises(AssertionError, backend._Decompress,
468
                      ("unknown compression", "data"))
469
    self.assertRaises(Exception, backend._Decompress,
470
                      (constants.RPC_ENCODING_ZLIB_BASE64, "invalid zlib data"))
471

    
472

    
473
class TestRpcClientBase(unittest.TestCase):
474
  def testNoHosts(self):
475
    cdef = ("test_call", NotImplemented, None, rpc_defs.TMO_SLOW, [],
476
            None, None, NotImplemented)
477
    http_proc = _FakeRequestProcessor(NotImplemented)
478
    client = rpc._RpcClientBase(rpc._StaticResolver([]), NotImplemented,
479
                                _req_process_fn=http_proc)
480
    self.assertEqual(client._Call(cdef, [], []), {})
481

    
482
    # Test wrong number of arguments
483
    self.assertRaises(errors.ProgrammerError, client._Call,
484
                      cdef, [], [0, 1, 2])
485

    
486
  def testTimeout(self):
487
    def _CalcTimeout((arg1, arg2)):
488
      return arg1 + arg2
489

    
490
    def _VerifyRequest(exp_timeout, req):
491
      self.assertEqual(req.read_timeout, exp_timeout)
492

    
493
      req.success = True
494
      req.resp_status_code = http.HTTP_OK
495
      req.resp_body = serializer.DumpJson((True, hex(req.read_timeout)))
496

    
497
    resolver = rpc._StaticResolver([
498
      "192.0.2.1",
499
      "192.0.2.2",
500
      ])
501

    
502
    nodes = [
503
      "node1.example.com",
504
      "node2.example.com",
505
      ]
506

    
507
    tests = [(100, None, 100), (30, None, 30)]
508
    tests.extend((_CalcTimeout, i, i + 300)
509
                 for i in [0, 5, 16485, 30516])
510

    
511
    for timeout, arg1, exp_timeout in tests:
512
      cdef = ("test_call", NotImplemented, None, timeout, [
513
        ("arg1", None, NotImplemented),
514
        ("arg2", None, NotImplemented),
515
        ], None, None, NotImplemented)
516

    
517
      http_proc = _FakeRequestProcessor(compat.partial(_VerifyRequest,
518
                                                       exp_timeout))
519
      client = rpc._RpcClientBase(resolver, NotImplemented,
520
                                  _req_process_fn=http_proc)
521
      result = client._Call(cdef, nodes, [arg1, 300])
522
      self.assertEqual(len(result), len(nodes))
523
      self.assertTrue(compat.all(not res.fail_msg and
524
                                 res.payload == hex(exp_timeout)
525
                                 for res in result.values()))
526

    
527
  def testArgumentEncoder(self):
528
    (AT1, AT2) = range(1, 3)
529

    
530
    resolver = rpc._StaticResolver([
531
      "192.0.2.5",
532
      "192.0.2.6",
533
      ])
534

    
535
    nodes = [
536
      "node5.example.com",
537
      "node6.example.com",
538
      ]
539

    
540
    encoders = {
541
      AT1: hex,
542
      AT2: hash,
543
      }
544

    
545
    cdef = ("test_call", NotImplemented, None, rpc_defs.TMO_NORMAL, [
546
      ("arg0", None, NotImplemented),
547
      ("arg1", AT1, NotImplemented),
548
      ("arg1", AT2, NotImplemented),
549
      ], None, None, NotImplemented)
550

    
551
    def _VerifyRequest(req):
552
      req.success = True
553
      req.resp_status_code = http.HTTP_OK
554
      req.resp_body = serializer.DumpJson((True, req.post_data))
555

    
556
    http_proc = _FakeRequestProcessor(_VerifyRequest)
557

    
558
    for num in [0, 3796, 9032119]:
559
      client = rpc._RpcClientBase(resolver, encoders.get,
560
                                  _req_process_fn=http_proc)
561
      result = client._Call(cdef, nodes, ["foo", num, "Hello%s" % num])
562
      self.assertEqual(len(result), len(nodes))
563
      for res in result.values():
564
        self.assertFalse(res.fail_msg)
565
        self.assertEqual(serializer.LoadJson(res.payload),
566
                         ["foo", hex(num), hash("Hello%s" % num)])
567

    
568
  def testPostProc(self):
569
    def _VerifyRequest(nums, req):
570
      req.success = True
571
      req.resp_status_code = http.HTTP_OK
572
      req.resp_body = serializer.DumpJson((True, nums))
573

    
574
    resolver = rpc._StaticResolver([
575
      "192.0.2.90",
576
      "192.0.2.95",
577
      ])
578

    
579
    nodes = [
580
      "node90.example.com",
581
      "node95.example.com",
582
      ]
583

    
584
    def _PostProc(res):
585
      self.assertFalse(res.fail_msg)
586
      res.payload = sum(res.payload)
587
      return res
588

    
589
    cdef = ("test_call", NotImplemented, None, rpc_defs.TMO_NORMAL, [],
590
            None, _PostProc, NotImplemented)
591

    
592
    # Seeded random generator
593
    rnd = random.Random(20299)
594

    
595
    for i in [0, 4, 74, 1391]:
596
      nums = [rnd.randint(0, 1000) for _ in range(i)]
597
      http_proc = _FakeRequestProcessor(compat.partial(_VerifyRequest, nums))
598
      client = rpc._RpcClientBase(resolver, NotImplemented,
599
                                  _req_process_fn=http_proc)
600
      result = client._Call(cdef, nodes, [])
601
      self.assertEqual(len(result), len(nodes))
602
      for res in result.values():
603
        self.assertFalse(res.fail_msg)
604
        self.assertEqual(res.payload, sum(nums))
605

    
606
  def testPreProc(self):
607
    def _VerifyRequest(req):
608
      req.success = True
609
      req.resp_status_code = http.HTTP_OK
610
      req.resp_body = serializer.DumpJson((True, req.post_data))
611

    
612
    resolver = rpc._StaticResolver([
613
      "192.0.2.30",
614
      "192.0.2.35",
615
      ])
616

    
617
    nodes = [
618
      "node30.example.com",
619
      "node35.example.com",
620
      ]
621

    
622
    def _PreProc(node, data):
623
      self.assertEqual(len(data), 1)
624
      return data[0] + node
625

    
626
    cdef = ("test_call", NotImplemented, None, rpc_defs.TMO_NORMAL, [
627
      ("arg0", None, NotImplemented),
628
      ], _PreProc, None, NotImplemented)
629

    
630
    http_proc = _FakeRequestProcessor(_VerifyRequest)
631
    client = rpc._RpcClientBase(resolver, NotImplemented,
632
                                _req_process_fn=http_proc)
633

    
634
    for prefix in ["foo", "bar", "baz"]:
635
      result = client._Call(cdef, nodes, [prefix])
636
      self.assertEqual(len(result), len(nodes))
637
      for (idx, (node, res)) in enumerate(result.items()):
638
        self.assertFalse(res.fail_msg)
639
        self.assertEqual(serializer.LoadJson(res.payload), prefix + node)
640

    
641
  def testResolverOptions(self):
642
    def _VerifyRequest(req):
643
      req.success = True
644
      req.resp_status_code = http.HTTP_OK
645
      req.resp_body = serializer.DumpJson((True, req.post_data))
646

    
647
    nodes = [
648
      "node30.example.com",
649
      "node35.example.com",
650
      ]
651

    
652
    def _Resolver(expected, hosts, options):
653
      self.assertEqual(hosts, nodes)
654
      self.assertEqual(options, expected)
655
      return zip(hosts, nodes)
656

    
657
    def _DynamicResolverOptions((arg0, )):
658
      return sum(arg0)
659

    
660
    tests = [
661
      (None, None, None),
662
      (rpc_defs.ACCEPT_OFFLINE_NODE, None, rpc_defs.ACCEPT_OFFLINE_NODE),
663
      (False, None, False),
664
      (True, None, True),
665
      (0, None, 0),
666
      (_DynamicResolverOptions, [1, 2, 3], 6),
667
      (_DynamicResolverOptions, range(4, 19), 165),
668
      ]
669

    
670
    for (resolver_opts, arg0, expected) in tests:
671
      cdef = ("test_call", NotImplemented, resolver_opts, rpc_defs.TMO_NORMAL, [
672
        ("arg0", None, NotImplemented),
673
        ], None, None, NotImplemented)
674

    
675
      http_proc = _FakeRequestProcessor(_VerifyRequest)
676

    
677
      client = rpc._RpcClientBase(compat.partial(_Resolver, expected),
678
                                  NotImplemented, _req_process_fn=http_proc)
679
      result = client._Call(cdef, nodes, [arg0])
680
      self.assertEqual(len(result), len(nodes))
681
      for (idx, (node, res)) in enumerate(result.items()):
682
        self.assertFalse(res.fail_msg)
683

    
684

    
685
class TestRpcRunner(unittest.TestCase):
686
  def testUploadFile(self):
687
    runner = rpc.RpcRunner(_req_process_fn=http_proc)
688

    
689

    
690
if __name__ == "__main__":
691
  testutils.GanetiTestProgram()