Statistics
| Branch: | Tag: | Revision:

root / test / ganeti.rpc_unittest.py @ 120e7e77

History | View | Annotate | Download (17.9 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2010, 2011 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for testing ganeti.rpc"""
23

    
24
import os
25
import sys
26
import unittest
27

    
28
from ganeti import constants
29
from ganeti import compat
30
from ganeti import rpc
31
from ganeti import rpc_defs
32
from ganeti import http
33
from ganeti import errors
34
from ganeti import serializer
35
from ganeti import objects
36
from ganeti import backend
37

    
38
import testutils
39

    
40

    
41
class _FakeRequestProcessor:
42
  def __init__(self, response_fn):
43
    self._response_fn = response_fn
44
    self.reqcount = 0
45

    
46
  def __call__(self, reqs, lock_monitor_cb=None):
47
    assert lock_monitor_cb is None or callable(lock_monitor_cb)
48
    for req in reqs:
49
      self.reqcount += 1
50
      self._response_fn(req)
51

    
52

    
53
def GetFakeSimpleStoreClass(fn):
54
  class FakeSimpleStore:
55
    GetNodePrimaryIPList = fn
56
    GetPrimaryIPFamily = lambda _: None
57

    
58
  return FakeSimpleStore
59

    
60

    
61
class TestRpcProcessor(unittest.TestCase):
62
  def _FakeAddressLookup(self, map):
63
    return lambda node_list: [map.get(node) for node in node_list]
64

    
65
  def _GetVersionResponse(self, req):
66
    self.assertEqual(req.host, "127.0.0.1")
67
    self.assertEqual(req.port, 24094)
68
    self.assertEqual(req.path, "/version")
69
    self.assertEqual(req.read_timeout, rpc._TMO_URGENT)
70
    req.success = True
71
    req.resp_status_code = http.HTTP_OK
72
    req.resp_body = serializer.DumpJson((True, 123))
73

    
74
  def testVersionSuccess(self):
75
    resolver = rpc._StaticResolver(["127.0.0.1"])
76
    http_proc = _FakeRequestProcessor(self._GetVersionResponse)
77
    proc = rpc._RpcProcessor(resolver, 24094)
78
    result = proc(["localhost"], "version", {"localhost": ""}, 60,
79
                  NotImplemented, _req_process_fn=http_proc)
80
    self.assertEqual(result.keys(), ["localhost"])
81
    lhresp = result["localhost"]
82
    self.assertFalse(lhresp.offline)
83
    self.assertEqual(lhresp.node, "localhost")
84
    self.assertFalse(lhresp.fail_msg)
85
    self.assertEqual(lhresp.payload, 123)
86
    self.assertEqual(lhresp.call, "version")
87
    lhresp.Raise("should not raise")
88
    self.assertEqual(http_proc.reqcount, 1)
89

    
90
  def _ReadTimeoutResponse(self, req):
91
    self.assertEqual(req.host, "192.0.2.13")
92
    self.assertEqual(req.port, 19176)
93
    self.assertEqual(req.path, "/version")
94
    self.assertEqual(req.read_timeout, 12356)
95
    req.success = True
96
    req.resp_status_code = http.HTTP_OK
97
    req.resp_body = serializer.DumpJson((True, -1))
98

    
99
  def testReadTimeout(self):
100
    resolver = rpc._StaticResolver(["192.0.2.13"])
101
    http_proc = _FakeRequestProcessor(self._ReadTimeoutResponse)
102
    proc = rpc._RpcProcessor(resolver, 19176)
103
    host = "node31856"
104
    body = {host: ""}
105
    result = proc([host], "version", body, 12356, NotImplemented,
106
                  _req_process_fn=http_proc)
107
    self.assertEqual(result.keys(), [host])
108
    lhresp = result[host]
109
    self.assertFalse(lhresp.offline)
110
    self.assertEqual(lhresp.node, host)
111
    self.assertFalse(lhresp.fail_msg)
112
    self.assertEqual(lhresp.payload, -1)
113
    self.assertEqual(lhresp.call, "version")
114
    lhresp.Raise("should not raise")
115
    self.assertEqual(http_proc.reqcount, 1)
116

    
117
  def testOfflineNode(self):
118
    resolver = rpc._StaticResolver([rpc._OFFLINE])
119
    http_proc = _FakeRequestProcessor(NotImplemented)
120
    proc = rpc._RpcProcessor(resolver, 30668)
121
    host = "n17296"
122
    body = {host: ""}
123
    result = proc([host], "version", body, 60, NotImplemented,
124
                  _req_process_fn=http_proc)
125
    self.assertEqual(result.keys(), [host])
126
    lhresp = result[host]
127
    self.assertTrue(lhresp.offline)
128
    self.assertEqual(lhresp.node, host)
129
    self.assertTrue(lhresp.fail_msg)
130
    self.assertFalse(lhresp.payload)
131
    self.assertEqual(lhresp.call, "version")
132

    
133
    # With a message
134
    self.assertRaises(errors.OpExecError, lhresp.Raise, "should raise")
135

    
136
    # No message
137
    self.assertRaises(errors.OpExecError, lhresp.Raise, None)
138

    
139
    self.assertEqual(http_proc.reqcount, 0)
140

    
141
  def _GetMultiVersionResponse(self, req):
142
    self.assert_(req.host.startswith("node"))
143
    self.assertEqual(req.port, 23245)
144
    self.assertEqual(req.path, "/version")
145
    req.success = True
146
    req.resp_status_code = http.HTTP_OK
147
    req.resp_body = serializer.DumpJson((True, 987))
148

    
149
  def testMultiVersionSuccess(self):
150
    nodes = ["node%s" % i for i in range(50)]
151
    body = dict((n, "") for n in nodes)
152
    resolver = rpc._StaticResolver(nodes)
153
    http_proc = _FakeRequestProcessor(self._GetMultiVersionResponse)
154
    proc = rpc._RpcProcessor(resolver, 23245)
155
    result = proc(nodes, "version", body, 60, NotImplemented,
156
                  _req_process_fn=http_proc)
157
    self.assertEqual(sorted(result.keys()), sorted(nodes))
158

    
159
    for name in nodes:
160
      lhresp = result[name]
161
      self.assertFalse(lhresp.offline)
162
      self.assertEqual(lhresp.node, name)
163
      self.assertFalse(lhresp.fail_msg)
164
      self.assertEqual(lhresp.payload, 987)
165
      self.assertEqual(lhresp.call, "version")
166
      lhresp.Raise("should not raise")
167

    
168
    self.assertEqual(http_proc.reqcount, len(nodes))
169

    
170
  def _GetVersionResponseFail(self, errinfo, req):
171
    self.assertEqual(req.path, "/version")
172
    req.success = True
173
    req.resp_status_code = http.HTTP_OK
174
    req.resp_body = serializer.DumpJson((False, errinfo))
175

    
176
  def testVersionFailure(self):
177
    resolver = rpc._StaticResolver(["aef9ur4i.example.com"])
178
    proc = rpc._RpcProcessor(resolver, 5903)
179
    for errinfo in [None, "Unknown error"]:
180
      http_proc = \
181
        _FakeRequestProcessor(compat.partial(self._GetVersionResponseFail,
182
                                             errinfo))
183
      host = "aef9ur4i.example.com"
184
      body = {host: ""}
185
      result = proc(body.keys(), "version", body, 60, NotImplemented,
186
                    _req_process_fn=http_proc)
187
      self.assertEqual(result.keys(), [host])
188
      lhresp = result[host]
189
      self.assertFalse(lhresp.offline)
190
      self.assertEqual(lhresp.node, host)
191
      self.assert_(lhresp.fail_msg)
192
      self.assertFalse(lhresp.payload)
193
      self.assertEqual(lhresp.call, "version")
194
      self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
195
      self.assertEqual(http_proc.reqcount, 1)
196

    
197
  def _GetHttpErrorResponse(self, httperrnodes, failnodes, req):
198
    self.assertEqual(req.path, "/vg_list")
199
    self.assertEqual(req.port, 15165)
200

    
201
    if req.host in httperrnodes:
202
      req.success = False
203
      req.error = "Node set up for HTTP errors"
204

    
205
    elif req.host in failnodes:
206
      req.success = True
207
      req.resp_status_code = 404
208
      req.resp_body = serializer.DumpJson({
209
        "code": 404,
210
        "message": "Method not found",
211
        "explain": "Explanation goes here",
212
        })
213
    else:
214
      req.success = True
215
      req.resp_status_code = http.HTTP_OK
216
      req.resp_body = serializer.DumpJson((True, hash(req.host)))
217

    
218
  def testHttpError(self):
219
    nodes = ["uaf6pbbv%s" % i for i in range(50)]
220
    body = dict((n, "") for n in nodes)
221
    resolver = rpc._StaticResolver(nodes)
222

    
223
    httperrnodes = set(nodes[1::7])
224
    self.assertEqual(len(httperrnodes), 7)
225

    
226
    failnodes = set(nodes[2::3]) - httperrnodes
227
    self.assertEqual(len(failnodes), 14)
228

    
229
    self.assertEqual(len(set(nodes) - failnodes - httperrnodes), 29)
230

    
231
    proc = rpc._RpcProcessor(resolver, 15165)
232
    http_proc = \
233
      _FakeRequestProcessor(compat.partial(self._GetHttpErrorResponse,
234
                                           httperrnodes, failnodes))
235
    result = proc(nodes, "vg_list", body, rpc._TMO_URGENT, NotImplemented,
236
                  _req_process_fn=http_proc)
237
    self.assertEqual(sorted(result.keys()), sorted(nodes))
238

    
239
    for name in nodes:
240
      lhresp = result[name]
241
      self.assertFalse(lhresp.offline)
242
      self.assertEqual(lhresp.node, name)
243
      self.assertEqual(lhresp.call, "vg_list")
244

    
245
      if name in httperrnodes:
246
        self.assert_(lhresp.fail_msg)
247
        self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
248
      elif name in failnodes:
249
        self.assert_(lhresp.fail_msg)
250
        self.assertRaises(errors.OpPrereqError, lhresp.Raise, "failed",
251
                          prereq=True, ecode=errors.ECODE_INVAL)
252
      else:
253
        self.assertFalse(lhresp.fail_msg)
254
        self.assertEqual(lhresp.payload, hash(name))
255
        lhresp.Raise("should not raise")
256

    
257
    self.assertEqual(http_proc.reqcount, len(nodes))
258

    
259
  def _GetInvalidResponseA(self, req):
260
    self.assertEqual(req.path, "/version")
261
    req.success = True
262
    req.resp_status_code = http.HTTP_OK
263
    req.resp_body = serializer.DumpJson(("This", "is", "an", "invalid",
264
                                         "response", "!", 1, 2, 3))
265

    
266
  def _GetInvalidResponseB(self, req):
267
    self.assertEqual(req.path, "/version")
268
    req.success = True
269
    req.resp_status_code = http.HTTP_OK
270
    req.resp_body = serializer.DumpJson("invalid response")
271

    
272
  def testInvalidResponse(self):
273
    resolver = rpc._StaticResolver(["oqo7lanhly.example.com"])
274
    proc = rpc._RpcProcessor(resolver, 19978)
275

    
276
    for fn in [self._GetInvalidResponseA, self._GetInvalidResponseB]:
277
      http_proc = _FakeRequestProcessor(fn)
278
      host = "oqo7lanhly.example.com"
279
      body = {host: ""}
280
      result = proc([host], "version", body, 60, NotImplemented,
281
                    _req_process_fn=http_proc)
282
      self.assertEqual(result.keys(), [host])
283
      lhresp = result[host]
284
      self.assertFalse(lhresp.offline)
285
      self.assertEqual(lhresp.node, host)
286
      self.assert_(lhresp.fail_msg)
287
      self.assertFalse(lhresp.payload)
288
      self.assertEqual(lhresp.call, "version")
289
      self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
290
      self.assertEqual(http_proc.reqcount, 1)
291

    
292
  def _GetBodyTestResponse(self, test_data, req):
293
    self.assertEqual(req.host, "192.0.2.84")
294
    self.assertEqual(req.port, 18700)
295
    self.assertEqual(req.path, "/upload_file")
296
    self.assertEqual(serializer.LoadJson(req.post_data), test_data)
297
    req.success = True
298
    req.resp_status_code = http.HTTP_OK
299
    req.resp_body = serializer.DumpJson((True, None))
300

    
301
  def testResponseBody(self):
302
    test_data = {
303
      "Hello": "World",
304
      "xyz": range(10),
305
      }
306
    resolver = rpc._StaticResolver(["192.0.2.84"])
307
    http_proc = _FakeRequestProcessor(compat.partial(self._GetBodyTestResponse,
308
                                                     test_data))
309
    proc = rpc._RpcProcessor(resolver, 18700)
310
    host = "node19759"
311
    body = {host: serializer.DumpJson(test_data)}
312
    result = proc([host], "upload_file", body, 30, NotImplemented,
313
                  _req_process_fn=http_proc)
314
    self.assertEqual(result.keys(), [host])
315
    lhresp = result[host]
316
    self.assertFalse(lhresp.offline)
317
    self.assertEqual(lhresp.node, host)
318
    self.assertFalse(lhresp.fail_msg)
319
    self.assertEqual(lhresp.payload, None)
320
    self.assertEqual(lhresp.call, "upload_file")
321
    lhresp.Raise("should not raise")
322
    self.assertEqual(http_proc.reqcount, 1)
323

    
324

    
325
class TestSsconfResolver(unittest.TestCase):
326
  def testSsconfLookup(self):
327
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
328
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
329
    node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
330
    ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
331
    result = rpc._SsconfResolver(node_list, NotImplemented,
332
                                 ssc=ssc, nslookup_fn=NotImplemented)
333
    self.assertEqual(result, zip(node_list, addr_list))
334

    
335
  def testNsLookup(self):
336
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
337
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
338
    ssc = GetFakeSimpleStoreClass(lambda _: [])
339
    node_addr_map = dict(zip(node_list, addr_list))
340
    nslookup_fn = lambda name, family=None: node_addr_map.get(name)
341
    result = rpc._SsconfResolver(node_list, NotImplemented,
342
                                 ssc=ssc, nslookup_fn=nslookup_fn)
343
    self.assertEqual(result, zip(node_list, addr_list))
344

    
345
  def testBothLookups(self):
346
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
347
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
348
    n = len(addr_list) / 2
349
    node_addr_list = [" ".join(t) for t in zip(node_list[n:], addr_list[n:])]
350
    ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
351
    node_addr_map = dict(zip(node_list[:n], addr_list[:n]))
352
    nslookup_fn = lambda name, family=None: node_addr_map.get(name)
353
    result = rpc._SsconfResolver(node_list, NotImplemented,
354
                                 ssc=ssc, nslookup_fn=nslookup_fn)
355
    self.assertEqual(result, zip(node_list, addr_list))
356

    
357
  def testAddressLookupIPv6(self):
358
    addr_list = ["2001:db8::%d" % n for n in range(0, 255, 11)]
359
    node_list = ["node%d.example.com" % n for n in range(0, 255, 11)]
360
    node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
361
    ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
362
    result = rpc._SsconfResolver(node_list, NotImplemented,
363
                                 ssc=ssc, nslookup_fn=NotImplemented)
364
    self.assertEqual(result, zip(node_list, addr_list))
365

    
366

    
367
class TestStaticResolver(unittest.TestCase):
368
  def test(self):
369
    addresses = ["192.0.2.%d" % n for n in range(0, 123, 7)]
370
    nodes = ["node%s.example.com" % n for n in range(0, 123, 7)]
371
    res = rpc._StaticResolver(addresses)
372
    self.assertEqual(res(nodes, NotImplemented), zip(nodes, addresses))
373

    
374
  def testWrongLength(self):
375
    res = rpc._StaticResolver([])
376
    self.assertRaises(AssertionError, res, ["abc"], NotImplemented)
377

    
378

    
379
class TestNodeConfigResolver(unittest.TestCase):
380
  @staticmethod
381
  def _GetSingleOnlineNode(name):
382
    assert name == "node90.example.com"
383
    return objects.Node(name=name, offline=False, primary_ip="192.0.2.90")
384

    
385
  @staticmethod
386
  def _GetSingleOfflineNode(name):
387
    assert name == "node100.example.com"
388
    return objects.Node(name=name, offline=True, primary_ip="192.0.2.100")
389

    
390
  def testSingleOnline(self):
391
    self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOnlineNode,
392
                                             NotImplemented,
393
                                             ["node90.example.com"], None),
394
                     [("node90.example.com", "192.0.2.90")])
395

    
396
  def testSingleOffline(self):
397
    self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOfflineNode,
398
                                             NotImplemented,
399
                                             ["node100.example.com"], None),
400
                     [("node100.example.com", rpc._OFFLINE)])
401

    
402
  def testSingleOfflineWithAcceptOffline(self):
403
    fn = self._GetSingleOfflineNode
404
    assert fn("node100.example.com").offline
405
    self.assertEqual(rpc._NodeConfigResolver(fn, NotImplemented,
406
                                             ["node100.example.com"],
407
                                             rpc_defs.ACCEPT_OFFLINE_NODE),
408
                     [("node100.example.com", "192.0.2.100")])
409
    for i in [False, True, "", "Hello", 0, 1]:
410
      self.assertRaises(AssertionError, rpc._NodeConfigResolver,
411
                        fn, NotImplemented, ["node100.example.com"], i)
412

    
413
  def testUnknownSingleNode(self):
414
    self.assertEqual(rpc._NodeConfigResolver(lambda _: None, NotImplemented,
415
                                             ["node110.example.com"], None),
416
                     [("node110.example.com", "node110.example.com")])
417

    
418
  def testMultiEmpty(self):
419
    self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
420
                                             lambda: {},
421
                                             [], None),
422
                     [])
423

    
424
  def testMultiSomeOffline(self):
425
    nodes = dict(("node%s.example.com" % i,
426
                  objects.Node(name="node%s.example.com" % i,
427
                               offline=((i % 3) == 0),
428
                               primary_ip="192.0.2.%s" % i))
429
                  for i in range(1, 255))
430

    
431
    # Resolve no names
432
    self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
433
                                             lambda: nodes,
434
                                             [], None),
435
                     [])
436

    
437
    # Offline, online and unknown hosts
438
    self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
439
                                             lambda: nodes,
440
                                             ["node3.example.com",
441
                                              "node92.example.com",
442
                                              "node54.example.com",
443
                                              "unknown.example.com",],
444
                                             None), [
445
      ("node3.example.com", rpc._OFFLINE),
446
      ("node92.example.com", "192.0.2.92"),
447
      ("node54.example.com", rpc._OFFLINE),
448
      ("unknown.example.com", "unknown.example.com"),
449
      ])
450

    
451

    
452
class TestCompress(unittest.TestCase):
453
  def test(self):
454
    for data in ["", "Hello", "Hello World!\nnew\nlines"]:
455
      self.assertEqual(rpc._Compress(data),
456
                       (constants.RPC_ENCODING_NONE, data))
457

    
458
    for data in [512 * " ", 5242 * "Hello World!\n"]:
459
      compressed = rpc._Compress(data)
460
      self.assertEqual(len(compressed), 2)
461
      self.assertEqual(backend._Decompress(compressed), data)
462

    
463
  def testDecompression(self):
464
    self.assertRaises(AssertionError, backend._Decompress, "")
465
    self.assertRaises(AssertionError, backend._Decompress, [""])
466
    self.assertRaises(AssertionError, backend._Decompress,
467
                      ("unknown compression", "data"))
468
    self.assertRaises(Exception, backend._Decompress,
469
                      (constants.RPC_ENCODING_ZLIB_BASE64, "invalid zlib data"))
470

    
471

    
472
if __name__ == "__main__":
473
  testutils.GanetiTestProgram()