Statistics
| Branch: | Tag: | Revision:

root / test / ganeti.rpc_unittest.py @ bd6d1202

History | View | Annotate | Download (30.2 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2010, 2011, 2012 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for testing ganeti.rpc"""
23

    
24
import os
25
import sys
26
import unittest
27
import random
28
import tempfile
29

    
30
from ganeti import constants
31
from ganeti import compat
32
from ganeti import rpc
33
from ganeti import rpc_defs
34
from ganeti import http
35
from ganeti import errors
36
from ganeti import serializer
37
from ganeti import objects
38
from ganeti import backend
39

    
40
import testutils
41
import mocks
42

    
43

    
44
class _FakeRequestProcessor:
45
  def __init__(self, response_fn):
46
    self._response_fn = response_fn
47
    self.reqcount = 0
48

    
49
  def __call__(self, reqs, lock_monitor_cb=None):
50
    assert lock_monitor_cb is None or callable(lock_monitor_cb)
51
    for req in reqs:
52
      self.reqcount += 1
53
      self._response_fn(req)
54

    
55

    
56
def GetFakeSimpleStoreClass(fn):
57
  class FakeSimpleStore:
58
    GetNodePrimaryIPList = fn
59
    GetPrimaryIPFamily = lambda _: None
60

    
61
  return FakeSimpleStore
62

    
63

    
64
def _RaiseNotImplemented():
65
  """Simple wrapper to raise NotImplementedError.
66

67
  """
68
  raise NotImplementedError
69

    
70

    
71
class TestRpcProcessor(unittest.TestCase):
72
  def _FakeAddressLookup(self, map):
73
    return lambda node_list: [map.get(node) for node in node_list]
74

    
75
  def _GetVersionResponse(self, req):
76
    self.assertEqual(req.host, "127.0.0.1")
77
    self.assertEqual(req.port, 24094)
78
    self.assertEqual(req.path, "/version")
79
    self.assertEqual(req.read_timeout, rpc._TMO_URGENT)
80
    req.success = True
81
    req.resp_status_code = http.HTTP_OK
82
    req.resp_body = serializer.DumpJson((True, 123))
83

    
84
  def testVersionSuccess(self):
85
    resolver = rpc._StaticResolver(["127.0.0.1"])
86
    http_proc = _FakeRequestProcessor(self._GetVersionResponse)
87
    proc = rpc._RpcProcessor(resolver, 24094)
88
    result = proc(["localhost"], "version", {"localhost": ""}, 60,
89
                  NotImplemented, _req_process_fn=http_proc)
90
    self.assertEqual(result.keys(), ["localhost"])
91
    lhresp = result["localhost"]
92
    self.assertFalse(lhresp.offline)
93
    self.assertEqual(lhresp.node, "localhost")
94
    self.assertFalse(lhresp.fail_msg)
95
    self.assertEqual(lhresp.payload, 123)
96
    self.assertEqual(lhresp.call, "version")
97
    lhresp.Raise("should not raise")
98
    self.assertEqual(http_proc.reqcount, 1)
99

    
100
  def _ReadTimeoutResponse(self, req):
101
    self.assertEqual(req.host, "192.0.2.13")
102
    self.assertEqual(req.port, 19176)
103
    self.assertEqual(req.path, "/version")
104
    self.assertEqual(req.read_timeout, 12356)
105
    req.success = True
106
    req.resp_status_code = http.HTTP_OK
107
    req.resp_body = serializer.DumpJson((True, -1))
108

    
109
  def testReadTimeout(self):
110
    resolver = rpc._StaticResolver(["192.0.2.13"])
111
    http_proc = _FakeRequestProcessor(self._ReadTimeoutResponse)
112
    proc = rpc._RpcProcessor(resolver, 19176)
113
    host = "node31856"
114
    body = {host: ""}
115
    result = proc([host], "version", body, 12356, NotImplemented,
116
                  _req_process_fn=http_proc)
117
    self.assertEqual(result.keys(), [host])
118
    lhresp = result[host]
119
    self.assertFalse(lhresp.offline)
120
    self.assertEqual(lhresp.node, host)
121
    self.assertFalse(lhresp.fail_msg)
122
    self.assertEqual(lhresp.payload, -1)
123
    self.assertEqual(lhresp.call, "version")
124
    lhresp.Raise("should not raise")
125
    self.assertEqual(http_proc.reqcount, 1)
126

    
127
  def testOfflineNode(self):
128
    resolver = rpc._StaticResolver([rpc._OFFLINE])
129
    http_proc = _FakeRequestProcessor(NotImplemented)
130
    proc = rpc._RpcProcessor(resolver, 30668)
131
    host = "n17296"
132
    body = {host: ""}
133
    result = proc([host], "version", body, 60, NotImplemented,
134
                  _req_process_fn=http_proc)
135
    self.assertEqual(result.keys(), [host])
136
    lhresp = result[host]
137
    self.assertTrue(lhresp.offline)
138
    self.assertEqual(lhresp.node, host)
139
    self.assertTrue(lhresp.fail_msg)
140
    self.assertFalse(lhresp.payload)
141
    self.assertEqual(lhresp.call, "version")
142

    
143
    # With a message
144
    self.assertRaises(errors.OpExecError, lhresp.Raise, "should raise")
145

    
146
    # No message
147
    self.assertRaises(errors.OpExecError, lhresp.Raise, None)
148

    
149
    self.assertEqual(http_proc.reqcount, 0)
150

    
151
  def _GetMultiVersionResponse(self, req):
152
    self.assert_(req.host.startswith("node"))
153
    self.assertEqual(req.port, 23245)
154
    self.assertEqual(req.path, "/version")
155
    req.success = True
156
    req.resp_status_code = http.HTTP_OK
157
    req.resp_body = serializer.DumpJson((True, 987))
158

    
159
  def testMultiVersionSuccess(self):
160
    nodes = ["node%s" % i for i in range(50)]
161
    body = dict((n, "") for n in nodes)
162
    resolver = rpc._StaticResolver(nodes)
163
    http_proc = _FakeRequestProcessor(self._GetMultiVersionResponse)
164
    proc = rpc._RpcProcessor(resolver, 23245)
165
    result = proc(nodes, "version", body, 60, NotImplemented,
166
                  _req_process_fn=http_proc)
167
    self.assertEqual(sorted(result.keys()), sorted(nodes))
168

    
169
    for name in nodes:
170
      lhresp = result[name]
171
      self.assertFalse(lhresp.offline)
172
      self.assertEqual(lhresp.node, name)
173
      self.assertFalse(lhresp.fail_msg)
174
      self.assertEqual(lhresp.payload, 987)
175
      self.assertEqual(lhresp.call, "version")
176
      lhresp.Raise("should not raise")
177

    
178
    self.assertEqual(http_proc.reqcount, len(nodes))
179

    
180
  def _GetVersionResponseFail(self, errinfo, req):
181
    self.assertEqual(req.path, "/version")
182
    req.success = True
183
    req.resp_status_code = http.HTTP_OK
184
    req.resp_body = serializer.DumpJson((False, errinfo))
185

    
186
  def testVersionFailure(self):
187
    resolver = rpc._StaticResolver(["aef9ur4i.example.com"])
188
    proc = rpc._RpcProcessor(resolver, 5903)
189
    for errinfo in [None, "Unknown error"]:
190
      http_proc = \
191
        _FakeRequestProcessor(compat.partial(self._GetVersionResponseFail,
192
                                             errinfo))
193
      host = "aef9ur4i.example.com"
194
      body = {host: ""}
195
      result = proc(body.keys(), "version", body, 60, NotImplemented,
196
                    _req_process_fn=http_proc)
197
      self.assertEqual(result.keys(), [host])
198
      lhresp = result[host]
199
      self.assertFalse(lhresp.offline)
200
      self.assertEqual(lhresp.node, host)
201
      self.assert_(lhresp.fail_msg)
202
      self.assertFalse(lhresp.payload)
203
      self.assertEqual(lhresp.call, "version")
204
      self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
205
      self.assertEqual(http_proc.reqcount, 1)
206

    
207
  def _GetHttpErrorResponse(self, httperrnodes, failnodes, req):
208
    self.assertEqual(req.path, "/vg_list")
209
    self.assertEqual(req.port, 15165)
210

    
211
    if req.host in httperrnodes:
212
      req.success = False
213
      req.error = "Node set up for HTTP errors"
214

    
215
    elif req.host in failnodes:
216
      req.success = True
217
      req.resp_status_code = 404
218
      req.resp_body = serializer.DumpJson({
219
        "code": 404,
220
        "message": "Method not found",
221
        "explain": "Explanation goes here",
222
        })
223
    else:
224
      req.success = True
225
      req.resp_status_code = http.HTTP_OK
226
      req.resp_body = serializer.DumpJson((True, hash(req.host)))
227

    
228
  def testHttpError(self):
229
    nodes = ["uaf6pbbv%s" % i for i in range(50)]
230
    body = dict((n, "") for n in nodes)
231
    resolver = rpc._StaticResolver(nodes)
232

    
233
    httperrnodes = set(nodes[1::7])
234
    self.assertEqual(len(httperrnodes), 7)
235

    
236
    failnodes = set(nodes[2::3]) - httperrnodes
237
    self.assertEqual(len(failnodes), 14)
238

    
239
    self.assertEqual(len(set(nodes) - failnodes - httperrnodes), 29)
240

    
241
    proc = rpc._RpcProcessor(resolver, 15165)
242
    http_proc = \
243
      _FakeRequestProcessor(compat.partial(self._GetHttpErrorResponse,
244
                                           httperrnodes, failnodes))
245
    result = proc(nodes, "vg_list", body, rpc._TMO_URGENT, NotImplemented,
246
                  _req_process_fn=http_proc)
247
    self.assertEqual(sorted(result.keys()), sorted(nodes))
248

    
249
    for name in nodes:
250
      lhresp = result[name]
251
      self.assertFalse(lhresp.offline)
252
      self.assertEqual(lhresp.node, name)
253
      self.assertEqual(lhresp.call, "vg_list")
254

    
255
      if name in httperrnodes:
256
        self.assert_(lhresp.fail_msg)
257
        self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
258
      elif name in failnodes:
259
        self.assert_(lhresp.fail_msg)
260
        self.assertRaises(errors.OpPrereqError, lhresp.Raise, "failed",
261
                          prereq=True, ecode=errors.ECODE_INVAL)
262
      else:
263
        self.assertFalse(lhresp.fail_msg)
264
        self.assertEqual(lhresp.payload, hash(name))
265
        lhresp.Raise("should not raise")
266

    
267
    self.assertEqual(http_proc.reqcount, len(nodes))
268

    
269
  def _GetInvalidResponseA(self, req):
270
    self.assertEqual(req.path, "/version")
271
    req.success = True
272
    req.resp_status_code = http.HTTP_OK
273
    req.resp_body = serializer.DumpJson(("This", "is", "an", "invalid",
274
                                         "response", "!", 1, 2, 3))
275

    
276
  def _GetInvalidResponseB(self, req):
277
    self.assertEqual(req.path, "/version")
278
    req.success = True
279
    req.resp_status_code = http.HTTP_OK
280
    req.resp_body = serializer.DumpJson("invalid response")
281

    
282
  def testInvalidResponse(self):
283
    resolver = rpc._StaticResolver(["oqo7lanhly.example.com"])
284
    proc = rpc._RpcProcessor(resolver, 19978)
285

    
286
    for fn in [self._GetInvalidResponseA, self._GetInvalidResponseB]:
287
      http_proc = _FakeRequestProcessor(fn)
288
      host = "oqo7lanhly.example.com"
289
      body = {host: ""}
290
      result = proc([host], "version", body, 60, NotImplemented,
291
                    _req_process_fn=http_proc)
292
      self.assertEqual(result.keys(), [host])
293
      lhresp = result[host]
294
      self.assertFalse(lhresp.offline)
295
      self.assertEqual(lhresp.node, host)
296
      self.assert_(lhresp.fail_msg)
297
      self.assertFalse(lhresp.payload)
298
      self.assertEqual(lhresp.call, "version")
299
      self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
300
      self.assertEqual(http_proc.reqcount, 1)
301

    
302
  def _GetBodyTestResponse(self, test_data, req):
303
    self.assertEqual(req.host, "192.0.2.84")
304
    self.assertEqual(req.port, 18700)
305
    self.assertEqual(req.path, "/upload_file")
306
    self.assertEqual(serializer.LoadJson(req.post_data), test_data)
307
    req.success = True
308
    req.resp_status_code = http.HTTP_OK
309
    req.resp_body = serializer.DumpJson((True, None))
310

    
311
  def testResponseBody(self):
312
    test_data = {
313
      "Hello": "World",
314
      "xyz": range(10),
315
      }
316
    resolver = rpc._StaticResolver(["192.0.2.84"])
317
    http_proc = _FakeRequestProcessor(compat.partial(self._GetBodyTestResponse,
318
                                                     test_data))
319
    proc = rpc._RpcProcessor(resolver, 18700)
320
    host = "node19759"
321
    body = {host: serializer.DumpJson(test_data)}
322
    result = proc([host], "upload_file", body, 30, NotImplemented,
323
                  _req_process_fn=http_proc)
324
    self.assertEqual(result.keys(), [host])
325
    lhresp = result[host]
326
    self.assertFalse(lhresp.offline)
327
    self.assertEqual(lhresp.node, host)
328
    self.assertFalse(lhresp.fail_msg)
329
    self.assertEqual(lhresp.payload, None)
330
    self.assertEqual(lhresp.call, "upload_file")
331
    lhresp.Raise("should not raise")
332
    self.assertEqual(http_proc.reqcount, 1)
333

    
334

    
335
class TestSsconfResolver(unittest.TestCase):
336
  def testSsconfLookup(self):
337
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
338
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
339
    node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
340
    ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
341
    result = rpc._SsconfResolver(True, node_list, NotImplemented,
342
                                 ssc=ssc, nslookup_fn=NotImplemented)
343
    self.assertEqual(result, zip(node_list, addr_list))
344

    
345
  def testNsLookup(self):
346
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
347
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
348
    ssc = GetFakeSimpleStoreClass(lambda _: [])
349
    node_addr_map = dict(zip(node_list, addr_list))
350
    nslookup_fn = lambda name, family=None: node_addr_map.get(name)
351
    result = rpc._SsconfResolver(True, node_list, NotImplemented,
352
                                 ssc=ssc, nslookup_fn=nslookup_fn)
353
    self.assertEqual(result, zip(node_list, addr_list))
354

    
355
  def testDisabledSsconfIp(self):
356
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
357
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
358
    ssc = GetFakeSimpleStoreClass(_RaiseNotImplemented)
359
    node_addr_map = dict(zip(node_list, addr_list))
360
    nslookup_fn = lambda name, family=None: node_addr_map.get(name)
361
    result = rpc._SsconfResolver(False, node_list, NotImplemented,
362
                                 ssc=ssc, nslookup_fn=nslookup_fn)
363
    self.assertEqual(result, zip(node_list, addr_list))
364

    
365
  def testBothLookups(self):
366
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
367
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
368
    n = len(addr_list) / 2
369
    node_addr_list = [" ".join(t) for t in zip(node_list[n:], addr_list[n:])]
370
    ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
371
    node_addr_map = dict(zip(node_list[:n], addr_list[:n]))
372
    nslookup_fn = lambda name, family=None: node_addr_map.get(name)
373
    result = rpc._SsconfResolver(True, node_list, NotImplemented,
374
                                 ssc=ssc, nslookup_fn=nslookup_fn)
375
    self.assertEqual(result, zip(node_list, addr_list))
376

    
377
  def testAddressLookupIPv6(self):
378
    addr_list = ["2001:db8::%d" % n for n in range(0, 255, 11)]
379
    node_list = ["node%d.example.com" % n for n in range(0, 255, 11)]
380
    node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
381
    ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
382
    result = rpc._SsconfResolver(True, node_list, NotImplemented,
383
                                 ssc=ssc, nslookup_fn=NotImplemented)
384
    self.assertEqual(result, zip(node_list, addr_list))
385

    
386

    
387
class TestStaticResolver(unittest.TestCase):
388
  def test(self):
389
    addresses = ["192.0.2.%d" % n for n in range(0, 123, 7)]
390
    nodes = ["node%s.example.com" % n for n in range(0, 123, 7)]
391
    res = rpc._StaticResolver(addresses)
392
    self.assertEqual(res(nodes, NotImplemented), zip(nodes, addresses))
393

    
394
  def testWrongLength(self):
395
    res = rpc._StaticResolver([])
396
    self.assertRaises(AssertionError, res, ["abc"], NotImplemented)
397

    
398

    
399
class TestNodeConfigResolver(unittest.TestCase):
400
  @staticmethod
401
  def _GetSingleOnlineNode(name):
402
    assert name == "node90.example.com"
403
    return objects.Node(name=name, offline=False, primary_ip="192.0.2.90")
404

    
405
  @staticmethod
406
  def _GetSingleOfflineNode(name):
407
    assert name == "node100.example.com"
408
    return objects.Node(name=name, offline=True, primary_ip="192.0.2.100")
409

    
410
  def testSingleOnline(self):
411
    self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOnlineNode,
412
                                             NotImplemented,
413
                                             ["node90.example.com"], None),
414
                     [("node90.example.com", "192.0.2.90")])
415

    
416
  def testSingleOffline(self):
417
    self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOfflineNode,
418
                                             NotImplemented,
419
                                             ["node100.example.com"], None),
420
                     [("node100.example.com", rpc._OFFLINE)])
421

    
422
  def testSingleOfflineWithAcceptOffline(self):
423
    fn = self._GetSingleOfflineNode
424
    assert fn("node100.example.com").offline
425
    self.assertEqual(rpc._NodeConfigResolver(fn, NotImplemented,
426
                                             ["node100.example.com"],
427
                                             rpc_defs.ACCEPT_OFFLINE_NODE),
428
                     [("node100.example.com", "192.0.2.100")])
429
    for i in [False, True, "", "Hello", 0, 1]:
430
      self.assertRaises(AssertionError, rpc._NodeConfigResolver,
431
                        fn, NotImplemented, ["node100.example.com"], i)
432

    
433
  def testUnknownSingleNode(self):
434
    self.assertEqual(rpc._NodeConfigResolver(lambda _: None, NotImplemented,
435
                                             ["node110.example.com"], None),
436
                     [("node110.example.com", "node110.example.com")])
437

    
438
  def testMultiEmpty(self):
439
    self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
440
                                             lambda: {},
441
                                             [], None),
442
                     [])
443

    
444
  def testMultiSomeOffline(self):
445
    nodes = dict(("node%s.example.com" % i,
446
                  objects.Node(name="node%s.example.com" % i,
447
                               offline=((i % 3) == 0),
448
                               primary_ip="192.0.2.%s" % i))
449
                  for i in range(1, 255))
450

    
451
    # Resolve no names
452
    self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
453
                                             lambda: nodes,
454
                                             [], None),
455
                     [])
456

    
457
    # Offline, online and unknown hosts
458
    self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
459
                                             lambda: nodes,
460
                                             ["node3.example.com",
461
                                              "node92.example.com",
462
                                              "node54.example.com",
463
                                              "unknown.example.com",],
464
                                             None), [
465
      ("node3.example.com", rpc._OFFLINE),
466
      ("node92.example.com", "192.0.2.92"),
467
      ("node54.example.com", rpc._OFFLINE),
468
      ("unknown.example.com", "unknown.example.com"),
469
      ])
470

    
471

    
472
class TestCompress(unittest.TestCase):
473
  def test(self):
474
    for data in ["", "Hello", "Hello World!\nnew\nlines"]:
475
      self.assertEqual(rpc._Compress(data),
476
                       (constants.RPC_ENCODING_NONE, data))
477

    
478
    for data in [512 * " ", 5242 * "Hello World!\n"]:
479
      compressed = rpc._Compress(data)
480
      self.assertEqual(len(compressed), 2)
481
      self.assertEqual(backend._Decompress(compressed), data)
482

    
483
  def testDecompression(self):
484
    self.assertRaises(AssertionError, backend._Decompress, "")
485
    self.assertRaises(AssertionError, backend._Decompress, [""])
486
    self.assertRaises(AssertionError, backend._Decompress,
487
                      ("unknown compression", "data"))
488
    self.assertRaises(Exception, backend._Decompress,
489
                      (constants.RPC_ENCODING_ZLIB_BASE64, "invalid zlib data"))
490

    
491

    
492
class TestRpcClientBase(unittest.TestCase):
493
  def testNoHosts(self):
494
    cdef = ("test_call", NotImplemented, None, rpc_defs.TMO_SLOW, [],
495
            None, None, NotImplemented)
496
    http_proc = _FakeRequestProcessor(NotImplemented)
497
    client = rpc._RpcClientBase(rpc._StaticResolver([]), NotImplemented,
498
                                _req_process_fn=http_proc)
499
    self.assertEqual(client._Call(cdef, [], []), {})
500

    
501
    # Test wrong number of arguments
502
    self.assertRaises(errors.ProgrammerError, client._Call,
503
                      cdef, [], [0, 1, 2])
504

    
505
  def testTimeout(self):
506
    def _CalcTimeout((arg1, arg2)):
507
      return arg1 + arg2
508

    
509
    def _VerifyRequest(exp_timeout, req):
510
      self.assertEqual(req.read_timeout, exp_timeout)
511

    
512
      req.success = True
513
      req.resp_status_code = http.HTTP_OK
514
      req.resp_body = serializer.DumpJson((True, hex(req.read_timeout)))
515

    
516
    resolver = rpc._StaticResolver([
517
      "192.0.2.1",
518
      "192.0.2.2",
519
      ])
520

    
521
    nodes = [
522
      "node1.example.com",
523
      "node2.example.com",
524
      ]
525

    
526
    tests = [(100, None, 100), (30, None, 30)]
527
    tests.extend((_CalcTimeout, i, i + 300)
528
                 for i in [0, 5, 16485, 30516])
529

    
530
    for timeout, arg1, exp_timeout in tests:
531
      cdef = ("test_call", NotImplemented, None, timeout, [
532
        ("arg1", None, NotImplemented),
533
        ("arg2", None, NotImplemented),
534
        ], None, None, NotImplemented)
535

    
536
      http_proc = _FakeRequestProcessor(compat.partial(_VerifyRequest,
537
                                                       exp_timeout))
538
      client = rpc._RpcClientBase(resolver, NotImplemented,
539
                                  _req_process_fn=http_proc)
540
      result = client._Call(cdef, nodes, [arg1, 300])
541
      self.assertEqual(len(result), len(nodes))
542
      self.assertTrue(compat.all(not res.fail_msg and
543
                                 res.payload == hex(exp_timeout)
544
                                 for res in result.values()))
545

    
546
  def testArgumentEncoder(self):
547
    (AT1, AT2) = range(1, 3)
548

    
549
    resolver = rpc._StaticResolver([
550
      "192.0.2.5",
551
      "192.0.2.6",
552
      ])
553

    
554
    nodes = [
555
      "node5.example.com",
556
      "node6.example.com",
557
      ]
558

    
559
    encoders = {
560
      AT1: hex,
561
      AT2: hash,
562
      }
563

    
564
    cdef = ("test_call", NotImplemented, None, rpc_defs.TMO_NORMAL, [
565
      ("arg0", None, NotImplemented),
566
      ("arg1", AT1, NotImplemented),
567
      ("arg1", AT2, NotImplemented),
568
      ], None, None, NotImplemented)
569

    
570
    def _VerifyRequest(req):
571
      req.success = True
572
      req.resp_status_code = http.HTTP_OK
573
      req.resp_body = serializer.DumpJson((True, req.post_data))
574

    
575
    http_proc = _FakeRequestProcessor(_VerifyRequest)
576

    
577
    for num in [0, 3796, 9032119]:
578
      client = rpc._RpcClientBase(resolver, encoders.get,
579
                                  _req_process_fn=http_proc)
580
      result = client._Call(cdef, nodes, ["foo", num, "Hello%s" % num])
581
      self.assertEqual(len(result), len(nodes))
582
      for res in result.values():
583
        self.assertFalse(res.fail_msg)
584
        self.assertEqual(serializer.LoadJson(res.payload),
585
                         ["foo", hex(num), hash("Hello%s" % num)])
586

    
587
  def testPostProc(self):
588
    def _VerifyRequest(nums, req):
589
      req.success = True
590
      req.resp_status_code = http.HTTP_OK
591
      req.resp_body = serializer.DumpJson((True, nums))
592

    
593
    resolver = rpc._StaticResolver([
594
      "192.0.2.90",
595
      "192.0.2.95",
596
      ])
597

    
598
    nodes = [
599
      "node90.example.com",
600
      "node95.example.com",
601
      ]
602

    
603
    def _PostProc(res):
604
      self.assertFalse(res.fail_msg)
605
      res.payload = sum(res.payload)
606
      return res
607

    
608
    cdef = ("test_call", NotImplemented, None, rpc_defs.TMO_NORMAL, [],
609
            None, _PostProc, NotImplemented)
610

    
611
    # Seeded random generator
612
    rnd = random.Random(20299)
613

    
614
    for i in [0, 4, 74, 1391]:
615
      nums = [rnd.randint(0, 1000) for _ in range(i)]
616
      http_proc = _FakeRequestProcessor(compat.partial(_VerifyRequest, nums))
617
      client = rpc._RpcClientBase(resolver, NotImplemented,
618
                                  _req_process_fn=http_proc)
619
      result = client._Call(cdef, nodes, [])
620
      self.assertEqual(len(result), len(nodes))
621
      for res in result.values():
622
        self.assertFalse(res.fail_msg)
623
        self.assertEqual(res.payload, sum(nums))
624

    
625
  def testPreProc(self):
626
    def _VerifyRequest(req):
627
      req.success = True
628
      req.resp_status_code = http.HTTP_OK
629
      req.resp_body = serializer.DumpJson((True, req.post_data))
630

    
631
    resolver = rpc._StaticResolver([
632
      "192.0.2.30",
633
      "192.0.2.35",
634
      ])
635

    
636
    nodes = [
637
      "node30.example.com",
638
      "node35.example.com",
639
      ]
640

    
641
    def _PreProc(node, data):
642
      self.assertEqual(len(data), 1)
643
      return data[0] + node
644

    
645
    cdef = ("test_call", NotImplemented, None, rpc_defs.TMO_NORMAL, [
646
      ("arg0", None, NotImplemented),
647
      ], _PreProc, None, NotImplemented)
648

    
649
    http_proc = _FakeRequestProcessor(_VerifyRequest)
650
    client = rpc._RpcClientBase(resolver, NotImplemented,
651
                                _req_process_fn=http_proc)
652

    
653
    for prefix in ["foo", "bar", "baz"]:
654
      result = client._Call(cdef, nodes, [prefix])
655
      self.assertEqual(len(result), len(nodes))
656
      for (idx, (node, res)) in enumerate(result.items()):
657
        self.assertFalse(res.fail_msg)
658
        self.assertEqual(serializer.LoadJson(res.payload), prefix + node)
659

    
660
  def testResolverOptions(self):
661
    def _VerifyRequest(req):
662
      req.success = True
663
      req.resp_status_code = http.HTTP_OK
664
      req.resp_body = serializer.DumpJson((True, req.post_data))
665

    
666
    nodes = [
667
      "node30.example.com",
668
      "node35.example.com",
669
      ]
670

    
671
    def _Resolver(expected, hosts, options):
672
      self.assertEqual(hosts, nodes)
673
      self.assertEqual(options, expected)
674
      return zip(hosts, nodes)
675

    
676
    def _DynamicResolverOptions((arg0, )):
677
      return sum(arg0)
678

    
679
    tests = [
680
      (None, None, None),
681
      (rpc_defs.ACCEPT_OFFLINE_NODE, None, rpc_defs.ACCEPT_OFFLINE_NODE),
682
      (False, None, False),
683
      (True, None, True),
684
      (0, None, 0),
685
      (_DynamicResolverOptions, [1, 2, 3], 6),
686
      (_DynamicResolverOptions, range(4, 19), 165),
687
      ]
688

    
689
    for (resolver_opts, arg0, expected) in tests:
690
      cdef = ("test_call", NotImplemented, resolver_opts, rpc_defs.TMO_NORMAL, [
691
        ("arg0", None, NotImplemented),
692
        ], None, None, NotImplemented)
693

    
694
      http_proc = _FakeRequestProcessor(_VerifyRequest)
695

    
696
      client = rpc._RpcClientBase(compat.partial(_Resolver, expected),
697
                                  NotImplemented, _req_process_fn=http_proc)
698
      result = client._Call(cdef, nodes, [arg0])
699
      self.assertEqual(len(result), len(nodes))
700
      for (idx, (node, res)) in enumerate(result.items()):
701
        self.assertFalse(res.fail_msg)
702

    
703

    
704
class _FakeConfigForRpcRunner:
705
  GetAllNodesInfo = NotImplemented
706

    
707
  def __init__(self, cluster=NotImplemented):
708
    self._cluster = cluster
709

    
710
  def GetNodeInfo(self, name):
711
    return objects.Node(name=name)
712

    
713
  def GetClusterInfo(self):
714
    return self._cluster
715

    
716

    
717
class TestRpcRunner(unittest.TestCase):
718
  def testUploadFile(self):
719
    data = 1779 * "Hello World\n"
720

    
721
    tmpfile = tempfile.NamedTemporaryFile()
722
    tmpfile.write(data)
723
    tmpfile.flush()
724
    st = os.stat(tmpfile.name)
725

    
726
    def _VerifyRequest(req):
727
      (uldata, ) = serializer.LoadJson(req.post_data)
728
      self.assertEqual(len(uldata), 7)
729
      self.assertEqual(uldata[0], tmpfile.name)
730
      self.assertEqual(list(uldata[1]), list(rpc._Compress(data)))
731
      self.assertEqual(uldata[2], st.st_mode)
732
      self.assertEqual(uldata[3], "user%s" % os.getuid())
733
      self.assertEqual(uldata[4], "group%s" % os.getgid())
734
      self.assertTrue(uldata[5] is not None)
735
      self.assertEqual(uldata[6], st.st_mtime)
736

    
737
      req.success = True
738
      req.resp_status_code = http.HTTP_OK
739
      req.resp_body = serializer.DumpJson((True, None))
740

    
741
    http_proc = _FakeRequestProcessor(_VerifyRequest)
742

    
743
    std_runner = rpc.RpcRunner(_FakeConfigForRpcRunner(), None,
744
                               _req_process_fn=http_proc,
745
                               _getents=mocks.FakeGetentResolver)
746

    
747
    cfg_runner = rpc.ConfigRunner(None, ["192.0.2.13"],
748
                                  _req_process_fn=http_proc,
749
                                  _getents=mocks.FakeGetentResolver)
750

    
751
    nodes = [
752
      "node1.example.com",
753
      ]
754

    
755
    for runner in [std_runner, cfg_runner]:
756
      result = runner.call_upload_file(nodes, tmpfile.name)
757
      self.assertEqual(len(result), len(nodes))
758
      for (idx, (node, res)) in enumerate(result.items()):
759
        self.assertFalse(res.fail_msg)
760

    
761
  def testEncodeInstance(self):
762
    cluster = objects.Cluster(hvparams={
763
      constants.HT_KVM: {
764
        constants.HV_BLOCKDEV_PREFIX: "foo",
765
        },
766
      },
767
      beparams={
768
        constants.PP_DEFAULT: {
769
          constants.BE_MAXMEM: 8192,
770
          },
771
        },
772
      os_hvp={},
773
      osparams={
774
        "linux": {
775
          "role": "unknown",
776
          },
777
        })
778
    cluster.UpgradeConfig()
779

    
780
    inst = objects.Instance(name="inst1.example.com",
781
      hypervisor=constants.HT_FAKE,
782
      os="linux",
783
      hvparams={
784
        constants.HT_KVM: {
785
          constants.HV_BLOCKDEV_PREFIX: "bar",
786
          constants.HV_ROOT_PATH: "/tmp",
787
          },
788
        },
789
      beparams={
790
        constants.BE_MINMEM: 128,
791
        constants.BE_MAXMEM: 256,
792
        },
793
      nics=[
794
        objects.NIC(nicparams={
795
          constants.NIC_MODE: "mymode",
796
          }),
797
        ],
798
      disks=[])
799
    inst.UpgradeConfig()
800

    
801
    cfg = _FakeConfigForRpcRunner(cluster=cluster)
802
    runner = rpc.RpcRunner(cfg, None,
803
                           _req_process_fn=NotImplemented,
804
                           _getents=mocks.FakeGetentResolver)
805

    
806
    def _CheckBasics(result):
807
      self.assertEqual(result["name"], "inst1.example.com")
808
      self.assertEqual(result["os"], "linux")
809
      self.assertEqual(result["beparams"][constants.BE_MINMEM], 128)
810
      self.assertEqual(len(result["hvparams"]), 1)
811
      self.assertEqual(len(result["nics"]), 1)
812
      self.assertEqual(result["nics"][0]["nicparams"][constants.NIC_MODE],
813
                       "mymode")
814

    
815
    # Generic object serialization
816
    result = runner._encoder((rpc_defs.ED_OBJECT_DICT, inst))
817
    _CheckBasics(result)
818

    
819
    result = runner._encoder((rpc_defs.ED_OBJECT_DICT_LIST, 5 * [inst]))
820
    map(_CheckBasics, result)
821

    
822
    # Just an instance
823
    result = runner._encoder((rpc_defs.ED_INST_DICT, inst))
824
    _CheckBasics(result)
825
    self.assertEqual(result["beparams"][constants.BE_MAXMEM], 256)
826
    self.assertEqual(result["hvparams"][constants.HT_KVM], {
827
      constants.HV_BLOCKDEV_PREFIX: "bar",
828
      constants.HV_ROOT_PATH: "/tmp",
829
      })
830
    self.assertEqual(result["osparams"], {
831
      "role": "unknown",
832
      })
833

    
834
    # Instance with OS parameters
835
    result = runner._encoder((rpc_defs.ED_INST_DICT_OSP, (inst, {
836
      "role": "webserver",
837
      "other": "field",
838
      })))
839
    _CheckBasics(result)
840
    self.assertEqual(result["beparams"][constants.BE_MAXMEM], 256)
841
    self.assertEqual(result["hvparams"][constants.HT_KVM], {
842
      constants.HV_BLOCKDEV_PREFIX: "bar",
843
      constants.HV_ROOT_PATH: "/tmp",
844
      })
845
    self.assertEqual(result["osparams"], {
846
      "role": "webserver",
847
      "other": "field",
848
      })
849

    
850
    # Instance with hypervisor and backend parameters
851
    result = runner._encoder((rpc_defs.ED_INST_DICT_HVP_BEP, (inst, {
852
      constants.HT_KVM: {
853
        constants.HV_BOOT_ORDER: "xyz",
854
        },
855
      }, {
856
      constants.BE_VCPUS: 100,
857
      constants.BE_MAXMEM: 4096,
858
      })))
859
    _CheckBasics(result)
860
    self.assertEqual(result["beparams"][constants.BE_MAXMEM], 4096)
861
    self.assertEqual(result["beparams"][constants.BE_VCPUS], 100)
862
    self.assertEqual(result["hvparams"][constants.HT_KVM], {
863
      constants.HV_BOOT_ORDER: "xyz",
864
      })
865

    
866

    
867
if __name__ == "__main__":
868
  testutils.GanetiTestProgram()