Statistics
| Branch: | Tag: | Revision:

root / test / py / ganeti.rpc_unittest.py @ 52a988f2

History | View | Annotate | Download (34.4 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2010, 2011, 2012, 2013 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for testing ganeti.rpc"""
23

    
24
import os
25
import sys
26
import unittest
27
import random
28
import tempfile
29

    
30
from ganeti import constants
31
from ganeti import compat
32
from ganeti import rpc
33
from ganeti import rpc_defs
34
from ganeti import http
35
from ganeti import errors
36
from ganeti import serializer
37
from ganeti import objects
38
from ganeti import backend
39

    
40
import testutils
41
import mocks
42

    
43

    
44
class _FakeRequestProcessor:
45
  def __init__(self, response_fn):
46
    self._response_fn = response_fn
47
    self.reqcount = 0
48

    
49
  def __call__(self, reqs, lock_monitor_cb=None):
50
    assert lock_monitor_cb is None or callable(lock_monitor_cb)
51
    for req in reqs:
52
      self.reqcount += 1
53
      self._response_fn(req)
54

    
55

    
56
def GetFakeSimpleStoreClass(fn):
57
  class FakeSimpleStore:
58
    GetNodePrimaryIPList = fn
59
    GetPrimaryIPFamily = lambda _: None
60

    
61
  return FakeSimpleStore
62

    
63

    
64
def _RaiseNotImplemented():
65
  """Simple wrapper to raise NotImplementedError.
66

67
  """
68
  raise NotImplementedError
69

    
70

    
71
class TestRpcProcessor(unittest.TestCase):
72
  def _FakeAddressLookup(self, map):
73
    return lambda node_list: [map.get(node) for node in node_list]
74

    
75
  def _GetVersionResponse(self, req):
76
    self.assertEqual(req.host, "127.0.0.1")
77
    self.assertEqual(req.port, 24094)
78
    self.assertEqual(req.path, "/version")
79
    self.assertEqual(req.read_timeout, constants.RPC_TMO_URGENT)
80
    req.success = True
81
    req.resp_status_code = http.HTTP_OK
82
    req.resp_body = serializer.DumpJson((True, 123))
83

    
84
  def testVersionSuccess(self):
85
    resolver = rpc._StaticResolver(["127.0.0.1"])
86
    http_proc = _FakeRequestProcessor(self._GetVersionResponse)
87
    proc = rpc._RpcProcessor(resolver, 24094)
88
    result = proc(["localhost"], "version", {"localhost": ""}, 60,
89
                  NotImplemented, _req_process_fn=http_proc)
90
    self.assertEqual(result.keys(), ["localhost"])
91
    lhresp = result["localhost"]
92
    self.assertFalse(lhresp.offline)
93
    self.assertEqual(lhresp.node, "localhost")
94
    self.assertFalse(lhresp.fail_msg)
95
    self.assertEqual(lhresp.payload, 123)
96
    self.assertEqual(lhresp.call, "version")
97
    lhresp.Raise("should not raise")
98
    self.assertEqual(http_proc.reqcount, 1)
99

    
100
  def _ReadTimeoutResponse(self, req):
101
    self.assertEqual(req.host, "192.0.2.13")
102
    self.assertEqual(req.port, 19176)
103
    self.assertEqual(req.path, "/version")
104
    self.assertEqual(req.read_timeout, 12356)
105
    req.success = True
106
    req.resp_status_code = http.HTTP_OK
107
    req.resp_body = serializer.DumpJson((True, -1))
108

    
109
  def testReadTimeout(self):
110
    resolver = rpc._StaticResolver(["192.0.2.13"])
111
    http_proc = _FakeRequestProcessor(self._ReadTimeoutResponse)
112
    proc = rpc._RpcProcessor(resolver, 19176)
113
    host = "node31856"
114
    body = {host: ""}
115
    result = proc([host], "version", body, 12356, NotImplemented,
116
                  _req_process_fn=http_proc)
117
    self.assertEqual(result.keys(), [host])
118
    lhresp = result[host]
119
    self.assertFalse(lhresp.offline)
120
    self.assertEqual(lhresp.node, host)
121
    self.assertFalse(lhresp.fail_msg)
122
    self.assertEqual(lhresp.payload, -1)
123
    self.assertEqual(lhresp.call, "version")
124
    lhresp.Raise("should not raise")
125
    self.assertEqual(http_proc.reqcount, 1)
126

    
127
  def testOfflineNode(self):
128
    resolver = rpc._StaticResolver([rpc._OFFLINE])
129
    http_proc = _FakeRequestProcessor(NotImplemented)
130
    proc = rpc._RpcProcessor(resolver, 30668)
131
    host = "n17296"
132
    body = {host: ""}
133
    result = proc([host], "version", body, 60, NotImplemented,
134
                  _req_process_fn=http_proc)
135
    self.assertEqual(result.keys(), [host])
136
    lhresp = result[host]
137
    self.assertTrue(lhresp.offline)
138
    self.assertEqual(lhresp.node, host)
139
    self.assertTrue(lhresp.fail_msg)
140
    self.assertFalse(lhresp.payload)
141
    self.assertEqual(lhresp.call, "version")
142

    
143
    # With a message
144
    self.assertRaises(errors.OpExecError, lhresp.Raise, "should raise")
145

    
146
    # No message
147
    self.assertRaises(errors.OpExecError, lhresp.Raise, None)
148

    
149
    self.assertEqual(http_proc.reqcount, 0)
150

    
151
  def _GetMultiVersionResponse(self, req):
152
    self.assert_(req.host.startswith("node"))
153
    self.assertEqual(req.port, 23245)
154
    self.assertEqual(req.path, "/version")
155
    req.success = True
156
    req.resp_status_code = http.HTTP_OK
157
    req.resp_body = serializer.DumpJson((True, 987))
158

    
159
  def testMultiVersionSuccess(self):
160
    nodes = ["node%s" % i for i in range(50)]
161
    body = dict((n, "") for n in nodes)
162
    resolver = rpc._StaticResolver(nodes)
163
    http_proc = _FakeRequestProcessor(self._GetMultiVersionResponse)
164
    proc = rpc._RpcProcessor(resolver, 23245)
165
    result = proc(nodes, "version", body, 60, NotImplemented,
166
                  _req_process_fn=http_proc)
167
    self.assertEqual(sorted(result.keys()), sorted(nodes))
168

    
169
    for name in nodes:
170
      lhresp = result[name]
171
      self.assertFalse(lhresp.offline)
172
      self.assertEqual(lhresp.node, name)
173
      self.assertFalse(lhresp.fail_msg)
174
      self.assertEqual(lhresp.payload, 987)
175
      self.assertEqual(lhresp.call, "version")
176
      lhresp.Raise("should not raise")
177

    
178
    self.assertEqual(http_proc.reqcount, len(nodes))
179

    
180
  def _GetVersionResponseFail(self, errinfo, req):
181
    self.assertEqual(req.path, "/version")
182
    req.success = True
183
    req.resp_status_code = http.HTTP_OK
184
    req.resp_body = serializer.DumpJson((False, errinfo))
185

    
186
  def testVersionFailure(self):
187
    resolver = rpc._StaticResolver(["aef9ur4i.example.com"])
188
    proc = rpc._RpcProcessor(resolver, 5903)
189
    for errinfo in [None, "Unknown error"]:
190
      http_proc = \
191
        _FakeRequestProcessor(compat.partial(self._GetVersionResponseFail,
192
                                             errinfo))
193
      host = "aef9ur4i.example.com"
194
      body = {host: ""}
195
      result = proc(body.keys(), "version", body, 60, NotImplemented,
196
                    _req_process_fn=http_proc)
197
      self.assertEqual(result.keys(), [host])
198
      lhresp = result[host]
199
      self.assertFalse(lhresp.offline)
200
      self.assertEqual(lhresp.node, host)
201
      self.assert_(lhresp.fail_msg)
202
      self.assertFalse(lhresp.payload)
203
      self.assertEqual(lhresp.call, "version")
204
      self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
205
      self.assertEqual(http_proc.reqcount, 1)
206

    
207
  def _GetHttpErrorResponse(self, httperrnodes, failnodes, req):
208
    self.assertEqual(req.path, "/vg_list")
209
    self.assertEqual(req.port, 15165)
210

    
211
    if req.host in httperrnodes:
212
      req.success = False
213
      req.error = "Node set up for HTTP errors"
214

    
215
    elif req.host in failnodes:
216
      req.success = True
217
      req.resp_status_code = 404
218
      req.resp_body = serializer.DumpJson({
219
        "code": 404,
220
        "message": "Method not found",
221
        "explain": "Explanation goes here",
222
        })
223
    else:
224
      req.success = True
225
      req.resp_status_code = http.HTTP_OK
226
      req.resp_body = serializer.DumpJson((True, hash(req.host)))
227

    
228
  def testHttpError(self):
229
    nodes = ["uaf6pbbv%s" % i for i in range(50)]
230
    body = dict((n, "") for n in nodes)
231
    resolver = rpc._StaticResolver(nodes)
232

    
233
    httperrnodes = set(nodes[1::7])
234
    self.assertEqual(len(httperrnodes), 7)
235

    
236
    failnodes = set(nodes[2::3]) - httperrnodes
237
    self.assertEqual(len(failnodes), 14)
238

    
239
    self.assertEqual(len(set(nodes) - failnodes - httperrnodes), 29)
240

    
241
    proc = rpc._RpcProcessor(resolver, 15165)
242
    http_proc = \
243
      _FakeRequestProcessor(compat.partial(self._GetHttpErrorResponse,
244
                                           httperrnodes, failnodes))
245
    result = proc(nodes, "vg_list", body,
246
                  constants.RPC_TMO_URGENT, NotImplemented,
247
                  _req_process_fn=http_proc)
248
    self.assertEqual(sorted(result.keys()), sorted(nodes))
249

    
250
    for name in nodes:
251
      lhresp = result[name]
252
      self.assertFalse(lhresp.offline)
253
      self.assertEqual(lhresp.node, name)
254
      self.assertEqual(lhresp.call, "vg_list")
255

    
256
      if name in httperrnodes:
257
        self.assert_(lhresp.fail_msg)
258
        self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
259
      elif name in failnodes:
260
        self.assert_(lhresp.fail_msg)
261
        self.assertRaises(errors.OpPrereqError, lhresp.Raise, "failed",
262
                          prereq=True, ecode=errors.ECODE_INVAL)
263
      else:
264
        self.assertFalse(lhresp.fail_msg)
265
        self.assertEqual(lhresp.payload, hash(name))
266
        lhresp.Raise("should not raise")
267

    
268
    self.assertEqual(http_proc.reqcount, len(nodes))
269

    
270
  def _GetInvalidResponseA(self, req):
271
    self.assertEqual(req.path, "/version")
272
    req.success = True
273
    req.resp_status_code = http.HTTP_OK
274
    req.resp_body = serializer.DumpJson(("This", "is", "an", "invalid",
275
                                         "response", "!", 1, 2, 3))
276

    
277
  def _GetInvalidResponseB(self, req):
278
    self.assertEqual(req.path, "/version")
279
    req.success = True
280
    req.resp_status_code = http.HTTP_OK
281
    req.resp_body = serializer.DumpJson("invalid response")
282

    
283
  def testInvalidResponse(self):
284
    resolver = rpc._StaticResolver(["oqo7lanhly.example.com"])
285
    proc = rpc._RpcProcessor(resolver, 19978)
286

    
287
    for fn in [self._GetInvalidResponseA, self._GetInvalidResponseB]:
288
      http_proc = _FakeRequestProcessor(fn)
289
      host = "oqo7lanhly.example.com"
290
      body = {host: ""}
291
      result = proc([host], "version", body, 60, NotImplemented,
292
                    _req_process_fn=http_proc)
293
      self.assertEqual(result.keys(), [host])
294
      lhresp = result[host]
295
      self.assertFalse(lhresp.offline)
296
      self.assertEqual(lhresp.node, host)
297
      self.assert_(lhresp.fail_msg)
298
      self.assertFalse(lhresp.payload)
299
      self.assertEqual(lhresp.call, "version")
300
      self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
301
      self.assertEqual(http_proc.reqcount, 1)
302

    
303
  def _GetBodyTestResponse(self, test_data, req):
304
    self.assertEqual(req.host, "192.0.2.84")
305
    self.assertEqual(req.port, 18700)
306
    self.assertEqual(req.path, "/upload_file")
307
    self.assertEqual(serializer.LoadJson(req.post_data), test_data)
308
    req.success = True
309
    req.resp_status_code = http.HTTP_OK
310
    req.resp_body = serializer.DumpJson((True, None))
311

    
312
  def testResponseBody(self):
313
    test_data = {
314
      "Hello": "World",
315
      "xyz": range(10),
316
      }
317
    resolver = rpc._StaticResolver(["192.0.2.84"])
318
    http_proc = _FakeRequestProcessor(compat.partial(self._GetBodyTestResponse,
319
                                                     test_data))
320
    proc = rpc._RpcProcessor(resolver, 18700)
321
    host = "node19759"
322
    body = {host: serializer.DumpJson(test_data)}
323
    result = proc([host], "upload_file", body, 30, NotImplemented,
324
                  _req_process_fn=http_proc)
325
    self.assertEqual(result.keys(), [host])
326
    lhresp = result[host]
327
    self.assertFalse(lhresp.offline)
328
    self.assertEqual(lhresp.node, host)
329
    self.assertFalse(lhresp.fail_msg)
330
    self.assertEqual(lhresp.payload, None)
331
    self.assertEqual(lhresp.call, "upload_file")
332
    lhresp.Raise("should not raise")
333
    self.assertEqual(http_proc.reqcount, 1)
334

    
335

    
336
class TestSsconfResolver(unittest.TestCase):
337
  def testSsconfLookup(self):
338
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
339
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
340
    node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
341
    ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
342
    result = rpc._SsconfResolver(True, node_list, NotImplemented,
343
                                 ssc=ssc, nslookup_fn=NotImplemented)
344
    self.assertEqual(result, zip(node_list, addr_list, node_list))
345

    
346
  def testNsLookup(self):
347
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
348
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
349
    ssc = GetFakeSimpleStoreClass(lambda _: [])
350
    node_addr_map = dict(zip(node_list, addr_list))
351
    nslookup_fn = lambda name, family=None: node_addr_map.get(name)
352
    result = rpc._SsconfResolver(True, node_list, NotImplemented,
353
                                 ssc=ssc, nslookup_fn=nslookup_fn)
354
    self.assertEqual(result, zip(node_list, addr_list, node_list))
355

    
356
  def testDisabledSsconfIp(self):
357
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
358
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
359
    ssc = GetFakeSimpleStoreClass(_RaiseNotImplemented)
360
    node_addr_map = dict(zip(node_list, addr_list))
361
    nslookup_fn = lambda name, family=None: node_addr_map.get(name)
362
    result = rpc._SsconfResolver(False, node_list, NotImplemented,
363
                                 ssc=ssc, nslookup_fn=nslookup_fn)
364
    self.assertEqual(result, zip(node_list, addr_list, node_list))
365

    
366
  def testBothLookups(self):
367
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
368
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
369
    n = len(addr_list) / 2
370
    node_addr_list = [" ".join(t) for t in zip(node_list[n:], addr_list[n:])]
371
    ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
372
    node_addr_map = dict(zip(node_list[:n], addr_list[:n]))
373
    nslookup_fn = lambda name, family=None: node_addr_map.get(name)
374
    result = rpc._SsconfResolver(True, node_list, NotImplemented,
375
                                 ssc=ssc, nslookup_fn=nslookup_fn)
376
    self.assertEqual(result, zip(node_list, addr_list, node_list))
377

    
378
  def testAddressLookupIPv6(self):
379
    addr_list = ["2001:db8::%d" % n for n in range(0, 255, 11)]
380
    node_list = ["node%d.example.com" % n for n in range(0, 255, 11)]
381
    node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
382
    ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
383
    result = rpc._SsconfResolver(True, node_list, NotImplemented,
384
                                 ssc=ssc, nslookup_fn=NotImplemented)
385
    self.assertEqual(result, zip(node_list, addr_list, node_list))
386

    
387

    
388
class TestStaticResolver(unittest.TestCase):
389
  def test(self):
390
    addresses = ["192.0.2.%d" % n for n in range(0, 123, 7)]
391
    nodes = ["node%s.example.com" % n for n in range(0, 123, 7)]
392
    res = rpc._StaticResolver(addresses)
393
    self.assertEqual(res(nodes, NotImplemented), zip(nodes, addresses, nodes))
394

    
395
  def testWrongLength(self):
396
    res = rpc._StaticResolver([])
397
    self.assertRaises(AssertionError, res, ["abc"], NotImplemented)
398

    
399

    
400
class TestNodeConfigResolver(unittest.TestCase):
401
  @staticmethod
402
  def _GetSingleOnlineNode(uuid):
403
    assert uuid == "node90-uuid"
404
    return objects.Node(name="node90.example.com",
405
                        uuid=uuid,
406
                        offline=False,
407
                        primary_ip="192.0.2.90")
408

    
409
  @staticmethod
410
  def _GetSingleOfflineNode(uuid):
411
    assert uuid == "node100-uuid"
412
    return objects.Node(name="node100.example.com",
413
                        uuid=uuid,
414
                        offline=True,
415
                        primary_ip="192.0.2.100")
416

    
417
  def testSingleOnline(self):
418
    self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOnlineNode,
419
                                             NotImplemented,
420
                                             ["node90-uuid"], None),
421
                     [("node90.example.com", "192.0.2.90", "node90-uuid")])
422

    
423
  def testSingleOffline(self):
424
    self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOfflineNode,
425
                                             NotImplemented,
426
                                             ["node100-uuid"], None),
427
                     [("node100.example.com", rpc._OFFLINE, "node100-uuid")])
428

    
429
  def testSingleOfflineWithAcceptOffline(self):
430
    fn = self._GetSingleOfflineNode
431
    assert fn("node100-uuid").offline
432
    self.assertEqual(rpc._NodeConfigResolver(fn, NotImplemented,
433
                                             ["node100-uuid"],
434
                                             rpc_defs.ACCEPT_OFFLINE_NODE),
435
                     [("node100.example.com", "192.0.2.100", "node100-uuid")])
436
    for i in [False, True, "", "Hello", 0, 1]:
437
      self.assertRaises(AssertionError, rpc._NodeConfigResolver,
438
                        fn, NotImplemented, ["node100.example.com"], i)
439

    
440
  def testUnknownSingleNode(self):
441
    self.assertEqual(rpc._NodeConfigResolver(lambda _: None, NotImplemented,
442
                                             ["node110.example.com"], None),
443
                     [("node110.example.com", "node110.example.com",
444
                       "node110.example.com")])
445

    
446
  def testMultiEmpty(self):
447
    self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
448
                                             lambda: {},
449
                                             [], None),
450
                     [])
451

    
452
  def testMultiSomeOffline(self):
453
    nodes = dict(("node%s-uuid" % i,
454
                  objects.Node(name="node%s.example.com" % i,
455
                               offline=((i % 3) == 0),
456
                               primary_ip="192.0.2.%s" % i,
457
                               uuid="node%s-uuid" % i))
458
                  for i in range(1, 255))
459

    
460
    # Resolve no names
461
    self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
462
                                             lambda: nodes,
463
                                             [], None),
464
                     [])
465

    
466
    # Offline, online and unknown hosts
467
    self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
468
                                             lambda: nodes,
469
                                             ["node3-uuid",
470
                                              "node92-uuid",
471
                                              "node54-uuid",
472
                                              "unknown.example.com",],
473
                                             None), [
474
      ("node3.example.com", rpc._OFFLINE, "node3-uuid"),
475
      ("node92.example.com", "192.0.2.92", "node92-uuid"),
476
      ("node54.example.com", rpc._OFFLINE, "node54-uuid"),
477
      ("unknown.example.com", "unknown.example.com", "unknown.example.com"),
478
      ])
479

    
480

    
481
class TestCompress(unittest.TestCase):
482
  def test(self):
483
    for data in ["", "Hello", "Hello World!\nnew\nlines"]:
484
      self.assertEqual(rpc._Compress(data),
485
                       (constants.RPC_ENCODING_NONE, data))
486

    
487
    for data in [512 * " ", 5242 * "Hello World!\n"]:
488
      compressed = rpc._Compress(data)
489
      self.assertEqual(len(compressed), 2)
490
      self.assertEqual(backend._Decompress(compressed), data)
491

    
492
  def testDecompression(self):
493
    self.assertRaises(AssertionError, backend._Decompress, "")
494
    self.assertRaises(AssertionError, backend._Decompress, [""])
495
    self.assertRaises(AssertionError, backend._Decompress,
496
                      ("unknown compression", "data"))
497
    self.assertRaises(Exception, backend._Decompress,
498
                      (constants.RPC_ENCODING_ZLIB_BASE64, "invalid zlib data"))
499

    
500

    
501
class TestRpcClientBase(unittest.TestCase):
502
  def testNoHosts(self):
503
    cdef = ("test_call", NotImplemented, None, constants.RPC_TMO_SLOW, [],
504
            None, None, NotImplemented)
505
    http_proc = _FakeRequestProcessor(NotImplemented)
506
    client = rpc._RpcClientBase(rpc._StaticResolver([]), NotImplemented,
507
                                _req_process_fn=http_proc)
508
    self.assertEqual(client._Call(cdef, [], []), {})
509

    
510
    # Test wrong number of arguments
511
    self.assertRaises(errors.ProgrammerError, client._Call,
512
                      cdef, [], [0, 1, 2])
513

    
514
  def testTimeout(self):
515
    def _CalcTimeout((arg1, arg2)):
516
      return arg1 + arg2
517

    
518
    def _VerifyRequest(exp_timeout, req):
519
      self.assertEqual(req.read_timeout, exp_timeout)
520

    
521
      req.success = True
522
      req.resp_status_code = http.HTTP_OK
523
      req.resp_body = serializer.DumpJson((True, hex(req.read_timeout)))
524

    
525
    resolver = rpc._StaticResolver([
526
      "192.0.2.1",
527
      "192.0.2.2",
528
      ])
529

    
530
    nodes = [
531
      "node1.example.com",
532
      "node2.example.com",
533
      ]
534

    
535
    tests = [(100, None, 100), (30, None, 30)]
536
    tests.extend((_CalcTimeout, i, i + 300)
537
                 for i in [0, 5, 16485, 30516])
538

    
539
    for timeout, arg1, exp_timeout in tests:
540
      cdef = ("test_call", NotImplemented, None, timeout, [
541
        ("arg1", None, NotImplemented),
542
        ("arg2", None, NotImplemented),
543
        ], None, None, NotImplemented)
544

    
545
      http_proc = _FakeRequestProcessor(compat.partial(_VerifyRequest,
546
                                                       exp_timeout))
547
      client = rpc._RpcClientBase(resolver, NotImplemented,
548
                                  _req_process_fn=http_proc)
549
      result = client._Call(cdef, nodes, [arg1, 300])
550
      self.assertEqual(len(result), len(nodes))
551
      self.assertTrue(compat.all(not res.fail_msg and
552
                                 res.payload == hex(exp_timeout)
553
                                 for res in result.values()))
554

    
555
  def testArgumentEncoder(self):
556
    (AT1, AT2) = range(1, 3)
557

    
558
    resolver = rpc._StaticResolver([
559
      "192.0.2.5",
560
      "192.0.2.6",
561
      ])
562

    
563
    nodes = [
564
      "node5.example.com",
565
      "node6.example.com",
566
      ]
567

    
568
    encoders = {
569
      AT1: hex,
570
      AT2: hash,
571
      }
572

    
573
    cdef = ("test_call", NotImplemented, None, constants.RPC_TMO_NORMAL, [
574
      ("arg0", None, NotImplemented),
575
      ("arg1", AT1, NotImplemented),
576
      ("arg1", AT2, NotImplemented),
577
      ], None, None, NotImplemented)
578

    
579
    def _VerifyRequest(req):
580
      req.success = True
581
      req.resp_status_code = http.HTTP_OK
582
      req.resp_body = serializer.DumpJson((True, req.post_data))
583

    
584
    http_proc = _FakeRequestProcessor(_VerifyRequest)
585

    
586
    for num in [0, 3796, 9032119]:
587
      client = rpc._RpcClientBase(resolver, encoders.get,
588
                                  _req_process_fn=http_proc)
589
      result = client._Call(cdef, nodes, ["foo", num, "Hello%s" % num])
590
      self.assertEqual(len(result), len(nodes))
591
      for res in result.values():
592
        self.assertFalse(res.fail_msg)
593
        self.assertEqual(serializer.LoadJson(res.payload),
594
                         ["foo", hex(num), hash("Hello%s" % num)])
595

    
596
  def testPostProc(self):
597
    def _VerifyRequest(nums, req):
598
      req.success = True
599
      req.resp_status_code = http.HTTP_OK
600
      req.resp_body = serializer.DumpJson((True, nums))
601

    
602
    resolver = rpc._StaticResolver([
603
      "192.0.2.90",
604
      "192.0.2.95",
605
      ])
606

    
607
    nodes = [
608
      "node90.example.com",
609
      "node95.example.com",
610
      ]
611

    
612
    def _PostProc(res):
613
      self.assertFalse(res.fail_msg)
614
      res.payload = sum(res.payload)
615
      return res
616

    
617
    cdef = ("test_call", NotImplemented, None, constants.RPC_TMO_NORMAL, [],
618
            None, _PostProc, NotImplemented)
619

    
620
    # Seeded random generator
621
    rnd = random.Random(20299)
622

    
623
    for i in [0, 4, 74, 1391]:
624
      nums = [rnd.randint(0, 1000) for _ in range(i)]
625
      http_proc = _FakeRequestProcessor(compat.partial(_VerifyRequest, nums))
626
      client = rpc._RpcClientBase(resolver, NotImplemented,
627
                                  _req_process_fn=http_proc)
628
      result = client._Call(cdef, nodes, [])
629
      self.assertEqual(len(result), len(nodes))
630
      for res in result.values():
631
        self.assertFalse(res.fail_msg)
632
        self.assertEqual(res.payload, sum(nums))
633

    
634
  def testPreProc(self):
635
    def _VerifyRequest(req):
636
      req.success = True
637
      req.resp_status_code = http.HTTP_OK
638
      req.resp_body = serializer.DumpJson((True, req.post_data))
639

    
640
    resolver = rpc._StaticResolver([
641
      "192.0.2.30",
642
      "192.0.2.35",
643
      ])
644

    
645
    nodes = [
646
      "node30.example.com",
647
      "node35.example.com",
648
      ]
649

    
650
    def _PreProc(node, data):
651
      self.assertEqual(len(data), 1)
652
      return data[0] + node
653

    
654
    cdef = ("test_call", NotImplemented, None, constants.RPC_TMO_NORMAL, [
655
      ("arg0", None, NotImplemented),
656
      ], _PreProc, None, NotImplemented)
657

    
658
    http_proc = _FakeRequestProcessor(_VerifyRequest)
659
    client = rpc._RpcClientBase(resolver, NotImplemented,
660
                                _req_process_fn=http_proc)
661

    
662
    for prefix in ["foo", "bar", "baz"]:
663
      result = client._Call(cdef, nodes, [prefix])
664
      self.assertEqual(len(result), len(nodes))
665
      for (idx, (node, res)) in enumerate(result.items()):
666
        self.assertFalse(res.fail_msg)
667
        self.assertEqual(serializer.LoadJson(res.payload), prefix + node)
668

    
669
  def testResolverOptions(self):
670
    def _VerifyRequest(req):
671
      req.success = True
672
      req.resp_status_code = http.HTTP_OK
673
      req.resp_body = serializer.DumpJson((True, req.post_data))
674

    
675
    nodes = [
676
      "node30.example.com",
677
      "node35.example.com",
678
      ]
679

    
680
    def _Resolver(expected, hosts, options):
681
      self.assertEqual(hosts, nodes)
682
      self.assertEqual(options, expected)
683
      return zip(hosts, nodes, hosts)
684

    
685
    def _DynamicResolverOptions((arg0, )):
686
      return sum(arg0)
687

    
688
    tests = [
689
      (None, None, None),
690
      (rpc_defs.ACCEPT_OFFLINE_NODE, None, rpc_defs.ACCEPT_OFFLINE_NODE),
691
      (False, None, False),
692
      (True, None, True),
693
      (0, None, 0),
694
      (_DynamicResolverOptions, [1, 2, 3], 6),
695
      (_DynamicResolverOptions, range(4, 19), 165),
696
      ]
697

    
698
    for (resolver_opts, arg0, expected) in tests:
699
      cdef = ("test_call", NotImplemented, resolver_opts,
700
              constants.RPC_TMO_NORMAL, [
701
        ("arg0", None, NotImplemented),
702
        ], None, None, NotImplemented)
703

    
704
      http_proc = _FakeRequestProcessor(_VerifyRequest)
705

    
706
      client = rpc._RpcClientBase(compat.partial(_Resolver, expected),
707
                                  NotImplemented, _req_process_fn=http_proc)
708
      result = client._Call(cdef, nodes, [arg0])
709
      self.assertEqual(len(result), len(nodes))
710
      for (idx, (node, res)) in enumerate(result.items()):
711
        self.assertFalse(res.fail_msg)
712

    
713

    
714
class _FakeConfigForRpcRunner:
715
  GetAllNodesInfo = NotImplemented
716

    
717
  def __init__(self, cluster=NotImplemented):
718
    self._cluster = cluster
719

    
720
  def GetNodeInfo(self, name):
721
    return objects.Node(name=name)
722

    
723
  def GetClusterInfo(self):
724
    return self._cluster
725

    
726
  def GetInstanceDiskParams(self, _):
727
    return constants.DISK_DT_DEFAULTS
728

    
729

    
730
class TestRpcRunner(unittest.TestCase):
731
  def testUploadFile(self):
732
    data = 1779 * "Hello World\n"
733

    
734
    tmpfile = tempfile.NamedTemporaryFile()
735
    tmpfile.write(data)
736
    tmpfile.flush()
737
    st = os.stat(tmpfile.name)
738

    
739
    def _VerifyRequest(req):
740
      (uldata, ) = serializer.LoadJson(req.post_data)
741
      self.assertEqual(len(uldata), 7)
742
      self.assertEqual(uldata[0], tmpfile.name)
743
      self.assertEqual(list(uldata[1]), list(rpc._Compress(data)))
744
      self.assertEqual(uldata[2], st.st_mode)
745
      self.assertEqual(uldata[3], "user%s" % os.getuid())
746
      self.assertEqual(uldata[4], "group%s" % os.getgid())
747
      self.assertTrue(uldata[5] is not None)
748
      self.assertEqual(uldata[6], st.st_mtime)
749

    
750
      req.success = True
751
      req.resp_status_code = http.HTTP_OK
752
      req.resp_body = serializer.DumpJson((True, None))
753

    
754
    http_proc = _FakeRequestProcessor(_VerifyRequest)
755

    
756
    std_runner = rpc.RpcRunner(_FakeConfigForRpcRunner(), None,
757
                               _req_process_fn=http_proc,
758
                               _getents=mocks.FakeGetentResolver)
759

    
760
    cfg_runner = rpc.ConfigRunner(None, ["192.0.2.13"],
761
                                  _req_process_fn=http_proc,
762
                                  _getents=mocks.FakeGetentResolver)
763

    
764
    nodes = [
765
      "node1.example.com",
766
      ]
767

    
768
    for runner in [std_runner, cfg_runner]:
769
      result = runner.call_upload_file(nodes, tmpfile.name)
770
      self.assertEqual(len(result), len(nodes))
771
      for (idx, (node, res)) in enumerate(result.items()):
772
        self.assertFalse(res.fail_msg)
773

    
774
  def testEncodeInstance(self):
775
    cluster = objects.Cluster(hvparams={
776
      constants.HT_KVM: {
777
        constants.HV_CDROM_IMAGE_PATH: "foo",
778
        },
779
      },
780
      beparams={
781
        constants.PP_DEFAULT: {
782
          constants.BE_MAXMEM: 8192,
783
          },
784
        },
785
      os_hvp={},
786
      osparams={
787
        "linux": {
788
          "role": "unknown",
789
          },
790
        })
791
    cluster.UpgradeConfig()
792

    
793
    inst = objects.Instance(name="inst1.example.com",
794
      hypervisor=constants.HT_KVM,
795
      os="linux",
796
      hvparams={
797
        constants.HV_CDROM_IMAGE_PATH: "bar",
798
        constants.HV_ROOT_PATH: "/tmp",
799
        },
800
      beparams={
801
        constants.BE_MINMEM: 128,
802
        constants.BE_MAXMEM: 256,
803
        },
804
      nics=[
805
        objects.NIC(nicparams={
806
          constants.NIC_MODE: "mymode",
807
          }),
808
        ],
809
      disk_template=constants.DT_PLAIN,
810
      disks=[
811
        objects.Disk(dev_type=constants.DT_PLAIN, size=4096,
812
                     logical_id=("vg", "disk6120")),
813
        objects.Disk(dev_type=constants.DT_PLAIN, size=1024,
814
                     logical_id=("vg", "disk8508")),
815
        ])
816
    inst.UpgradeConfig()
817

    
818
    cfg = _FakeConfigForRpcRunner(cluster=cluster)
819
    runner = rpc.RpcRunner(cfg, None,
820
                           _req_process_fn=NotImplemented,
821
                           _getents=mocks.FakeGetentResolver)
822

    
823
    def _CheckBasics(result):
824
      self.assertEqual(result["name"], "inst1.example.com")
825
      self.assertEqual(result["os"], "linux")
826
      self.assertEqual(result["beparams"][constants.BE_MINMEM], 128)
827
      self.assertEqual(len(result["nics"]), 1)
828
      self.assertEqual(result["nics"][0]["nicparams"][constants.NIC_MODE],
829
                       "mymode")
830

    
831
    # Generic object serialization
832
    result = runner._encoder((rpc_defs.ED_OBJECT_DICT, inst))
833
    _CheckBasics(result)
834
    self.assertEqual(len(result["hvparams"]), 2)
835

    
836
    result = runner._encoder((rpc_defs.ED_OBJECT_DICT_LIST, 5 * [inst]))
837
    map(_CheckBasics, result)
838
    map(lambda r: self.assertEqual(len(r["hvparams"]), 2), result)
839

    
840
    # Just an instance
841
    result = runner._encoder((rpc_defs.ED_INST_DICT, inst))
842
    _CheckBasics(result)
843
    self.assertEqual(result["beparams"][constants.BE_MAXMEM], 256)
844
    self.assertEqual(result["hvparams"][constants.HV_CDROM_IMAGE_PATH], "bar")
845
    self.assertEqual(result["hvparams"][constants.HV_ROOT_PATH], "/tmp")
846
    self.assertEqual(result["osparams"], {
847
      "role": "unknown",
848
      })
849
    self.assertEqual(len(result["hvparams"]),
850
                     len(constants.HVC_DEFAULTS[constants.HT_KVM]))
851

    
852
    # Instance with OS parameters
853
    result = runner._encoder((rpc_defs.ED_INST_DICT_OSP_DP, (inst, {
854
      "role": "webserver",
855
      "other": "field",
856
      })))
857
    _CheckBasics(result)
858
    self.assertEqual(result["beparams"][constants.BE_MAXMEM], 256)
859
    self.assertEqual(result["hvparams"][constants.HV_CDROM_IMAGE_PATH], "bar")
860
    self.assertEqual(result["hvparams"][constants.HV_ROOT_PATH], "/tmp")
861
    self.assertEqual(result["osparams"], {
862
      "role": "webserver",
863
      "other": "field",
864
      })
865

    
866
    # Instance with hypervisor and backend parameters
867
    result = runner._encoder((rpc_defs.ED_INST_DICT_HVP_BEP_DP, (inst, {
868
      constants.HV_BOOT_ORDER: "xyz",
869
      }, {
870
      constants.BE_VCPUS: 100,
871
      constants.BE_MAXMEM: 4096,
872
      })))
873
    _CheckBasics(result)
874
    self.assertEqual(result["beparams"][constants.BE_MAXMEM], 4096)
875
    self.assertEqual(result["beparams"][constants.BE_VCPUS], 100)
876
    self.assertEqual(result["hvparams"][constants.HV_BOOT_ORDER], "xyz")
877
    self.assertEqual(result["disks"], [{
878
      "dev_type": constants.DT_PLAIN,
879
      "size": 4096,
880
      "logical_id": ("vg", "disk6120"),
881
      "params": constants.DISK_DT_DEFAULTS[inst.disk_template],
882
      }, {
883
      "dev_type": constants.DT_PLAIN,
884
      "size": 1024,
885
      "logical_id": ("vg", "disk8508"),
886
      "params": constants.DISK_DT_DEFAULTS[inst.disk_template],
887
      }])
888

    
889
    self.assertTrue(compat.all(disk.params == {} for disk in inst.disks),
890
                    msg="Configuration objects were modified")
891

    
892

    
893
class TestLegacyNodeInfo(unittest.TestCase):
894
  KEY_BOOT = "bootid"
895
  KEY_VG0 = "name"
896
  KEY_VG1 = "storage_free"
897
  KEY_VG2 = "storage_size"
898
  KEY_HV = "cpu_count"
899
  KEY_SP1 = "spindles_free"
900
  KEY_SP2 = "spindles_total"
901
  KEY_ST = "type" # key for storage type
902
  VAL_BOOT = 0
903
  VAL_VG0 = "xy"
904
  VAL_VG1 = 11
905
  VAL_VG2 = 12
906
  VAL_VG3 = "lvm-vg"
907
  VAL_HV = 2
908
  VAL_SP0 = "ab"
909
  VAL_SP1 = 31
910
  VAL_SP2 = 32
911
  VAL_SP3 = "lvm-pv"
912
  DICT_VG = {
913
    KEY_VG0: VAL_VG0,
914
    KEY_VG1: VAL_VG1,
915
    KEY_VG2: VAL_VG2,
916
    KEY_ST: VAL_VG3,
917
    }
918
  DICT_HV = {KEY_HV: VAL_HV}
919
  DICT_SP = {
920
    KEY_ST: VAL_SP3,
921
    KEY_VG0: VAL_SP0,
922
    KEY_VG1: VAL_SP1,
923
    KEY_VG2: VAL_SP2,
924
    }
925
  STD_LST = [VAL_BOOT, [DICT_VG, DICT_SP], [DICT_HV]]
926
  STD_DICT = {
927
    KEY_BOOT: VAL_BOOT,
928
    KEY_VG0: VAL_VG0,
929
    KEY_VG1: VAL_VG1,
930
    KEY_VG2: VAL_VG2,
931
    KEY_HV: VAL_HV,
932
    }
933

    
934
  def testStandard(self):
935
    result = rpc.MakeLegacyNodeInfo(self.STD_LST)
936
    self.assertEqual(result, self.STD_DICT)
937

    
938
  def testSpindlesRequired(self):
939
    my_lst = [self.VAL_BOOT, [], [self.DICT_HV]]
940
    self.assertRaises(errors.OpExecError, rpc.MakeLegacyNodeInfo, my_lst,
941
        require_spindles=True)
942

    
943
  def testNoSpindlesRequired(self):
944
    my_lst = [self.VAL_BOOT, [], [self.DICT_HV]]
945
    result = rpc.MakeLegacyNodeInfo(my_lst, require_spindles = False)
946
    self.assertEqual(result, {self.KEY_BOOT: self.VAL_BOOT,
947
                              self.KEY_HV: self.VAL_HV})
948
    result = rpc.MakeLegacyNodeInfo(self.STD_LST, require_spindles = False)
949
    self.assertEqual(result, self.STD_DICT)
950

    
951

    
952
class TestAddDefaultStorageInfoToLegacyNodeInfo(unittest.TestCase):
953

    
954
  def setUp(self):
955
    self.free_storage_file = 23
956
    self.total_storage_file = 42
957
    self.free_storage_lvm = 69
958
    self.total_storage_lvm = 666
959
    self.node_info = [{"name": "myfile",
960
                       "type": constants.ST_FILE,
961
                       "storage_free": self.free_storage_file,
962
                       "storage_size": self.total_storage_file},
963
                      {"name": "myvg",
964
                       "type": constants.ST_LVM_VG,
965
                       "storage_free": self.free_storage_lvm,
966
                       "storage_size": self.total_storage_lvm},
967
                      {"name": "myspindle",
968
                       "type": constants.ST_LVM_PV,
969
                       "storage_free": 33,
970
                       "storage_size": 44}]
971

    
972
  def testAddDefaultStorageInfoToLegacyNodeInfo(self):
973
    result = {}
974
    rpc._AddDefaultStorageInfoToLegacyNodeInfo(result, self.node_info)
975
    self.assertEqual(self.free_storage_file, result["storage_free"])
976
    self.assertEqual(self.total_storage_file, result["storage_size"])
977

    
978
  def testAddDefaultStorageInfoToLegacyNodeInfoNoDefaults(self):
979
    result = {}
980
    rpc._AddDefaultStorageInfoToLegacyNodeInfo(result, self.node_info[-1:])
981
    self.assertFalse("storage_free" in result)
982
    self.assertFalse("storage_size" in result)
983

    
984

    
985
if __name__ == "__main__":
986
  testutils.GanetiTestProgram()