Statistics
| Branch: | Tag: | Revision:

root / test / py / ganeti.rpc_unittest.py @ 4869595d

History | View | Annotate | Download (33.6 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2010, 2011, 2012, 2013 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for testing ganeti.rpc"""
23

    
24
import os
25
import sys
26
import unittest
27
import random
28
import tempfile
29

    
30
from ganeti import constants
31
from ganeti import compat
32
from ganeti.rpc import node as rpc
33
from ganeti import rpc_defs
34
from ganeti import http
35
from ganeti import errors
36
from ganeti import serializer
37
from ganeti import objects
38
from ganeti import backend
39

    
40
import testutils
41
import mocks
42

    
43

    
44
class _FakeRequestProcessor:
45
  def __init__(self, response_fn):
46
    self._response_fn = response_fn
47
    self.reqcount = 0
48

    
49
  def __call__(self, reqs, lock_monitor_cb=None):
50
    assert lock_monitor_cb is None or callable(lock_monitor_cb)
51
    for req in reqs:
52
      self.reqcount += 1
53
      self._response_fn(req)
54

    
55

    
56
def GetFakeSimpleStoreClass(fn):
57
  class FakeSimpleStore:
58
    GetNodePrimaryIPList = fn
59
    GetPrimaryIPFamily = lambda _: None
60

    
61
  return FakeSimpleStore
62

    
63

    
64
def _RaiseNotImplemented():
65
  """Simple wrapper to raise NotImplementedError.
66

67
  """
68
  raise NotImplementedError
69

    
70

    
71
class TestRpcProcessor(unittest.TestCase):
72
  def _FakeAddressLookup(self, map):
73
    return lambda node_list: [map.get(node) for node in node_list]
74

    
75
  def _GetVersionResponse(self, req):
76
    self.assertEqual(req.host, "127.0.0.1")
77
    self.assertEqual(req.port, 24094)
78
    self.assertEqual(req.path, "/version")
79
    self.assertEqual(req.read_timeout, constants.RPC_TMO_URGENT)
80
    req.success = True
81
    req.resp_status_code = http.HTTP_OK
82
    req.resp_body = serializer.DumpJson((True, 123))
83

    
84
  def testVersionSuccess(self):
85
    resolver = rpc._StaticResolver(["127.0.0.1"])
86
    http_proc = _FakeRequestProcessor(self._GetVersionResponse)
87
    proc = rpc._RpcProcessor(resolver, 24094)
88
    result = proc(["localhost"], "version", {"localhost": ""}, 60,
89
                  NotImplemented, _req_process_fn=http_proc)
90
    self.assertEqual(result.keys(), ["localhost"])
91
    lhresp = result["localhost"]
92
    self.assertFalse(lhresp.offline)
93
    self.assertEqual(lhresp.node, "localhost")
94
    self.assertFalse(lhresp.fail_msg)
95
    self.assertEqual(lhresp.payload, 123)
96
    self.assertEqual(lhresp.call, "version")
97
    lhresp.Raise("should not raise")
98
    self.assertEqual(http_proc.reqcount, 1)
99

    
100
  def _ReadTimeoutResponse(self, req):
101
    self.assertEqual(req.host, "192.0.2.13")
102
    self.assertEqual(req.port, 19176)
103
    self.assertEqual(req.path, "/version")
104
    self.assertEqual(req.read_timeout, 12356)
105
    req.success = True
106
    req.resp_status_code = http.HTTP_OK
107
    req.resp_body = serializer.DumpJson((True, -1))
108

    
109
  def testReadTimeout(self):
110
    resolver = rpc._StaticResolver(["192.0.2.13"])
111
    http_proc = _FakeRequestProcessor(self._ReadTimeoutResponse)
112
    proc = rpc._RpcProcessor(resolver, 19176)
113
    host = "node31856"
114
    body = {host: ""}
115
    result = proc([host], "version", body, 12356, NotImplemented,
116
                  _req_process_fn=http_proc)
117
    self.assertEqual(result.keys(), [host])
118
    lhresp = result[host]
119
    self.assertFalse(lhresp.offline)
120
    self.assertEqual(lhresp.node, host)
121
    self.assertFalse(lhresp.fail_msg)
122
    self.assertEqual(lhresp.payload, -1)
123
    self.assertEqual(lhresp.call, "version")
124
    lhresp.Raise("should not raise")
125
    self.assertEqual(http_proc.reqcount, 1)
126

    
127
  def testOfflineNode(self):
128
    resolver = rpc._StaticResolver([rpc._OFFLINE])
129
    http_proc = _FakeRequestProcessor(NotImplemented)
130
    proc = rpc._RpcProcessor(resolver, 30668)
131
    host = "n17296"
132
    body = {host: ""}
133
    result = proc([host], "version", body, 60, NotImplemented,
134
                  _req_process_fn=http_proc)
135
    self.assertEqual(result.keys(), [host])
136
    lhresp = result[host]
137
    self.assertTrue(lhresp.offline)
138
    self.assertEqual(lhresp.node, host)
139
    self.assertTrue(lhresp.fail_msg)
140
    self.assertFalse(lhresp.payload)
141
    self.assertEqual(lhresp.call, "version")
142

    
143
    # With a message
144
    self.assertRaises(errors.OpExecError, lhresp.Raise, "should raise")
145

    
146
    # No message
147
    self.assertRaises(errors.OpExecError, lhresp.Raise, None)
148

    
149
    self.assertEqual(http_proc.reqcount, 0)
150

    
151
  def _GetMultiVersionResponse(self, req):
152
    self.assert_(req.host.startswith("node"))
153
    self.assertEqual(req.port, 23245)
154
    self.assertEqual(req.path, "/version")
155
    req.success = True
156
    req.resp_status_code = http.HTTP_OK
157
    req.resp_body = serializer.DumpJson((True, 987))
158

    
159
  def testMultiVersionSuccess(self):
160
    nodes = ["node%s" % i for i in range(50)]
161
    body = dict((n, "") for n in nodes)
162
    resolver = rpc._StaticResolver(nodes)
163
    http_proc = _FakeRequestProcessor(self._GetMultiVersionResponse)
164
    proc = rpc._RpcProcessor(resolver, 23245)
165
    result = proc(nodes, "version", body, 60, NotImplemented,
166
                  _req_process_fn=http_proc)
167
    self.assertEqual(sorted(result.keys()), sorted(nodes))
168

    
169
    for name in nodes:
170
      lhresp = result[name]
171
      self.assertFalse(lhresp.offline)
172
      self.assertEqual(lhresp.node, name)
173
      self.assertFalse(lhresp.fail_msg)
174
      self.assertEqual(lhresp.payload, 987)
175
      self.assertEqual(lhresp.call, "version")
176
      lhresp.Raise("should not raise")
177

    
178
    self.assertEqual(http_proc.reqcount, len(nodes))
179

    
180
  def _GetVersionResponseFail(self, errinfo, req):
181
    self.assertEqual(req.path, "/version")
182
    req.success = True
183
    req.resp_status_code = http.HTTP_OK
184
    req.resp_body = serializer.DumpJson((False, errinfo))
185

    
186
  def testVersionFailure(self):
187
    resolver = rpc._StaticResolver(["aef9ur4i.example.com"])
188
    proc = rpc._RpcProcessor(resolver, 5903)
189
    for errinfo in [None, "Unknown error"]:
190
      http_proc = \
191
        _FakeRequestProcessor(compat.partial(self._GetVersionResponseFail,
192
                                             errinfo))
193
      host = "aef9ur4i.example.com"
194
      body = {host: ""}
195
      result = proc(body.keys(), "version", body, 60, NotImplemented,
196
                    _req_process_fn=http_proc)
197
      self.assertEqual(result.keys(), [host])
198
      lhresp = result[host]
199
      self.assertFalse(lhresp.offline)
200
      self.assertEqual(lhresp.node, host)
201
      self.assert_(lhresp.fail_msg)
202
      self.assertFalse(lhresp.payload)
203
      self.assertEqual(lhresp.call, "version")
204
      self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
205
      self.assertEqual(http_proc.reqcount, 1)
206

    
207
  def _GetHttpErrorResponse(self, httperrnodes, failnodes, req):
208
    self.assertEqual(req.path, "/vg_list")
209
    self.assertEqual(req.port, 15165)
210

    
211
    if req.host in httperrnodes:
212
      req.success = False
213
      req.error = "Node set up for HTTP errors"
214

    
215
    elif req.host in failnodes:
216
      req.success = True
217
      req.resp_status_code = 404
218
      req.resp_body = serializer.DumpJson({
219
        "code": 404,
220
        "message": "Method not found",
221
        "explain": "Explanation goes here",
222
        })
223
    else:
224
      req.success = True
225
      req.resp_status_code = http.HTTP_OK
226
      req.resp_body = serializer.DumpJson((True, hash(req.host)))
227

    
228
  def testHttpError(self):
229
    nodes = ["uaf6pbbv%s" % i for i in range(50)]
230
    body = dict((n, "") for n in nodes)
231
    resolver = rpc._StaticResolver(nodes)
232

    
233
    httperrnodes = set(nodes[1::7])
234
    self.assertEqual(len(httperrnodes), 7)
235

    
236
    failnodes = set(nodes[2::3]) - httperrnodes
237
    self.assertEqual(len(failnodes), 14)
238

    
239
    self.assertEqual(len(set(nodes) - failnodes - httperrnodes), 29)
240

    
241
    proc = rpc._RpcProcessor(resolver, 15165)
242
    http_proc = \
243
      _FakeRequestProcessor(compat.partial(self._GetHttpErrorResponse,
244
                                           httperrnodes, failnodes))
245
    result = proc(nodes, "vg_list", body,
246
                  constants.RPC_TMO_URGENT, NotImplemented,
247
                  _req_process_fn=http_proc)
248
    self.assertEqual(sorted(result.keys()), sorted(nodes))
249

    
250
    for name in nodes:
251
      lhresp = result[name]
252
      self.assertFalse(lhresp.offline)
253
      self.assertEqual(lhresp.node, name)
254
      self.assertEqual(lhresp.call, "vg_list")
255

    
256
      if name in httperrnodes:
257
        self.assert_(lhresp.fail_msg)
258
        self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
259
      elif name in failnodes:
260
        self.assert_(lhresp.fail_msg)
261
        self.assertRaises(errors.OpPrereqError, lhresp.Raise, "failed",
262
                          prereq=True, ecode=errors.ECODE_INVAL)
263
      else:
264
        self.assertFalse(lhresp.fail_msg)
265
        self.assertEqual(lhresp.payload, hash(name))
266
        lhresp.Raise("should not raise")
267

    
268
    self.assertEqual(http_proc.reqcount, len(nodes))
269

    
270
  def _GetInvalidResponseA(self, req):
271
    self.assertEqual(req.path, "/version")
272
    req.success = True
273
    req.resp_status_code = http.HTTP_OK
274
    req.resp_body = serializer.DumpJson(("This", "is", "an", "invalid",
275
                                         "response", "!", 1, 2, 3))
276

    
277
  def _GetInvalidResponseB(self, req):
278
    self.assertEqual(req.path, "/version")
279
    req.success = True
280
    req.resp_status_code = http.HTTP_OK
281
    req.resp_body = serializer.DumpJson("invalid response")
282

    
283
  def testInvalidResponse(self):
284
    resolver = rpc._StaticResolver(["oqo7lanhly.example.com"])
285
    proc = rpc._RpcProcessor(resolver, 19978)
286

    
287
    for fn in [self._GetInvalidResponseA, self._GetInvalidResponseB]:
288
      http_proc = _FakeRequestProcessor(fn)
289
      host = "oqo7lanhly.example.com"
290
      body = {host: ""}
291
      result = proc([host], "version", body, 60, NotImplemented,
292
                    _req_process_fn=http_proc)
293
      self.assertEqual(result.keys(), [host])
294
      lhresp = result[host]
295
      self.assertFalse(lhresp.offline)
296
      self.assertEqual(lhresp.node, host)
297
      self.assert_(lhresp.fail_msg)
298
      self.assertFalse(lhresp.payload)
299
      self.assertEqual(lhresp.call, "version")
300
      self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
301
      self.assertEqual(http_proc.reqcount, 1)
302

    
303
  def _GetBodyTestResponse(self, test_data, req):
304
    self.assertEqual(req.host, "192.0.2.84")
305
    self.assertEqual(req.port, 18700)
306
    self.assertEqual(req.path, "/upload_file")
307
    self.assertEqual(serializer.LoadJson(req.post_data), test_data)
308
    req.success = True
309
    req.resp_status_code = http.HTTP_OK
310
    req.resp_body = serializer.DumpJson((True, None))
311

    
312
  def testResponseBody(self):
313
    test_data = {
314
      "Hello": "World",
315
      "xyz": range(10),
316
      }
317
    resolver = rpc._StaticResolver(["192.0.2.84"])
318
    http_proc = _FakeRequestProcessor(compat.partial(self._GetBodyTestResponse,
319
                                                     test_data))
320
    proc = rpc._RpcProcessor(resolver, 18700)
321
    host = "node19759"
322
    body = {host: serializer.DumpJson(test_data)}
323
    result = proc([host], "upload_file", body, 30, NotImplemented,
324
                  _req_process_fn=http_proc)
325
    self.assertEqual(result.keys(), [host])
326
    lhresp = result[host]
327
    self.assertFalse(lhresp.offline)
328
    self.assertEqual(lhresp.node, host)
329
    self.assertFalse(lhresp.fail_msg)
330
    self.assertEqual(lhresp.payload, None)
331
    self.assertEqual(lhresp.call, "upload_file")
332
    lhresp.Raise("should not raise")
333
    self.assertEqual(http_proc.reqcount, 1)
334

    
335

    
336
class TestSsconfResolver(unittest.TestCase):
337
  def testSsconfLookup(self):
338
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
339
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
340
    node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
341
    ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
342
    result = rpc._SsconfResolver(True, node_list, NotImplemented,
343
                                 ssc=ssc, nslookup_fn=NotImplemented)
344
    self.assertEqual(result, zip(node_list, addr_list, node_list))
345

    
346
  def testNsLookup(self):
347
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
348
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
349
    ssc = GetFakeSimpleStoreClass(lambda _: [])
350
    node_addr_map = dict(zip(node_list, addr_list))
351
    nslookup_fn = lambda name, family=None: node_addr_map.get(name)
352
    result = rpc._SsconfResolver(True, node_list, NotImplemented,
353
                                 ssc=ssc, nslookup_fn=nslookup_fn)
354
    self.assertEqual(result, zip(node_list, addr_list, node_list))
355

    
356
  def testDisabledSsconfIp(self):
357
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
358
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
359
    ssc = GetFakeSimpleStoreClass(_RaiseNotImplemented)
360
    node_addr_map = dict(zip(node_list, addr_list))
361
    nslookup_fn = lambda name, family=None: node_addr_map.get(name)
362
    result = rpc._SsconfResolver(False, node_list, NotImplemented,
363
                                 ssc=ssc, nslookup_fn=nslookup_fn)
364
    self.assertEqual(result, zip(node_list, addr_list, node_list))
365

    
366
  def testBothLookups(self):
367
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
368
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
369
    n = len(addr_list) / 2
370
    node_addr_list = [" ".join(t) for t in zip(node_list[n:], addr_list[n:])]
371
    ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
372
    node_addr_map = dict(zip(node_list[:n], addr_list[:n]))
373
    nslookup_fn = lambda name, family=None: node_addr_map.get(name)
374
    result = rpc._SsconfResolver(True, node_list, NotImplemented,
375
                                 ssc=ssc, nslookup_fn=nslookup_fn)
376
    self.assertEqual(result, zip(node_list, addr_list, node_list))
377

    
378
  def testAddressLookupIPv6(self):
379
    addr_list = ["2001:db8::%d" % n for n in range(0, 255, 11)]
380
    node_list = ["node%d.example.com" % n for n in range(0, 255, 11)]
381
    node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
382
    ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
383
    result = rpc._SsconfResolver(True, node_list, NotImplemented,
384
                                 ssc=ssc, nslookup_fn=NotImplemented)
385
    self.assertEqual(result, zip(node_list, addr_list, node_list))
386

    
387

    
388
class TestStaticResolver(unittest.TestCase):
389
  def test(self):
390
    addresses = ["192.0.2.%d" % n for n in range(0, 123, 7)]
391
    nodes = ["node%s.example.com" % n for n in range(0, 123, 7)]
392
    res = rpc._StaticResolver(addresses)
393
    self.assertEqual(res(nodes, NotImplemented), zip(nodes, addresses, nodes))
394

    
395
  def testWrongLength(self):
396
    res = rpc._StaticResolver([])
397
    self.assertRaises(AssertionError, res, ["abc"], NotImplemented)
398

    
399

    
400
class TestNodeConfigResolver(unittest.TestCase):
401
  @staticmethod
402
  def _GetSingleOnlineNode(uuid):
403
    assert uuid == "node90-uuid"
404
    return objects.Node(name="node90.example.com",
405
                        uuid=uuid,
406
                        offline=False,
407
                        primary_ip="192.0.2.90")
408

    
409
  @staticmethod
410
  def _GetSingleOfflineNode(uuid):
411
    assert uuid == "node100-uuid"
412
    return objects.Node(name="node100.example.com",
413
                        uuid=uuid,
414
                        offline=True,
415
                        primary_ip="192.0.2.100")
416

    
417
  def testSingleOnline(self):
418
    self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOnlineNode,
419
                                             NotImplemented,
420
                                             ["node90-uuid"], None),
421
                     [("node90.example.com", "192.0.2.90", "node90-uuid")])
422

    
423
  def testSingleOffline(self):
424
    self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOfflineNode,
425
                                             NotImplemented,
426
                                             ["node100-uuid"], None),
427
                     [("node100.example.com", rpc._OFFLINE, "node100-uuid")])
428

    
429
  def testSingleOfflineWithAcceptOffline(self):
430
    fn = self._GetSingleOfflineNode
431
    assert fn("node100-uuid").offline
432
    self.assertEqual(rpc._NodeConfigResolver(fn, NotImplemented,
433
                                             ["node100-uuid"],
434
                                             rpc_defs.ACCEPT_OFFLINE_NODE),
435
                     [("node100.example.com", "192.0.2.100", "node100-uuid")])
436
    for i in [False, True, "", "Hello", 0, 1]:
437
      self.assertRaises(AssertionError, rpc._NodeConfigResolver,
438
                        fn, NotImplemented, ["node100.example.com"], i)
439

    
440
  def testUnknownSingleNode(self):
441
    self.assertEqual(rpc._NodeConfigResolver(lambda _: None, NotImplemented,
442
                                             ["node110.example.com"], None),
443
                     [("node110.example.com", "node110.example.com",
444
                       "node110.example.com")])
445

    
446
  def testMultiEmpty(self):
447
    self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
448
                                             lambda: {},
449
                                             [], None),
450
                     [])
451

    
452
  def testMultiSomeOffline(self):
453
    nodes = dict(("node%s-uuid" % i,
454
                  objects.Node(name="node%s.example.com" % i,
455
                               offline=((i % 3) == 0),
456
                               primary_ip="192.0.2.%s" % i,
457
                               uuid="node%s-uuid" % i))
458
                  for i in range(1, 255))
459

    
460
    # Resolve no names
461
    self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
462
                                             lambda: nodes,
463
                                             [], None),
464
                     [])
465

    
466
    # Offline, online and unknown hosts
467
    self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
468
                                             lambda: nodes,
469
                                             ["node3-uuid",
470
                                              "node92-uuid",
471
                                              "node54-uuid",
472
                                              "unknown.example.com",],
473
                                             None), [
474
      ("node3.example.com", rpc._OFFLINE, "node3-uuid"),
475
      ("node92.example.com", "192.0.2.92", "node92-uuid"),
476
      ("node54.example.com", rpc._OFFLINE, "node54-uuid"),
477
      ("unknown.example.com", "unknown.example.com", "unknown.example.com"),
478
      ])
479

    
480

    
481
class TestCompress(unittest.TestCase):
482
  def test(self):
483
    for data in ["", "Hello", "Hello World!\nnew\nlines"]:
484
      self.assertEqual(rpc._Compress(NotImplemented, data),
485
                       (constants.RPC_ENCODING_NONE, data))
486

    
487
    for data in [512 * " ", 5242 * "Hello World!\n"]:
488
      compressed = rpc._Compress(NotImplemented, data)
489
      self.assertEqual(len(compressed), 2)
490
      self.assertEqual(backend._Decompress(compressed), data)
491

    
492
  def testDecompression(self):
493
    self.assertRaises(AssertionError, backend._Decompress, "")
494
    self.assertRaises(AssertionError, backend._Decompress, [""])
495
    self.assertRaises(AssertionError, backend._Decompress,
496
                      ("unknown compression", "data"))
497
    self.assertRaises(Exception, backend._Decompress,
498
                      (constants.RPC_ENCODING_ZLIB_BASE64, "invalid zlib data"))
499

    
500

    
501
class TestRpcClientBase(unittest.TestCase):
502
  def testNoHosts(self):
503
    cdef = ("test_call", NotImplemented, None, constants.RPC_TMO_SLOW, [],
504
            None, None, NotImplemented)
505
    http_proc = _FakeRequestProcessor(NotImplemented)
506
    client = rpc._RpcClientBase(rpc._StaticResolver([]), NotImplemented,
507
                                _req_process_fn=http_proc)
508
    self.assertEqual(client._Call(cdef, [], []), {})
509

    
510
    # Test wrong number of arguments
511
    self.assertRaises(errors.ProgrammerError, client._Call,
512
                      cdef, [], [0, 1, 2])
513

    
514
  def testTimeout(self):
515
    def _CalcTimeout((arg1, arg2)):
516
      return arg1 + arg2
517

    
518
    def _VerifyRequest(exp_timeout, req):
519
      self.assertEqual(req.read_timeout, exp_timeout)
520

    
521
      req.success = True
522
      req.resp_status_code = http.HTTP_OK
523
      req.resp_body = serializer.DumpJson((True, hex(req.read_timeout)))
524

    
525
    resolver = rpc._StaticResolver([
526
      "192.0.2.1",
527
      "192.0.2.2",
528
      ])
529

    
530
    nodes = [
531
      "node1.example.com",
532
      "node2.example.com",
533
      ]
534

    
535
    tests = [(100, None, 100), (30, None, 30)]
536
    tests.extend((_CalcTimeout, i, i + 300)
537
                 for i in [0, 5, 16485, 30516])
538

    
539
    for timeout, arg1, exp_timeout in tests:
540
      cdef = ("test_call", NotImplemented, None, timeout, [
541
        ("arg1", None, NotImplemented),
542
        ("arg2", None, NotImplemented),
543
        ], None, None, NotImplemented)
544

    
545
      http_proc = _FakeRequestProcessor(compat.partial(_VerifyRequest,
546
                                                       exp_timeout))
547
      client = rpc._RpcClientBase(resolver, NotImplemented,
548
                                  _req_process_fn=http_proc)
549
      result = client._Call(cdef, nodes, [arg1, 300])
550
      self.assertEqual(len(result), len(nodes))
551
      self.assertTrue(compat.all(not res.fail_msg and
552
                                 res.payload == hex(exp_timeout)
553
                                 for res in result.values()))
554

    
555
  def testArgumentEncoder(self):
556
    (AT1, AT2) = range(1, 3)
557

    
558
    resolver = rpc._StaticResolver([
559
      "192.0.2.5",
560
      "192.0.2.6",
561
      ])
562

    
563
    nodes = [
564
      "node5.example.com",
565
      "node6.example.com",
566
      ]
567

    
568
    encoders = {
569
      AT1: lambda _, value: hex(value),
570
      AT2: lambda _, value: hash(value),
571
      }
572

    
573
    cdef = ("test_call", NotImplemented, None, constants.RPC_TMO_NORMAL, [
574
      ("arg0", None, NotImplemented),
575
      ("arg1", AT1, NotImplemented),
576
      ("arg1", AT2, NotImplemented),
577
      ], None, None, NotImplemented)
578

    
579
    def _VerifyRequest(req):
580
      req.success = True
581
      req.resp_status_code = http.HTTP_OK
582
      req.resp_body = serializer.DumpJson((True, req.post_data))
583

    
584
    http_proc = _FakeRequestProcessor(_VerifyRequest)
585

    
586
    for num in [0, 3796, 9032119]:
587
      client = rpc._RpcClientBase(resolver, encoders.get,
588
                                  _req_process_fn=http_proc)
589
      result = client._Call(cdef, nodes, ["foo", num, "Hello%s" % num])
590
      self.assertEqual(len(result), len(nodes))
591
      for res in result.values():
592
        self.assertFalse(res.fail_msg)
593
        self.assertEqual(serializer.LoadJson(res.payload),
594
                         ["foo", hex(num), hash("Hello%s" % num)])
595

    
596
  def testPostProc(self):
597
    def _VerifyRequest(nums, req):
598
      req.success = True
599
      req.resp_status_code = http.HTTP_OK
600
      req.resp_body = serializer.DumpJson((True, nums))
601

    
602
    resolver = rpc._StaticResolver([
603
      "192.0.2.90",
604
      "192.0.2.95",
605
      ])
606

    
607
    nodes = [
608
      "node90.example.com",
609
      "node95.example.com",
610
      ]
611

    
612
    def _PostProc(res):
613
      self.assertFalse(res.fail_msg)
614
      res.payload = sum(res.payload)
615
      return res
616

    
617
    cdef = ("test_call", NotImplemented, None, constants.RPC_TMO_NORMAL, [],
618
            None, _PostProc, NotImplemented)
619

    
620
    # Seeded random generator
621
    rnd = random.Random(20299)
622

    
623
    for i in [0, 4, 74, 1391]:
624
      nums = [rnd.randint(0, 1000) for _ in range(i)]
625
      http_proc = _FakeRequestProcessor(compat.partial(_VerifyRequest, nums))
626
      client = rpc._RpcClientBase(resolver, NotImplemented,
627
                                  _req_process_fn=http_proc)
628
      result = client._Call(cdef, nodes, [])
629
      self.assertEqual(len(result), len(nodes))
630
      for res in result.values():
631
        self.assertFalse(res.fail_msg)
632
        self.assertEqual(res.payload, sum(nums))
633

    
634
  def testPreProc(self):
635
    def _VerifyRequest(req):
636
      req.success = True
637
      req.resp_status_code = http.HTTP_OK
638
      req.resp_body = serializer.DumpJson((True, req.post_data))
639

    
640
    resolver = rpc._StaticResolver([
641
      "192.0.2.30",
642
      "192.0.2.35",
643
      ])
644

    
645
    nodes = [
646
      "node30.example.com",
647
      "node35.example.com",
648
      ]
649

    
650
    def _PreProc(node, data):
651
      self.assertEqual(len(data), 1)
652
      return data[0] + node
653

    
654
    cdef = ("test_call", NotImplemented, None, constants.RPC_TMO_NORMAL, [
655
      ("arg0", None, NotImplemented),
656
      ], _PreProc, None, NotImplemented)
657

    
658
    http_proc = _FakeRequestProcessor(_VerifyRequest)
659
    client = rpc._RpcClientBase(resolver, NotImplemented,
660
                                _req_process_fn=http_proc)
661

    
662
    for prefix in ["foo", "bar", "baz"]:
663
      result = client._Call(cdef, nodes, [prefix])
664
      self.assertEqual(len(result), len(nodes))
665
      for (idx, (node, res)) in enumerate(result.items()):
666
        self.assertFalse(res.fail_msg)
667
        self.assertEqual(serializer.LoadJson(res.payload), prefix + node)
668

    
669
  def testResolverOptions(self):
670
    def _VerifyRequest(req):
671
      req.success = True
672
      req.resp_status_code = http.HTTP_OK
673
      req.resp_body = serializer.DumpJson((True, req.post_data))
674

    
675
    nodes = [
676
      "node30.example.com",
677
      "node35.example.com",
678
      ]
679

    
680
    def _Resolver(expected, hosts, options):
681
      self.assertEqual(hosts, nodes)
682
      self.assertEqual(options, expected)
683
      return zip(hosts, nodes, hosts)
684

    
685
    def _DynamicResolverOptions((arg0, )):
686
      return sum(arg0)
687

    
688
    tests = [
689
      (None, None, None),
690
      (rpc_defs.ACCEPT_OFFLINE_NODE, None, rpc_defs.ACCEPT_OFFLINE_NODE),
691
      (False, None, False),
692
      (True, None, True),
693
      (0, None, 0),
694
      (_DynamicResolverOptions, [1, 2, 3], 6),
695
      (_DynamicResolverOptions, range(4, 19), 165),
696
      ]
697

    
698
    for (resolver_opts, arg0, expected) in tests:
699
      cdef = ("test_call", NotImplemented, resolver_opts,
700
              constants.RPC_TMO_NORMAL, [
701
        ("arg0", None, NotImplemented),
702
        ], None, None, NotImplemented)
703

    
704
      http_proc = _FakeRequestProcessor(_VerifyRequest)
705

    
706
      client = rpc._RpcClientBase(compat.partial(_Resolver, expected),
707
                                  NotImplemented, _req_process_fn=http_proc)
708
      result = client._Call(cdef, nodes, [arg0])
709
      self.assertEqual(len(result), len(nodes))
710
      for (idx, (node, res)) in enumerate(result.items()):
711
        self.assertFalse(res.fail_msg)
712

    
713

    
714
class _FakeConfigForRpcRunner:
715
  GetAllNodesInfo = NotImplemented
716

    
717
  def __init__(self, cluster=NotImplemented):
718
    self._cluster = cluster
719

    
720
  def GetNodeInfo(self, name):
721
    return objects.Node(name=name)
722

    
723
  def GetMultiNodeInfo(self, names):
724
    return [(name, self.GetNodeInfo(name)) for name in names]
725

    
726
  def GetClusterInfo(self):
727
    return self._cluster
728

    
729
  def GetInstanceDiskParams(self, _):
730
    return constants.DISK_DT_DEFAULTS
731

    
732

    
733
class TestRpcRunner(unittest.TestCase):
734
  def testUploadFile(self):
735
    data = 1779 * "Hello World\n"
736

    
737
    tmpfile = tempfile.NamedTemporaryFile()
738
    tmpfile.write(data)
739
    tmpfile.flush()
740
    st = os.stat(tmpfile.name)
741

    
742
    nodes = [
743
      "node1.example.com",
744
      ]
745

    
746
    def _VerifyRequest(req):
747
      (uldata, ) = serializer.LoadJson(req.post_data)
748
      self.assertEqual(len(uldata), 7)
749
      self.assertEqual(uldata[0], tmpfile.name)
750
      self.assertEqual(list(uldata[1]), list(rpc._Compress(nodes[0], data)))
751
      self.assertEqual(uldata[2], st.st_mode)
752
      self.assertEqual(uldata[3], "user%s" % os.getuid())
753
      self.assertEqual(uldata[4], "group%s" % os.getgid())
754
      self.assertTrue(uldata[5] is not None)
755
      self.assertEqual(uldata[6], st.st_mtime)
756

    
757
      req.success = True
758
      req.resp_status_code = http.HTTP_OK
759
      req.resp_body = serializer.DumpJson((True, None))
760

    
761
    http_proc = _FakeRequestProcessor(_VerifyRequest)
762

    
763
    std_runner = rpc.RpcRunner(_FakeConfigForRpcRunner(), None,
764
                               _req_process_fn=http_proc,
765
                               _getents=mocks.FakeGetentResolver)
766

    
767
    cfg_runner = rpc.ConfigRunner(None, ["192.0.2.13"],
768
                                  _req_process_fn=http_proc,
769
                                  _getents=mocks.FakeGetentResolver)
770

    
771
    for runner in [std_runner, cfg_runner]:
772
      result = runner.call_upload_file(nodes, tmpfile.name)
773
      self.assertEqual(len(result), len(nodes))
774
      for (idx, (node, res)) in enumerate(result.items()):
775
        self.assertFalse(res.fail_msg)
776

    
777
  def testEncodeInstance(self):
778
    cluster = objects.Cluster(hvparams={
779
      constants.HT_KVM: {
780
        constants.HV_BLOCKDEV_PREFIX: "foo",
781
        },
782
      },
783
      beparams={
784
        constants.PP_DEFAULT: {
785
          constants.BE_MAXMEM: 8192,
786
          },
787
        },
788
      os_hvp={},
789
      osparams={
790
        "linux": {
791
          "role": "unknown",
792
          },
793
        })
794
    cluster.UpgradeConfig()
795

    
796
    inst = objects.Instance(name="inst1.example.com",
797
      hypervisor=constants.HT_FAKE,
798
      os="linux",
799
      hvparams={
800
        constants.HT_KVM: {
801
          constants.HV_BLOCKDEV_PREFIX: "bar",
802
          constants.HV_ROOT_PATH: "/tmp",
803
          },
804
        },
805
      beparams={
806
        constants.BE_MINMEM: 128,
807
        constants.BE_MAXMEM: 256,
808
        },
809
      nics=[
810
        objects.NIC(nicparams={
811
          constants.NIC_MODE: "mymode",
812
          }),
813
        ],
814
      disk_template=constants.DT_PLAIN,
815
      disks=[
816
        objects.Disk(dev_type=constants.DT_PLAIN, size=4096,
817
                     logical_id=("vg", "disk6120")),
818
        objects.Disk(dev_type=constants.DT_PLAIN, size=1024,
819
                     logical_id=("vg", "disk8508")),
820
        ])
821
    inst.UpgradeConfig()
822

    
823
    cfg = _FakeConfigForRpcRunner(cluster=cluster)
824
    runner = rpc.RpcRunner(cfg, None,
825
                           _req_process_fn=NotImplemented,
826
                           _getents=mocks.FakeGetentResolver)
827

    
828
    def _CheckBasics(result):
829
      self.assertEqual(result["name"], "inst1.example.com")
830
      self.assertEqual(result["os"], "linux")
831
      self.assertEqual(result["beparams"][constants.BE_MINMEM], 128)
832
      self.assertEqual(len(result["hvparams"]), 1)
833
      self.assertEqual(len(result["nics"]), 1)
834
      self.assertEqual(result["nics"][0]["nicparams"][constants.NIC_MODE],
835
                       "mymode")
836

    
837
    # Generic object serialization
838
    result = runner._encoder(NotImplemented, (rpc_defs.ED_OBJECT_DICT, inst))
839
    _CheckBasics(result)
840

    
841
    result = runner._encoder(NotImplemented,
842
                             (rpc_defs.ED_OBJECT_DICT_LIST, 5 * [inst]))
843
    map(_CheckBasics, result)
844

    
845
    # Just an instance
846
    result = runner._encoder(NotImplemented, (rpc_defs.ED_INST_DICT, inst))
847
    _CheckBasics(result)
848
    self.assertEqual(result["beparams"][constants.BE_MAXMEM], 256)
849
    self.assertEqual(result["hvparams"][constants.HT_KVM], {
850
      constants.HV_BLOCKDEV_PREFIX: "bar",
851
      constants.HV_ROOT_PATH: "/tmp",
852
      })
853
    self.assertEqual(result["osparams"], {
854
      "role": "unknown",
855
      })
856

    
857
    # Instance with OS parameters
858
    result = runner._encoder(NotImplemented,
859
                             (rpc_defs.ED_INST_DICT_OSP_DP, (inst, {
860
                               "role": "webserver",
861
                               "other": "field",
862
                             })))
863
    _CheckBasics(result)
864
    self.assertEqual(result["beparams"][constants.BE_MAXMEM], 256)
865
    self.assertEqual(result["hvparams"][constants.HT_KVM], {
866
      constants.HV_BLOCKDEV_PREFIX: "bar",
867
      constants.HV_ROOT_PATH: "/tmp",
868
      })
869
    self.assertEqual(result["osparams"], {
870
      "role": "webserver",
871
      "other": "field",
872
      })
873

    
874
    # Instance with hypervisor and backend parameters
875
    result = runner._encoder(NotImplemented,
876
                             (rpc_defs.ED_INST_DICT_HVP_BEP_DP, (inst, {
877
      constants.HT_KVM: {
878
        constants.HV_BOOT_ORDER: "xyz",
879
        },
880
      }, {
881
      constants.BE_VCPUS: 100,
882
      constants.BE_MAXMEM: 4096,
883
      })))
884
    _CheckBasics(result)
885
    self.assertEqual(result["beparams"][constants.BE_MAXMEM], 4096)
886
    self.assertEqual(result["beparams"][constants.BE_VCPUS], 100)
887
    self.assertEqual(result["hvparams"][constants.HT_KVM], {
888
      constants.HV_BOOT_ORDER: "xyz",
889
      })
890
    self.assertEqual(result["disks"], [{
891
      "dev_type": constants.DT_PLAIN,
892
      "dynamic_params": {},
893
      "size": 4096,
894
      "logical_id": ("vg", "disk6120"),
895
      "params": constants.DISK_DT_DEFAULTS[inst.disk_template],
896
      }, {
897
      "dev_type": constants.DT_PLAIN,
898
      "dynamic_params": {},
899
      "size": 1024,
900
      "logical_id": ("vg", "disk8508"),
901
      "params": constants.DISK_DT_DEFAULTS[inst.disk_template],
902
      }])
903

    
904
    self.assertTrue(compat.all(disk.params == {} for disk in inst.disks),
905
                    msg="Configuration objects were modified")
906

    
907

    
908
class TestLegacyNodeInfo(unittest.TestCase):
909
  KEY_BOOT = "bootid"
910
  KEY_NAME = "name"
911
  KEY_STORAGE_FREE = "storage_free"
912
  KEY_STORAGE_TOTAL = "storage_size"
913
  KEY_CPU_COUNT = "cpu_count"
914
  KEY_SPINDLES_FREE = "spindles_free"
915
  KEY_SPINDLES_TOTAL = "spindles_total"
916
  KEY_STORAGE_TYPE = "type" # key for storage type
917
  VAL_BOOT = 0
918
  VAL_VG_NAME = "xy"
919
  VAL_VG_FREE = 11
920
  VAL_VG_TOTAL = 12
921
  VAL_VG_TYPE = "lvm-vg"
922
  VAL_CPU_COUNT = 2
923
  VAL_PV_NAME = "ab"
924
  VAL_PV_FREE = 31
925
  VAL_PV_TOTAL = 32
926
  VAL_PV_TYPE = "lvm-pv"
927
  DICT_VG = {
928
    KEY_NAME: VAL_VG_NAME,
929
    KEY_STORAGE_FREE: VAL_VG_FREE,
930
    KEY_STORAGE_TOTAL: VAL_VG_TOTAL,
931
    KEY_STORAGE_TYPE: VAL_VG_TYPE,
932
    }
933
  DICT_HV = {KEY_CPU_COUNT: VAL_CPU_COUNT}
934
  DICT_SP = {
935
    KEY_STORAGE_TYPE: VAL_PV_TYPE,
936
    KEY_NAME: VAL_PV_NAME,
937
    KEY_STORAGE_FREE: VAL_PV_FREE,
938
    KEY_STORAGE_TOTAL: VAL_PV_TOTAL,
939
    }
940
  STD_LST = [VAL_BOOT, [DICT_VG, DICT_SP], [DICT_HV]]
941
  STD_DICT = {
942
    KEY_BOOT: VAL_BOOT,
943
    KEY_NAME: VAL_VG_NAME,
944
    KEY_STORAGE_FREE: VAL_VG_FREE,
945
    KEY_STORAGE_TOTAL: VAL_VG_TOTAL,
946
    KEY_SPINDLES_FREE: VAL_PV_FREE,
947
    KEY_SPINDLES_TOTAL: VAL_PV_TOTAL,
948
    KEY_CPU_COUNT: VAL_CPU_COUNT,
949
    }
950

    
951
  def testWithSpindles(self):
952
    result = rpc.MakeLegacyNodeInfo(self.STD_LST, constants.DT_PLAIN)
953
    self.assertEqual(result, self.STD_DICT)
954

    
955
  def testNoSpindles(self):
956
    my_lst = [self.VAL_BOOT, [self.DICT_VG], [self.DICT_HV]]
957
    result = rpc.MakeLegacyNodeInfo(my_lst, constants.DT_PLAIN)
958
    expected_dict = dict((k,v) for k, v in self.STD_DICT.iteritems())
959
    expected_dict[self.KEY_SPINDLES_FREE] = 0
960
    expected_dict[self.KEY_SPINDLES_TOTAL] = 0
961
    self.assertEqual(result, expected_dict)
962

    
963

    
964
if __name__ == "__main__":
965
  testutils.GanetiTestProgram()