Statistics
| Branch: | Tag: | Revision:

root / test / ganeti.rpc_unittest.py @ 890ea4ce

History | View | Annotate | Download (17 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2010, 2011 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for testing ganeti.rpc"""
23

    
24
import os
25
import sys
26
import unittest
27

    
28
from ganeti import constants
29
from ganeti import compat
30
from ganeti import rpc
31
from ganeti import rpc_defs
32
from ganeti import http
33
from ganeti import errors
34
from ganeti import serializer
35
from ganeti import objects
36

    
37
import testutils
38

    
39

    
40
class _FakeRequestProcessor:
41
  def __init__(self, response_fn):
42
    self._response_fn = response_fn
43
    self.reqcount = 0
44

    
45
  def __call__(self, reqs, lock_monitor_cb=None):
46
    assert lock_monitor_cb is None or callable(lock_monitor_cb)
47
    for req in reqs:
48
      self.reqcount += 1
49
      self._response_fn(req)
50

    
51

    
52
def GetFakeSimpleStoreClass(fn):
53
  class FakeSimpleStore:
54
    GetNodePrimaryIPList = fn
55
    GetPrimaryIPFamily = lambda _: None
56

    
57
  return FakeSimpleStore
58

    
59

    
60
class TestRpcProcessor(unittest.TestCase):
61
  def _FakeAddressLookup(self, map):
62
    return lambda node_list: [map.get(node) for node in node_list]
63

    
64
  def _GetVersionResponse(self, req):
65
    self.assertEqual(req.host, "127.0.0.1")
66
    self.assertEqual(req.port, 24094)
67
    self.assertEqual(req.path, "/version")
68
    self.assertEqual(req.read_timeout, rpc._TMO_URGENT)
69
    req.success = True
70
    req.resp_status_code = http.HTTP_OK
71
    req.resp_body = serializer.DumpJson((True, 123))
72

    
73
  def testVersionSuccess(self):
74
    resolver = rpc._StaticResolver(["127.0.0.1"])
75
    http_proc = _FakeRequestProcessor(self._GetVersionResponse)
76
    proc = rpc._RpcProcessor(resolver, 24094)
77
    result = proc(["localhost"], "version", {"localhost": ""}, 60,
78
                  NotImplemented, _req_process_fn=http_proc)
79
    self.assertEqual(result.keys(), ["localhost"])
80
    lhresp = result["localhost"]
81
    self.assertFalse(lhresp.offline)
82
    self.assertEqual(lhresp.node, "localhost")
83
    self.assertFalse(lhresp.fail_msg)
84
    self.assertEqual(lhresp.payload, 123)
85
    self.assertEqual(lhresp.call, "version")
86
    lhresp.Raise("should not raise")
87
    self.assertEqual(http_proc.reqcount, 1)
88

    
89
  def _ReadTimeoutResponse(self, req):
90
    self.assertEqual(req.host, "192.0.2.13")
91
    self.assertEqual(req.port, 19176)
92
    self.assertEqual(req.path, "/version")
93
    self.assertEqual(req.read_timeout, 12356)
94
    req.success = True
95
    req.resp_status_code = http.HTTP_OK
96
    req.resp_body = serializer.DumpJson((True, -1))
97

    
98
  def testReadTimeout(self):
99
    resolver = rpc._StaticResolver(["192.0.2.13"])
100
    http_proc = _FakeRequestProcessor(self._ReadTimeoutResponse)
101
    proc = rpc._RpcProcessor(resolver, 19176)
102
    host = "node31856"
103
    body = {host: ""}
104
    result = proc([host], "version", body, 12356, NotImplemented,
105
                  _req_process_fn=http_proc)
106
    self.assertEqual(result.keys(), [host])
107
    lhresp = result[host]
108
    self.assertFalse(lhresp.offline)
109
    self.assertEqual(lhresp.node, host)
110
    self.assertFalse(lhresp.fail_msg)
111
    self.assertEqual(lhresp.payload, -1)
112
    self.assertEqual(lhresp.call, "version")
113
    lhresp.Raise("should not raise")
114
    self.assertEqual(http_proc.reqcount, 1)
115

    
116
  def testOfflineNode(self):
117
    resolver = rpc._StaticResolver([rpc._OFFLINE])
118
    http_proc = _FakeRequestProcessor(NotImplemented)
119
    proc = rpc._RpcProcessor(resolver, 30668)
120
    host = "n17296"
121
    body = {host: ""}
122
    result = proc([host], "version", body, 60, NotImplemented,
123
                  _req_process_fn=http_proc)
124
    self.assertEqual(result.keys(), [host])
125
    lhresp = result[host]
126
    self.assertTrue(lhresp.offline)
127
    self.assertEqual(lhresp.node, host)
128
    self.assertTrue(lhresp.fail_msg)
129
    self.assertFalse(lhresp.payload)
130
    self.assertEqual(lhresp.call, "version")
131

    
132
    # With a message
133
    self.assertRaises(errors.OpExecError, lhresp.Raise, "should raise")
134

    
135
    # No message
136
    self.assertRaises(errors.OpExecError, lhresp.Raise, None)
137

    
138
    self.assertEqual(http_proc.reqcount, 0)
139

    
140
  def _GetMultiVersionResponse(self, req):
141
    self.assert_(req.host.startswith("node"))
142
    self.assertEqual(req.port, 23245)
143
    self.assertEqual(req.path, "/version")
144
    req.success = True
145
    req.resp_status_code = http.HTTP_OK
146
    req.resp_body = serializer.DumpJson((True, 987))
147

    
148
  def testMultiVersionSuccess(self):
149
    nodes = ["node%s" % i for i in range(50)]
150
    body = dict((n, "") for n in nodes)
151
    resolver = rpc._StaticResolver(nodes)
152
    http_proc = _FakeRequestProcessor(self._GetMultiVersionResponse)
153
    proc = rpc._RpcProcessor(resolver, 23245)
154
    result = proc(nodes, "version", body, 60, NotImplemented,
155
                  _req_process_fn=http_proc)
156
    self.assertEqual(sorted(result.keys()), sorted(nodes))
157

    
158
    for name in nodes:
159
      lhresp = result[name]
160
      self.assertFalse(lhresp.offline)
161
      self.assertEqual(lhresp.node, name)
162
      self.assertFalse(lhresp.fail_msg)
163
      self.assertEqual(lhresp.payload, 987)
164
      self.assertEqual(lhresp.call, "version")
165
      lhresp.Raise("should not raise")
166

    
167
    self.assertEqual(http_proc.reqcount, len(nodes))
168

    
169
  def _GetVersionResponseFail(self, errinfo, req):
170
    self.assertEqual(req.path, "/version")
171
    req.success = True
172
    req.resp_status_code = http.HTTP_OK
173
    req.resp_body = serializer.DumpJson((False, errinfo))
174

    
175
  def testVersionFailure(self):
176
    resolver = rpc._StaticResolver(["aef9ur4i.example.com"])
177
    proc = rpc._RpcProcessor(resolver, 5903)
178
    for errinfo in [None, "Unknown error"]:
179
      http_proc = \
180
        _FakeRequestProcessor(compat.partial(self._GetVersionResponseFail,
181
                                             errinfo))
182
      host = "aef9ur4i.example.com"
183
      body = {host: ""}
184
      result = proc(body.keys(), "version", body, 60, NotImplemented,
185
                    _req_process_fn=http_proc)
186
      self.assertEqual(result.keys(), [host])
187
      lhresp = result[host]
188
      self.assertFalse(lhresp.offline)
189
      self.assertEqual(lhresp.node, host)
190
      self.assert_(lhresp.fail_msg)
191
      self.assertFalse(lhresp.payload)
192
      self.assertEqual(lhresp.call, "version")
193
      self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
194
      self.assertEqual(http_proc.reqcount, 1)
195

    
196
  def _GetHttpErrorResponse(self, httperrnodes, failnodes, req):
197
    self.assertEqual(req.path, "/vg_list")
198
    self.assertEqual(req.port, 15165)
199

    
200
    if req.host in httperrnodes:
201
      req.success = False
202
      req.error = "Node set up for HTTP errors"
203

    
204
    elif req.host in failnodes:
205
      req.success = True
206
      req.resp_status_code = 404
207
      req.resp_body = serializer.DumpJson({
208
        "code": 404,
209
        "message": "Method not found",
210
        "explain": "Explanation goes here",
211
        })
212
    else:
213
      req.success = True
214
      req.resp_status_code = http.HTTP_OK
215
      req.resp_body = serializer.DumpJson((True, hash(req.host)))
216

    
217
  def testHttpError(self):
218
    nodes = ["uaf6pbbv%s" % i for i in range(50)]
219
    body = dict((n, "") for n in nodes)
220
    resolver = rpc._StaticResolver(nodes)
221

    
222
    httperrnodes = set(nodes[1::7])
223
    self.assertEqual(len(httperrnodes), 7)
224

    
225
    failnodes = set(nodes[2::3]) - httperrnodes
226
    self.assertEqual(len(failnodes), 14)
227

    
228
    self.assertEqual(len(set(nodes) - failnodes - httperrnodes), 29)
229

    
230
    proc = rpc._RpcProcessor(resolver, 15165)
231
    http_proc = \
232
      _FakeRequestProcessor(compat.partial(self._GetHttpErrorResponse,
233
                                           httperrnodes, failnodes))
234
    result = proc(nodes, "vg_list", body, rpc._TMO_URGENT, NotImplemented,
235
                  _req_process_fn=http_proc)
236
    self.assertEqual(sorted(result.keys()), sorted(nodes))
237

    
238
    for name in nodes:
239
      lhresp = result[name]
240
      self.assertFalse(lhresp.offline)
241
      self.assertEqual(lhresp.node, name)
242
      self.assertEqual(lhresp.call, "vg_list")
243

    
244
      if name in httperrnodes:
245
        self.assert_(lhresp.fail_msg)
246
        self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
247
      elif name in failnodes:
248
        self.assert_(lhresp.fail_msg)
249
        self.assertRaises(errors.OpPrereqError, lhresp.Raise, "failed",
250
                          prereq=True, ecode=errors.ECODE_INVAL)
251
      else:
252
        self.assertFalse(lhresp.fail_msg)
253
        self.assertEqual(lhresp.payload, hash(name))
254
        lhresp.Raise("should not raise")
255

    
256
    self.assertEqual(http_proc.reqcount, len(nodes))
257

    
258
  def _GetInvalidResponseA(self, req):
259
    self.assertEqual(req.path, "/version")
260
    req.success = True
261
    req.resp_status_code = http.HTTP_OK
262
    req.resp_body = serializer.DumpJson(("This", "is", "an", "invalid",
263
                                         "response", "!", 1, 2, 3))
264

    
265
  def _GetInvalidResponseB(self, req):
266
    self.assertEqual(req.path, "/version")
267
    req.success = True
268
    req.resp_status_code = http.HTTP_OK
269
    req.resp_body = serializer.DumpJson("invalid response")
270

    
271
  def testInvalidResponse(self):
272
    resolver = rpc._StaticResolver(["oqo7lanhly.example.com"])
273
    proc = rpc._RpcProcessor(resolver, 19978)
274

    
275
    for fn in [self._GetInvalidResponseA, self._GetInvalidResponseB]:
276
      http_proc = _FakeRequestProcessor(fn)
277
      host = "oqo7lanhly.example.com"
278
      body = {host: ""}
279
      result = proc([host], "version", body, 60, NotImplemented,
280
                    _req_process_fn=http_proc)
281
      self.assertEqual(result.keys(), [host])
282
      lhresp = result[host]
283
      self.assertFalse(lhresp.offline)
284
      self.assertEqual(lhresp.node, host)
285
      self.assert_(lhresp.fail_msg)
286
      self.assertFalse(lhresp.payload)
287
      self.assertEqual(lhresp.call, "version")
288
      self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
289
      self.assertEqual(http_proc.reqcount, 1)
290

    
291
  def _GetBodyTestResponse(self, test_data, req):
292
    self.assertEqual(req.host, "192.0.2.84")
293
    self.assertEqual(req.port, 18700)
294
    self.assertEqual(req.path, "/upload_file")
295
    self.assertEqual(serializer.LoadJson(req.post_data), test_data)
296
    req.success = True
297
    req.resp_status_code = http.HTTP_OK
298
    req.resp_body = serializer.DumpJson((True, None))
299

    
300
  def testResponseBody(self):
301
    test_data = {
302
      "Hello": "World",
303
      "xyz": range(10),
304
      }
305
    resolver = rpc._StaticResolver(["192.0.2.84"])
306
    http_proc = _FakeRequestProcessor(compat.partial(self._GetBodyTestResponse,
307
                                                     test_data))
308
    proc = rpc._RpcProcessor(resolver, 18700)
309
    host = "node19759"
310
    body = {host: serializer.DumpJson(test_data)}
311
    result = proc([host], "upload_file", body, 30, NotImplemented,
312
                  _req_process_fn=http_proc)
313
    self.assertEqual(result.keys(), [host])
314
    lhresp = result[host]
315
    self.assertFalse(lhresp.offline)
316
    self.assertEqual(lhresp.node, host)
317
    self.assertFalse(lhresp.fail_msg)
318
    self.assertEqual(lhresp.payload, None)
319
    self.assertEqual(lhresp.call, "upload_file")
320
    lhresp.Raise("should not raise")
321
    self.assertEqual(http_proc.reqcount, 1)
322

    
323

    
324
class TestSsconfResolver(unittest.TestCase):
325
  def testSsconfLookup(self):
326
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
327
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
328
    node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
329
    ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
330
    result = rpc._SsconfResolver(node_list, NotImplemented,
331
                                 ssc=ssc, nslookup_fn=NotImplemented)
332
    self.assertEqual(result, zip(node_list, addr_list))
333

    
334
  def testNsLookup(self):
335
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
336
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
337
    ssc = GetFakeSimpleStoreClass(lambda _: [])
338
    node_addr_map = dict(zip(node_list, addr_list))
339
    nslookup_fn = lambda name, family=None: node_addr_map.get(name)
340
    result = rpc._SsconfResolver(node_list, NotImplemented,
341
                                 ssc=ssc, nslookup_fn=nslookup_fn)
342
    self.assertEqual(result, zip(node_list, addr_list))
343

    
344
  def testBothLookups(self):
345
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
346
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
347
    n = len(addr_list) / 2
348
    node_addr_list = [" ".join(t) for t in zip(node_list[n:], addr_list[n:])]
349
    ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
350
    node_addr_map = dict(zip(node_list[:n], addr_list[:n]))
351
    nslookup_fn = lambda name, family=None: node_addr_map.get(name)
352
    result = rpc._SsconfResolver(node_list, NotImplemented,
353
                                 ssc=ssc, nslookup_fn=nslookup_fn)
354
    self.assertEqual(result, zip(node_list, addr_list))
355

    
356
  def testAddressLookupIPv6(self):
357
    addr_list = ["2001:db8::%d" % n for n in range(0, 255, 11)]
358
    node_list = ["node%d.example.com" % n for n in range(0, 255, 11)]
359
    node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
360
    ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
361
    result = rpc._SsconfResolver(node_list, NotImplemented,
362
                                 ssc=ssc, nslookup_fn=NotImplemented)
363
    self.assertEqual(result, zip(node_list, addr_list))
364

    
365

    
366
class TestStaticResolver(unittest.TestCase):
367
  def test(self):
368
    addresses = ["192.0.2.%d" % n for n in range(0, 123, 7)]
369
    nodes = ["node%s.example.com" % n for n in range(0, 123, 7)]
370
    res = rpc._StaticResolver(addresses)
371
    self.assertEqual(res(nodes, NotImplemented), zip(nodes, addresses))
372

    
373
  def testWrongLength(self):
374
    res = rpc._StaticResolver([])
375
    self.assertRaises(AssertionError, res, ["abc"], NotImplemented)
376

    
377

    
378
class TestNodeConfigResolver(unittest.TestCase):
379
  @staticmethod
380
  def _GetSingleOnlineNode(name):
381
    assert name == "node90.example.com"
382
    return objects.Node(name=name, offline=False, primary_ip="192.0.2.90")
383

    
384
  @staticmethod
385
  def _GetSingleOfflineNode(name):
386
    assert name == "node100.example.com"
387
    return objects.Node(name=name, offline=True, primary_ip="192.0.2.100")
388

    
389
  def testSingleOnline(self):
390
    self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOnlineNode,
391
                                             NotImplemented,
392
                                             ["node90.example.com"], None),
393
                     [("node90.example.com", "192.0.2.90")])
394

    
395
  def testSingleOffline(self):
396
    self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOfflineNode,
397
                                             NotImplemented,
398
                                             ["node100.example.com"], None),
399
                     [("node100.example.com", rpc._OFFLINE)])
400

    
401
  def testSingleOfflineWithAcceptOffline(self):
402
    fn = self._GetSingleOfflineNode
403
    assert fn("node100.example.com").offline
404
    self.assertEqual(rpc._NodeConfigResolver(fn, NotImplemented,
405
                                             ["node100.example.com"],
406
                                             rpc_defs.ACCEPT_OFFLINE_NODE),
407
                     [("node100.example.com", "192.0.2.100")])
408
    for i in [False, True, "", "Hello", 0, 1]:
409
      self.assertRaises(AssertionError, rpc._NodeConfigResolver,
410
                        fn, NotImplemented, ["node100.example.com"], i)
411

    
412
  def testUnknownSingleNode(self):
413
    self.assertEqual(rpc._NodeConfigResolver(lambda _: None, NotImplemented,
414
                                             ["node110.example.com"], None),
415
                     [("node110.example.com", "node110.example.com")])
416

    
417
  def testMultiEmpty(self):
418
    self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
419
                                             lambda: {},
420
                                             [], None),
421
                     [])
422

    
423
  def testMultiSomeOffline(self):
424
    nodes = dict(("node%s.example.com" % i,
425
                  objects.Node(name="node%s.example.com" % i,
426
                               offline=((i % 3) == 0),
427
                               primary_ip="192.0.2.%s" % i))
428
                  for i in range(1, 255))
429

    
430
    # Resolve no names
431
    self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
432
                                             lambda: nodes,
433
                                             [], None),
434
                     [])
435

    
436
    # Offline, online and unknown hosts
437
    self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
438
                                             lambda: nodes,
439
                                             ["node3.example.com",
440
                                              "node92.example.com",
441
                                              "node54.example.com",
442
                                              "unknown.example.com",],
443
                                             None), [
444
      ("node3.example.com", rpc._OFFLINE),
445
      ("node92.example.com", "192.0.2.92"),
446
      ("node54.example.com", rpc._OFFLINE),
447
      ("unknown.example.com", "unknown.example.com"),
448
      ])
449

    
450

    
451
if __name__ == "__main__":
452
  testutils.GanetiTestProgram()