Statistics
| Branch: | Tag: | Revision:

root / test / ganeti.rpc_unittest.py @ 601dfcbb

History | View | Annotate | Download (26 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2010, 2011 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for testing ganeti.rpc"""
23

    
24
import os
25
import sys
26
import unittest
27
import random
28
import tempfile
29

    
30
from ganeti import constants
31
from ganeti import compat
32
from ganeti import rpc
33
from ganeti import rpc_defs
34
from ganeti import http
35
from ganeti import errors
36
from ganeti import serializer
37
from ganeti import objects
38
from ganeti import backend
39

    
40
import testutils
41
import mocks
42

    
43

    
44
class _FakeRequestProcessor:
45
  def __init__(self, response_fn):
46
    self._response_fn = response_fn
47
    self.reqcount = 0
48

    
49
  def __call__(self, reqs, lock_monitor_cb=None):
50
    assert lock_monitor_cb is None or callable(lock_monitor_cb)
51
    for req in reqs:
52
      self.reqcount += 1
53
      self._response_fn(req)
54

    
55

    
56
def GetFakeSimpleStoreClass(fn):
57
  class FakeSimpleStore:
58
    GetNodePrimaryIPList = fn
59
    GetPrimaryIPFamily = lambda _: None
60

    
61
  return FakeSimpleStore
62

    
63

    
64
class TestRpcProcessor(unittest.TestCase):
65
  def _FakeAddressLookup(self, map):
66
    return lambda node_list: [map.get(node) for node in node_list]
67

    
68
  def _GetVersionResponse(self, req):
69
    self.assertEqual(req.host, "127.0.0.1")
70
    self.assertEqual(req.port, 24094)
71
    self.assertEqual(req.path, "/version")
72
    self.assertEqual(req.read_timeout, rpc._TMO_URGENT)
73
    req.success = True
74
    req.resp_status_code = http.HTTP_OK
75
    req.resp_body = serializer.DumpJson((True, 123))
76

    
77
  def testVersionSuccess(self):
78
    resolver = rpc._StaticResolver(["127.0.0.1"])
79
    http_proc = _FakeRequestProcessor(self._GetVersionResponse)
80
    proc = rpc._RpcProcessor(resolver, 24094)
81
    result = proc(["localhost"], "version", {"localhost": ""}, 60,
82
                  NotImplemented, _req_process_fn=http_proc)
83
    self.assertEqual(result.keys(), ["localhost"])
84
    lhresp = result["localhost"]
85
    self.assertFalse(lhresp.offline)
86
    self.assertEqual(lhresp.node, "localhost")
87
    self.assertFalse(lhresp.fail_msg)
88
    self.assertEqual(lhresp.payload, 123)
89
    self.assertEqual(lhresp.call, "version")
90
    lhresp.Raise("should not raise")
91
    self.assertEqual(http_proc.reqcount, 1)
92

    
93
  def _ReadTimeoutResponse(self, req):
94
    self.assertEqual(req.host, "192.0.2.13")
95
    self.assertEqual(req.port, 19176)
96
    self.assertEqual(req.path, "/version")
97
    self.assertEqual(req.read_timeout, 12356)
98
    req.success = True
99
    req.resp_status_code = http.HTTP_OK
100
    req.resp_body = serializer.DumpJson((True, -1))
101

    
102
  def testReadTimeout(self):
103
    resolver = rpc._StaticResolver(["192.0.2.13"])
104
    http_proc = _FakeRequestProcessor(self._ReadTimeoutResponse)
105
    proc = rpc._RpcProcessor(resolver, 19176)
106
    host = "node31856"
107
    body = {host: ""}
108
    result = proc([host], "version", body, 12356, NotImplemented,
109
                  _req_process_fn=http_proc)
110
    self.assertEqual(result.keys(), [host])
111
    lhresp = result[host]
112
    self.assertFalse(lhresp.offline)
113
    self.assertEqual(lhresp.node, host)
114
    self.assertFalse(lhresp.fail_msg)
115
    self.assertEqual(lhresp.payload, -1)
116
    self.assertEqual(lhresp.call, "version")
117
    lhresp.Raise("should not raise")
118
    self.assertEqual(http_proc.reqcount, 1)
119

    
120
  def testOfflineNode(self):
121
    resolver = rpc._StaticResolver([rpc._OFFLINE])
122
    http_proc = _FakeRequestProcessor(NotImplemented)
123
    proc = rpc._RpcProcessor(resolver, 30668)
124
    host = "n17296"
125
    body = {host: ""}
126
    result = proc([host], "version", body, 60, NotImplemented,
127
                  _req_process_fn=http_proc)
128
    self.assertEqual(result.keys(), [host])
129
    lhresp = result[host]
130
    self.assertTrue(lhresp.offline)
131
    self.assertEqual(lhresp.node, host)
132
    self.assertTrue(lhresp.fail_msg)
133
    self.assertFalse(lhresp.payload)
134
    self.assertEqual(lhresp.call, "version")
135

    
136
    # With a message
137
    self.assertRaises(errors.OpExecError, lhresp.Raise, "should raise")
138

    
139
    # No message
140
    self.assertRaises(errors.OpExecError, lhresp.Raise, None)
141

    
142
    self.assertEqual(http_proc.reqcount, 0)
143

    
144
  def _GetMultiVersionResponse(self, req):
145
    self.assert_(req.host.startswith("node"))
146
    self.assertEqual(req.port, 23245)
147
    self.assertEqual(req.path, "/version")
148
    req.success = True
149
    req.resp_status_code = http.HTTP_OK
150
    req.resp_body = serializer.DumpJson((True, 987))
151

    
152
  def testMultiVersionSuccess(self):
153
    nodes = ["node%s" % i for i in range(50)]
154
    body = dict((n, "") for n in nodes)
155
    resolver = rpc._StaticResolver(nodes)
156
    http_proc = _FakeRequestProcessor(self._GetMultiVersionResponse)
157
    proc = rpc._RpcProcessor(resolver, 23245)
158
    result = proc(nodes, "version", body, 60, NotImplemented,
159
                  _req_process_fn=http_proc)
160
    self.assertEqual(sorted(result.keys()), sorted(nodes))
161

    
162
    for name in nodes:
163
      lhresp = result[name]
164
      self.assertFalse(lhresp.offline)
165
      self.assertEqual(lhresp.node, name)
166
      self.assertFalse(lhresp.fail_msg)
167
      self.assertEqual(lhresp.payload, 987)
168
      self.assertEqual(lhresp.call, "version")
169
      lhresp.Raise("should not raise")
170

    
171
    self.assertEqual(http_proc.reqcount, len(nodes))
172

    
173
  def _GetVersionResponseFail(self, errinfo, req):
174
    self.assertEqual(req.path, "/version")
175
    req.success = True
176
    req.resp_status_code = http.HTTP_OK
177
    req.resp_body = serializer.DumpJson((False, errinfo))
178

    
179
  def testVersionFailure(self):
180
    resolver = rpc._StaticResolver(["aef9ur4i.example.com"])
181
    proc = rpc._RpcProcessor(resolver, 5903)
182
    for errinfo in [None, "Unknown error"]:
183
      http_proc = \
184
        _FakeRequestProcessor(compat.partial(self._GetVersionResponseFail,
185
                                             errinfo))
186
      host = "aef9ur4i.example.com"
187
      body = {host: ""}
188
      result = proc(body.keys(), "version", body, 60, NotImplemented,
189
                    _req_process_fn=http_proc)
190
      self.assertEqual(result.keys(), [host])
191
      lhresp = result[host]
192
      self.assertFalse(lhresp.offline)
193
      self.assertEqual(lhresp.node, host)
194
      self.assert_(lhresp.fail_msg)
195
      self.assertFalse(lhresp.payload)
196
      self.assertEqual(lhresp.call, "version")
197
      self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
198
      self.assertEqual(http_proc.reqcount, 1)
199

    
200
  def _GetHttpErrorResponse(self, httperrnodes, failnodes, req):
201
    self.assertEqual(req.path, "/vg_list")
202
    self.assertEqual(req.port, 15165)
203

    
204
    if req.host in httperrnodes:
205
      req.success = False
206
      req.error = "Node set up for HTTP errors"
207

    
208
    elif req.host in failnodes:
209
      req.success = True
210
      req.resp_status_code = 404
211
      req.resp_body = serializer.DumpJson({
212
        "code": 404,
213
        "message": "Method not found",
214
        "explain": "Explanation goes here",
215
        })
216
    else:
217
      req.success = True
218
      req.resp_status_code = http.HTTP_OK
219
      req.resp_body = serializer.DumpJson((True, hash(req.host)))
220

    
221
  def testHttpError(self):
222
    nodes = ["uaf6pbbv%s" % i for i in range(50)]
223
    body = dict((n, "") for n in nodes)
224
    resolver = rpc._StaticResolver(nodes)
225

    
226
    httperrnodes = set(nodes[1::7])
227
    self.assertEqual(len(httperrnodes), 7)
228

    
229
    failnodes = set(nodes[2::3]) - httperrnodes
230
    self.assertEqual(len(failnodes), 14)
231

    
232
    self.assertEqual(len(set(nodes) - failnodes - httperrnodes), 29)
233

    
234
    proc = rpc._RpcProcessor(resolver, 15165)
235
    http_proc = \
236
      _FakeRequestProcessor(compat.partial(self._GetHttpErrorResponse,
237
                                           httperrnodes, failnodes))
238
    result = proc(nodes, "vg_list", body, rpc._TMO_URGENT, NotImplemented,
239
                  _req_process_fn=http_proc)
240
    self.assertEqual(sorted(result.keys()), sorted(nodes))
241

    
242
    for name in nodes:
243
      lhresp = result[name]
244
      self.assertFalse(lhresp.offline)
245
      self.assertEqual(lhresp.node, name)
246
      self.assertEqual(lhresp.call, "vg_list")
247

    
248
      if name in httperrnodes:
249
        self.assert_(lhresp.fail_msg)
250
        self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
251
      elif name in failnodes:
252
        self.assert_(lhresp.fail_msg)
253
        self.assertRaises(errors.OpPrereqError, lhresp.Raise, "failed",
254
                          prereq=True, ecode=errors.ECODE_INVAL)
255
      else:
256
        self.assertFalse(lhresp.fail_msg)
257
        self.assertEqual(lhresp.payload, hash(name))
258
        lhresp.Raise("should not raise")
259

    
260
    self.assertEqual(http_proc.reqcount, len(nodes))
261

    
262
  def _GetInvalidResponseA(self, req):
263
    self.assertEqual(req.path, "/version")
264
    req.success = True
265
    req.resp_status_code = http.HTTP_OK
266
    req.resp_body = serializer.DumpJson(("This", "is", "an", "invalid",
267
                                         "response", "!", 1, 2, 3))
268

    
269
  def _GetInvalidResponseB(self, req):
270
    self.assertEqual(req.path, "/version")
271
    req.success = True
272
    req.resp_status_code = http.HTTP_OK
273
    req.resp_body = serializer.DumpJson("invalid response")
274

    
275
  def testInvalidResponse(self):
276
    resolver = rpc._StaticResolver(["oqo7lanhly.example.com"])
277
    proc = rpc._RpcProcessor(resolver, 19978)
278

    
279
    for fn in [self._GetInvalidResponseA, self._GetInvalidResponseB]:
280
      http_proc = _FakeRequestProcessor(fn)
281
      host = "oqo7lanhly.example.com"
282
      body = {host: ""}
283
      result = proc([host], "version", body, 60, NotImplemented,
284
                    _req_process_fn=http_proc)
285
      self.assertEqual(result.keys(), [host])
286
      lhresp = result[host]
287
      self.assertFalse(lhresp.offline)
288
      self.assertEqual(lhresp.node, host)
289
      self.assert_(lhresp.fail_msg)
290
      self.assertFalse(lhresp.payload)
291
      self.assertEqual(lhresp.call, "version")
292
      self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
293
      self.assertEqual(http_proc.reqcount, 1)
294

    
295
  def _GetBodyTestResponse(self, test_data, req):
296
    self.assertEqual(req.host, "192.0.2.84")
297
    self.assertEqual(req.port, 18700)
298
    self.assertEqual(req.path, "/upload_file")
299
    self.assertEqual(serializer.LoadJson(req.post_data), test_data)
300
    req.success = True
301
    req.resp_status_code = http.HTTP_OK
302
    req.resp_body = serializer.DumpJson((True, None))
303

    
304
  def testResponseBody(self):
305
    test_data = {
306
      "Hello": "World",
307
      "xyz": range(10),
308
      }
309
    resolver = rpc._StaticResolver(["192.0.2.84"])
310
    http_proc = _FakeRequestProcessor(compat.partial(self._GetBodyTestResponse,
311
                                                     test_data))
312
    proc = rpc._RpcProcessor(resolver, 18700)
313
    host = "node19759"
314
    body = {host: serializer.DumpJson(test_data)}
315
    result = proc([host], "upload_file", body, 30, NotImplemented,
316
                  _req_process_fn=http_proc)
317
    self.assertEqual(result.keys(), [host])
318
    lhresp = result[host]
319
    self.assertFalse(lhresp.offline)
320
    self.assertEqual(lhresp.node, host)
321
    self.assertFalse(lhresp.fail_msg)
322
    self.assertEqual(lhresp.payload, None)
323
    self.assertEqual(lhresp.call, "upload_file")
324
    lhresp.Raise("should not raise")
325
    self.assertEqual(http_proc.reqcount, 1)
326

    
327

    
328
class TestSsconfResolver(unittest.TestCase):
329
  def testSsconfLookup(self):
330
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
331
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
332
    node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
333
    ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
334
    result = rpc._SsconfResolver(node_list, NotImplemented,
335
                                 ssc=ssc, nslookup_fn=NotImplemented)
336
    self.assertEqual(result, zip(node_list, addr_list))
337

    
338
  def testNsLookup(self):
339
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
340
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
341
    ssc = GetFakeSimpleStoreClass(lambda _: [])
342
    node_addr_map = dict(zip(node_list, addr_list))
343
    nslookup_fn = lambda name, family=None: node_addr_map.get(name)
344
    result = rpc._SsconfResolver(node_list, NotImplemented,
345
                                 ssc=ssc, nslookup_fn=nslookup_fn)
346
    self.assertEqual(result, zip(node_list, addr_list))
347

    
348
  def testBothLookups(self):
349
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
350
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
351
    n = len(addr_list) / 2
352
    node_addr_list = [" ".join(t) for t in zip(node_list[n:], addr_list[n:])]
353
    ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
354
    node_addr_map = dict(zip(node_list[:n], addr_list[:n]))
355
    nslookup_fn = lambda name, family=None: node_addr_map.get(name)
356
    result = rpc._SsconfResolver(node_list, NotImplemented,
357
                                 ssc=ssc, nslookup_fn=nslookup_fn)
358
    self.assertEqual(result, zip(node_list, addr_list))
359

    
360
  def testAddressLookupIPv6(self):
361
    addr_list = ["2001:db8::%d" % n for n in range(0, 255, 11)]
362
    node_list = ["node%d.example.com" % n for n in range(0, 255, 11)]
363
    node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
364
    ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
365
    result = rpc._SsconfResolver(node_list, NotImplemented,
366
                                 ssc=ssc, nslookup_fn=NotImplemented)
367
    self.assertEqual(result, zip(node_list, addr_list))
368

    
369

    
370
class TestStaticResolver(unittest.TestCase):
371
  def test(self):
372
    addresses = ["192.0.2.%d" % n for n in range(0, 123, 7)]
373
    nodes = ["node%s.example.com" % n for n in range(0, 123, 7)]
374
    res = rpc._StaticResolver(addresses)
375
    self.assertEqual(res(nodes, NotImplemented), zip(nodes, addresses))
376

    
377
  def testWrongLength(self):
378
    res = rpc._StaticResolver([])
379
    self.assertRaises(AssertionError, res, ["abc"], NotImplemented)
380

    
381

    
382
class TestNodeConfigResolver(unittest.TestCase):
383
  @staticmethod
384
  def _GetSingleOnlineNode(name):
385
    assert name == "node90.example.com"
386
    return objects.Node(name=name, offline=False, primary_ip="192.0.2.90")
387

    
388
  @staticmethod
389
  def _GetSingleOfflineNode(name):
390
    assert name == "node100.example.com"
391
    return objects.Node(name=name, offline=True, primary_ip="192.0.2.100")
392

    
393
  def testSingleOnline(self):
394
    self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOnlineNode,
395
                                             NotImplemented,
396
                                             ["node90.example.com"], None),
397
                     [("node90.example.com", "192.0.2.90")])
398

    
399
  def testSingleOffline(self):
400
    self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOfflineNode,
401
                                             NotImplemented,
402
                                             ["node100.example.com"], None),
403
                     [("node100.example.com", rpc._OFFLINE)])
404

    
405
  def testSingleOfflineWithAcceptOffline(self):
406
    fn = self._GetSingleOfflineNode
407
    assert fn("node100.example.com").offline
408
    self.assertEqual(rpc._NodeConfigResolver(fn, NotImplemented,
409
                                             ["node100.example.com"],
410
                                             rpc_defs.ACCEPT_OFFLINE_NODE),
411
                     [("node100.example.com", "192.0.2.100")])
412
    for i in [False, True, "", "Hello", 0, 1]:
413
      self.assertRaises(AssertionError, rpc._NodeConfigResolver,
414
                        fn, NotImplemented, ["node100.example.com"], i)
415

    
416
  def testUnknownSingleNode(self):
417
    self.assertEqual(rpc._NodeConfigResolver(lambda _: None, NotImplemented,
418
                                             ["node110.example.com"], None),
419
                     [("node110.example.com", "node110.example.com")])
420

    
421
  def testMultiEmpty(self):
422
    self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
423
                                             lambda: {},
424
                                             [], None),
425
                     [])
426

    
427
  def testMultiSomeOffline(self):
428
    nodes = dict(("node%s.example.com" % i,
429
                  objects.Node(name="node%s.example.com" % i,
430
                               offline=((i % 3) == 0),
431
                               primary_ip="192.0.2.%s" % i))
432
                  for i in range(1, 255))
433

    
434
    # Resolve no names
435
    self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
436
                                             lambda: nodes,
437
                                             [], None),
438
                     [])
439

    
440
    # Offline, online and unknown hosts
441
    self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
442
                                             lambda: nodes,
443
                                             ["node3.example.com",
444
                                              "node92.example.com",
445
                                              "node54.example.com",
446
                                              "unknown.example.com",],
447
                                             None), [
448
      ("node3.example.com", rpc._OFFLINE),
449
      ("node92.example.com", "192.0.2.92"),
450
      ("node54.example.com", rpc._OFFLINE),
451
      ("unknown.example.com", "unknown.example.com"),
452
      ])
453

    
454

    
455
class TestCompress(unittest.TestCase):
456
  def test(self):
457
    for data in ["", "Hello", "Hello World!\nnew\nlines"]:
458
      self.assertEqual(rpc._Compress(data),
459
                       (constants.RPC_ENCODING_NONE, data))
460

    
461
    for data in [512 * " ", 5242 * "Hello World!\n"]:
462
      compressed = rpc._Compress(data)
463
      self.assertEqual(len(compressed), 2)
464
      self.assertEqual(backend._Decompress(compressed), data)
465

    
466
  def testDecompression(self):
467
    self.assertRaises(AssertionError, backend._Decompress, "")
468
    self.assertRaises(AssertionError, backend._Decompress, [""])
469
    self.assertRaises(AssertionError, backend._Decompress,
470
                      ("unknown compression", "data"))
471
    self.assertRaises(Exception, backend._Decompress,
472
                      (constants.RPC_ENCODING_ZLIB_BASE64, "invalid zlib data"))
473

    
474

    
475
class TestRpcClientBase(unittest.TestCase):
476
  def testNoHosts(self):
477
    cdef = ("test_call", NotImplemented, None, rpc_defs.TMO_SLOW, [],
478
            None, None, NotImplemented)
479
    http_proc = _FakeRequestProcessor(NotImplemented)
480
    client = rpc._RpcClientBase(rpc._StaticResolver([]), NotImplemented,
481
                                _req_process_fn=http_proc)
482
    self.assertEqual(client._Call(cdef, [], []), {})
483

    
484
    # Test wrong number of arguments
485
    self.assertRaises(errors.ProgrammerError, client._Call,
486
                      cdef, [], [0, 1, 2])
487

    
488
  def testTimeout(self):
489
    def _CalcTimeout((arg1, arg2)):
490
      return arg1 + arg2
491

    
492
    def _VerifyRequest(exp_timeout, req):
493
      self.assertEqual(req.read_timeout, exp_timeout)
494

    
495
      req.success = True
496
      req.resp_status_code = http.HTTP_OK
497
      req.resp_body = serializer.DumpJson((True, hex(req.read_timeout)))
498

    
499
    resolver = rpc._StaticResolver([
500
      "192.0.2.1",
501
      "192.0.2.2",
502
      ])
503

    
504
    nodes = [
505
      "node1.example.com",
506
      "node2.example.com",
507
      ]
508

    
509
    tests = [(100, None, 100), (30, None, 30)]
510
    tests.extend((_CalcTimeout, i, i + 300)
511
                 for i in [0, 5, 16485, 30516])
512

    
513
    for timeout, arg1, exp_timeout in tests:
514
      cdef = ("test_call", NotImplemented, None, timeout, [
515
        ("arg1", None, NotImplemented),
516
        ("arg2", None, NotImplemented),
517
        ], None, None, NotImplemented)
518

    
519
      http_proc = _FakeRequestProcessor(compat.partial(_VerifyRequest,
520
                                                       exp_timeout))
521
      client = rpc._RpcClientBase(resolver, NotImplemented,
522
                                  _req_process_fn=http_proc)
523
      result = client._Call(cdef, nodes, [arg1, 300])
524
      self.assertEqual(len(result), len(nodes))
525
      self.assertTrue(compat.all(not res.fail_msg and
526
                                 res.payload == hex(exp_timeout)
527
                                 for res in result.values()))
528

    
529
  def testArgumentEncoder(self):
530
    (AT1, AT2) = range(1, 3)
531

    
532
    resolver = rpc._StaticResolver([
533
      "192.0.2.5",
534
      "192.0.2.6",
535
      ])
536

    
537
    nodes = [
538
      "node5.example.com",
539
      "node6.example.com",
540
      ]
541

    
542
    encoders = {
543
      AT1: hex,
544
      AT2: hash,
545
      }
546

    
547
    cdef = ("test_call", NotImplemented, None, rpc_defs.TMO_NORMAL, [
548
      ("arg0", None, NotImplemented),
549
      ("arg1", AT1, NotImplemented),
550
      ("arg1", AT2, NotImplemented),
551
      ], None, None, NotImplemented)
552

    
553
    def _VerifyRequest(req):
554
      req.success = True
555
      req.resp_status_code = http.HTTP_OK
556
      req.resp_body = serializer.DumpJson((True, req.post_data))
557

    
558
    http_proc = _FakeRequestProcessor(_VerifyRequest)
559

    
560
    for num in [0, 3796, 9032119]:
561
      client = rpc._RpcClientBase(resolver, encoders.get,
562
                                  _req_process_fn=http_proc)
563
      result = client._Call(cdef, nodes, ["foo", num, "Hello%s" % num])
564
      self.assertEqual(len(result), len(nodes))
565
      for res in result.values():
566
        self.assertFalse(res.fail_msg)
567
        self.assertEqual(serializer.LoadJson(res.payload),
568
                         ["foo", hex(num), hash("Hello%s" % num)])
569

    
570
  def testPostProc(self):
571
    def _VerifyRequest(nums, req):
572
      req.success = True
573
      req.resp_status_code = http.HTTP_OK
574
      req.resp_body = serializer.DumpJson((True, nums))
575

    
576
    resolver = rpc._StaticResolver([
577
      "192.0.2.90",
578
      "192.0.2.95",
579
      ])
580

    
581
    nodes = [
582
      "node90.example.com",
583
      "node95.example.com",
584
      ]
585

    
586
    def _PostProc(res):
587
      self.assertFalse(res.fail_msg)
588
      res.payload = sum(res.payload)
589
      return res
590

    
591
    cdef = ("test_call", NotImplemented, None, rpc_defs.TMO_NORMAL, [],
592
            None, _PostProc, NotImplemented)
593

    
594
    # Seeded random generator
595
    rnd = random.Random(20299)
596

    
597
    for i in [0, 4, 74, 1391]:
598
      nums = [rnd.randint(0, 1000) for _ in range(i)]
599
      http_proc = _FakeRequestProcessor(compat.partial(_VerifyRequest, nums))
600
      client = rpc._RpcClientBase(resolver, NotImplemented,
601
                                  _req_process_fn=http_proc)
602
      result = client._Call(cdef, nodes, [])
603
      self.assertEqual(len(result), len(nodes))
604
      for res in result.values():
605
        self.assertFalse(res.fail_msg)
606
        self.assertEqual(res.payload, sum(nums))
607

    
608
  def testPreProc(self):
609
    def _VerifyRequest(req):
610
      req.success = True
611
      req.resp_status_code = http.HTTP_OK
612
      req.resp_body = serializer.DumpJson((True, req.post_data))
613

    
614
    resolver = rpc._StaticResolver([
615
      "192.0.2.30",
616
      "192.0.2.35",
617
      ])
618

    
619
    nodes = [
620
      "node30.example.com",
621
      "node35.example.com",
622
      ]
623

    
624
    def _PreProc(node, data):
625
      self.assertEqual(len(data), 1)
626
      return data[0] + node
627

    
628
    cdef = ("test_call", NotImplemented, None, rpc_defs.TMO_NORMAL, [
629
      ("arg0", None, NotImplemented),
630
      ], _PreProc, None, NotImplemented)
631

    
632
    http_proc = _FakeRequestProcessor(_VerifyRequest)
633
    client = rpc._RpcClientBase(resolver, NotImplemented,
634
                                _req_process_fn=http_proc)
635

    
636
    for prefix in ["foo", "bar", "baz"]:
637
      result = client._Call(cdef, nodes, [prefix])
638
      self.assertEqual(len(result), len(nodes))
639
      for (idx, (node, res)) in enumerate(result.items()):
640
        self.assertFalse(res.fail_msg)
641
        self.assertEqual(serializer.LoadJson(res.payload), prefix + node)
642

    
643
  def testResolverOptions(self):
644
    def _VerifyRequest(req):
645
      req.success = True
646
      req.resp_status_code = http.HTTP_OK
647
      req.resp_body = serializer.DumpJson((True, req.post_data))
648

    
649
    nodes = [
650
      "node30.example.com",
651
      "node35.example.com",
652
      ]
653

    
654
    def _Resolver(expected, hosts, options):
655
      self.assertEqual(hosts, nodes)
656
      self.assertEqual(options, expected)
657
      return zip(hosts, nodes)
658

    
659
    def _DynamicResolverOptions((arg0, )):
660
      return sum(arg0)
661

    
662
    tests = [
663
      (None, None, None),
664
      (rpc_defs.ACCEPT_OFFLINE_NODE, None, rpc_defs.ACCEPT_OFFLINE_NODE),
665
      (False, None, False),
666
      (True, None, True),
667
      (0, None, 0),
668
      (_DynamicResolverOptions, [1, 2, 3], 6),
669
      (_DynamicResolverOptions, range(4, 19), 165),
670
      ]
671

    
672
    for (resolver_opts, arg0, expected) in tests:
673
      cdef = ("test_call", NotImplemented, resolver_opts, rpc_defs.TMO_NORMAL, [
674
        ("arg0", None, NotImplemented),
675
        ], None, None, NotImplemented)
676

    
677
      http_proc = _FakeRequestProcessor(_VerifyRequest)
678

    
679
      client = rpc._RpcClientBase(compat.partial(_Resolver, expected),
680
                                  NotImplemented, _req_process_fn=http_proc)
681
      result = client._Call(cdef, nodes, [arg0])
682
      self.assertEqual(len(result), len(nodes))
683
      for (idx, (node, res)) in enumerate(result.items()):
684
        self.assertFalse(res.fail_msg)
685

    
686

    
687
class _FakeConfigForRpcRunner:
688
  GetAllNodesInfo = NotImplemented
689

    
690
  def GetNodeInfo(self, name):
691
    return objects.Node(name=name)
692

    
693

    
694
class TestRpcRunner(unittest.TestCase):
695
  def testUploadFile(self):
696
    data = 1779 * "Hello World\n"
697

    
698
    tmpfile = tempfile.NamedTemporaryFile()
699
    tmpfile.write(data)
700
    tmpfile.flush()
701
    st = os.stat(tmpfile.name)
702

    
703
    def _VerifyRequest(req):
704
      (uldata, ) = serializer.LoadJson(req.post_data)
705
      self.assertEqual(len(uldata), 7)
706
      self.assertEqual(uldata[0], tmpfile.name)
707
      self.assertEqual(list(uldata[1]), list(rpc._Compress(data)))
708
      self.assertEqual(uldata[2], st.st_mode)
709
      self.assertEqual(uldata[3], "user%s" % os.getuid())
710
      self.assertEqual(uldata[4], "group%s" % os.getgid())
711
      self.assertEqual(uldata[5], st.st_atime)
712
      self.assertEqual(uldata[6], st.st_mtime)
713

    
714
      req.success = True
715
      req.resp_status_code = http.HTTP_OK
716
      req.resp_body = serializer.DumpJson((True, None))
717

    
718
    http_proc = _FakeRequestProcessor(_VerifyRequest)
719
    cfg = _FakeConfigForRpcRunner()
720
    runner = rpc.RpcRunner(cfg, None, _req_process_fn=http_proc,
721
                           _getents=mocks.FakeGetentResolver)
722

    
723
    nodes = [
724
      "node1.example.com",
725
      ]
726

    
727
    result = runner.call_upload_file(nodes, tmpfile.name)
728
    self.assertEqual(len(result), len(nodes))
729
    for (idx, (node, res)) in enumerate(result.items()):
730
      self.assertFalse(res.fail_msg)
731

    
732

    
733
if __name__ == "__main__":
734
  testutils.GanetiTestProgram()