Statistics
| Branch: | Tag: | Revision:

root / test / ganeti.rpc_unittest.py @ fce5efd1

History | View | Annotate | Download (16.4 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2010, 2011 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for testing ganeti.rpc"""
23

    
24
import os
25
import sys
26
import unittest
27

    
28
from ganeti import constants
29
from ganeti import compat
30
from ganeti import rpc
31
from ganeti import http
32
from ganeti import errors
33
from ganeti import serializer
34
from ganeti import objects
35

    
36
import testutils
37

    
38

    
39
class _FakeRequestProcessor:
40
  def __init__(self, response_fn):
41
    self._response_fn = response_fn
42
    self.reqcount = 0
43

    
44
  def __call__(self, reqs, lock_monitor_cb=None):
45
    assert lock_monitor_cb is None or callable(lock_monitor_cb)
46
    for req in reqs:
47
      self.reqcount += 1
48
      self._response_fn(req)
49

    
50

    
51
def GetFakeSimpleStoreClass(fn):
52
  class FakeSimpleStore:
53
    GetNodePrimaryIPList = fn
54
    GetPrimaryIPFamily = lambda _: None
55

    
56
  return FakeSimpleStore
57

    
58

    
59
class TestRpcProcessor(unittest.TestCase):
60
  def _FakeAddressLookup(self, map):
61
    return lambda node_list: [map.get(node) for node in node_list]
62

    
63
  def _GetVersionResponse(self, req):
64
    self.assertEqual(req.host, "127.0.0.1")
65
    self.assertEqual(req.port, 24094)
66
    self.assertEqual(req.path, "/version")
67
    self.assertEqual(req.read_timeout, rpc._TMO_URGENT)
68
    req.success = True
69
    req.resp_status_code = http.HTTP_OK
70
    req.resp_body = serializer.DumpJson((True, 123))
71

    
72
  def testVersionSuccess(self):
73
    resolver = rpc._StaticResolver(["127.0.0.1"])
74
    http_proc = _FakeRequestProcessor(self._GetVersionResponse)
75
    proc = rpc._RpcProcessor(resolver, 24094)
76
    result = proc(["localhost"], "version", {"localhost": ""}, 60,
77
                  NotImplemented, _req_process_fn=http_proc)
78
    self.assertEqual(result.keys(), ["localhost"])
79
    lhresp = result["localhost"]
80
    self.assertFalse(lhresp.offline)
81
    self.assertEqual(lhresp.node, "localhost")
82
    self.assertFalse(lhresp.fail_msg)
83
    self.assertEqual(lhresp.payload, 123)
84
    self.assertEqual(lhresp.call, "version")
85
    lhresp.Raise("should not raise")
86
    self.assertEqual(http_proc.reqcount, 1)
87

    
88
  def _ReadTimeoutResponse(self, req):
89
    self.assertEqual(req.host, "192.0.2.13")
90
    self.assertEqual(req.port, 19176)
91
    self.assertEqual(req.path, "/version")
92
    self.assertEqual(req.read_timeout, 12356)
93
    req.success = True
94
    req.resp_status_code = http.HTTP_OK
95
    req.resp_body = serializer.DumpJson((True, -1))
96

    
97
  def testReadTimeout(self):
98
    resolver = rpc._StaticResolver(["192.0.2.13"])
99
    http_proc = _FakeRequestProcessor(self._ReadTimeoutResponse)
100
    proc = rpc._RpcProcessor(resolver, 19176)
101
    host = "node31856"
102
    body = {host: ""}
103
    result = proc([host], "version", body, 12356, NotImplemented,
104
                  _req_process_fn=http_proc)
105
    self.assertEqual(result.keys(), [host])
106
    lhresp = result[host]
107
    self.assertFalse(lhresp.offline)
108
    self.assertEqual(lhresp.node, host)
109
    self.assertFalse(lhresp.fail_msg)
110
    self.assertEqual(lhresp.payload, -1)
111
    self.assertEqual(lhresp.call, "version")
112
    lhresp.Raise("should not raise")
113
    self.assertEqual(http_proc.reqcount, 1)
114

    
115
  def testOfflineNode(self):
116
    resolver = rpc._StaticResolver([rpc._OFFLINE])
117
    http_proc = _FakeRequestProcessor(NotImplemented)
118
    proc = rpc._RpcProcessor(resolver, 30668)
119
    host = "n17296"
120
    body = {host: ""}
121
    result = proc([host], "version", body, 60, NotImplemented,
122
                  _req_process_fn=http_proc)
123
    self.assertEqual(result.keys(), [host])
124
    lhresp = result[host]
125
    self.assertTrue(lhresp.offline)
126
    self.assertEqual(lhresp.node, host)
127
    self.assertTrue(lhresp.fail_msg)
128
    self.assertFalse(lhresp.payload)
129
    self.assertEqual(lhresp.call, "version")
130

    
131
    # With a message
132
    self.assertRaises(errors.OpExecError, lhresp.Raise, "should raise")
133

    
134
    # No message
135
    self.assertRaises(errors.OpExecError, lhresp.Raise, None)
136

    
137
    self.assertEqual(http_proc.reqcount, 0)
138

    
139
  def _GetMultiVersionResponse(self, req):
140
    self.assert_(req.host.startswith("node"))
141
    self.assertEqual(req.port, 23245)
142
    self.assertEqual(req.path, "/version")
143
    req.success = True
144
    req.resp_status_code = http.HTTP_OK
145
    req.resp_body = serializer.DumpJson((True, 987))
146

    
147
  def testMultiVersionSuccess(self):
148
    nodes = ["node%s" % i for i in range(50)]
149
    body = dict((n, "") for n in nodes)
150
    resolver = rpc._StaticResolver(nodes)
151
    http_proc = _FakeRequestProcessor(self._GetMultiVersionResponse)
152
    proc = rpc._RpcProcessor(resolver, 23245)
153
    result = proc(nodes, "version", body, 60, NotImplemented,
154
                  _req_process_fn=http_proc)
155
    self.assertEqual(sorted(result.keys()), sorted(nodes))
156

    
157
    for name in nodes:
158
      lhresp = result[name]
159
      self.assertFalse(lhresp.offline)
160
      self.assertEqual(lhresp.node, name)
161
      self.assertFalse(lhresp.fail_msg)
162
      self.assertEqual(lhresp.payload, 987)
163
      self.assertEqual(lhresp.call, "version")
164
      lhresp.Raise("should not raise")
165

    
166
    self.assertEqual(http_proc.reqcount, len(nodes))
167

    
168
  def _GetVersionResponseFail(self, errinfo, req):
169
    self.assertEqual(req.path, "/version")
170
    req.success = True
171
    req.resp_status_code = http.HTTP_OK
172
    req.resp_body = serializer.DumpJson((False, errinfo))
173

    
174
  def testVersionFailure(self):
175
    resolver = rpc._StaticResolver(["aef9ur4i.example.com"])
176
    proc = rpc._RpcProcessor(resolver, 5903)
177
    for errinfo in [None, "Unknown error"]:
178
      http_proc = \
179
        _FakeRequestProcessor(compat.partial(self._GetVersionResponseFail,
180
                                             errinfo))
181
      host = "aef9ur4i.example.com"
182
      body = {host: ""}
183
      result = proc(body.keys(), "version", body, 60, NotImplemented,
184
                    _req_process_fn=http_proc)
185
      self.assertEqual(result.keys(), [host])
186
      lhresp = result[host]
187
      self.assertFalse(lhresp.offline)
188
      self.assertEqual(lhresp.node, host)
189
      self.assert_(lhresp.fail_msg)
190
      self.assertFalse(lhresp.payload)
191
      self.assertEqual(lhresp.call, "version")
192
      self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
193
      self.assertEqual(http_proc.reqcount, 1)
194

    
195
  def _GetHttpErrorResponse(self, httperrnodes, failnodes, req):
196
    self.assertEqual(req.path, "/vg_list")
197
    self.assertEqual(req.port, 15165)
198

    
199
    if req.host in httperrnodes:
200
      req.success = False
201
      req.error = "Node set up for HTTP errors"
202

    
203
    elif req.host in failnodes:
204
      req.success = True
205
      req.resp_status_code = 404
206
      req.resp_body = serializer.DumpJson({
207
        "code": 404,
208
        "message": "Method not found",
209
        "explain": "Explanation goes here",
210
        })
211
    else:
212
      req.success = True
213
      req.resp_status_code = http.HTTP_OK
214
      req.resp_body = serializer.DumpJson((True, hash(req.host)))
215

    
216
  def testHttpError(self):
217
    nodes = ["uaf6pbbv%s" % i for i in range(50)]
218
    body = dict((n, "") for n in nodes)
219
    resolver = rpc._StaticResolver(nodes)
220

    
221
    httperrnodes = set(nodes[1::7])
222
    self.assertEqual(len(httperrnodes), 7)
223

    
224
    failnodes = set(nodes[2::3]) - httperrnodes
225
    self.assertEqual(len(failnodes), 14)
226

    
227
    self.assertEqual(len(set(nodes) - failnodes - httperrnodes), 29)
228

    
229
    proc = rpc._RpcProcessor(resolver, 15165)
230
    http_proc = \
231
      _FakeRequestProcessor(compat.partial(self._GetHttpErrorResponse,
232
                                           httperrnodes, failnodes))
233
    result = proc(nodes, "vg_list", body, rpc._TMO_URGENT, NotImplemented,
234
                  _req_process_fn=http_proc)
235
    self.assertEqual(sorted(result.keys()), sorted(nodes))
236

    
237
    for name in nodes:
238
      lhresp = result[name]
239
      self.assertFalse(lhresp.offline)
240
      self.assertEqual(lhresp.node, name)
241
      self.assertEqual(lhresp.call, "vg_list")
242

    
243
      if name in httperrnodes:
244
        self.assert_(lhresp.fail_msg)
245
        self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
246
      elif name in failnodes:
247
        self.assert_(lhresp.fail_msg)
248
        self.assertRaises(errors.OpPrereqError, lhresp.Raise, "failed",
249
                          prereq=True, ecode=errors.ECODE_INVAL)
250
      else:
251
        self.assertFalse(lhresp.fail_msg)
252
        self.assertEqual(lhresp.payload, hash(name))
253
        lhresp.Raise("should not raise")
254

    
255
    self.assertEqual(http_proc.reqcount, len(nodes))
256

    
257
  def _GetInvalidResponseA(self, req):
258
    self.assertEqual(req.path, "/version")
259
    req.success = True
260
    req.resp_status_code = http.HTTP_OK
261
    req.resp_body = serializer.DumpJson(("This", "is", "an", "invalid",
262
                                         "response", "!", 1, 2, 3))
263

    
264
  def _GetInvalidResponseB(self, req):
265
    self.assertEqual(req.path, "/version")
266
    req.success = True
267
    req.resp_status_code = http.HTTP_OK
268
    req.resp_body = serializer.DumpJson("invalid response")
269

    
270
  def testInvalidResponse(self):
271
    resolver = rpc._StaticResolver(["oqo7lanhly.example.com"])
272
    proc = rpc._RpcProcessor(resolver, 19978)
273

    
274
    for fn in [self._GetInvalidResponseA, self._GetInvalidResponseB]:
275
      http_proc = _FakeRequestProcessor(fn)
276
      host = "oqo7lanhly.example.com"
277
      body = {host: ""}
278
      result = proc([host], "version", body, 60, NotImplemented,
279
                    _req_process_fn=http_proc)
280
      self.assertEqual(result.keys(), [host])
281
      lhresp = result[host]
282
      self.assertFalse(lhresp.offline)
283
      self.assertEqual(lhresp.node, host)
284
      self.assert_(lhresp.fail_msg)
285
      self.assertFalse(lhresp.payload)
286
      self.assertEqual(lhresp.call, "version")
287
      self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
288
      self.assertEqual(http_proc.reqcount, 1)
289

    
290
  def _GetBodyTestResponse(self, test_data, req):
291
    self.assertEqual(req.host, "192.0.2.84")
292
    self.assertEqual(req.port, 18700)
293
    self.assertEqual(req.path, "/upload_file")
294
    self.assertEqual(serializer.LoadJson(req.post_data), test_data)
295
    req.success = True
296
    req.resp_status_code = http.HTTP_OK
297
    req.resp_body = serializer.DumpJson((True, None))
298

    
299
  def testResponseBody(self):
300
    test_data = {
301
      "Hello": "World",
302
      "xyz": range(10),
303
      }
304
    resolver = rpc._StaticResolver(["192.0.2.84"])
305
    http_proc = _FakeRequestProcessor(compat.partial(self._GetBodyTestResponse,
306
                                                     test_data))
307
    proc = rpc._RpcProcessor(resolver, 18700)
308
    host = "node19759"
309
    body = {host: serializer.DumpJson(test_data)}
310
    result = proc([host], "upload_file", body, 30, NotImplemented,
311
                  _req_process_fn=http_proc)
312
    self.assertEqual(result.keys(), [host])
313
    lhresp = result[host]
314
    self.assertFalse(lhresp.offline)
315
    self.assertEqual(lhresp.node, host)
316
    self.assertFalse(lhresp.fail_msg)
317
    self.assertEqual(lhresp.payload, None)
318
    self.assertEqual(lhresp.call, "upload_file")
319
    lhresp.Raise("should not raise")
320
    self.assertEqual(http_proc.reqcount, 1)
321

    
322

    
323
class TestSsconfResolver(unittest.TestCase):
324
  def testSsconfLookup(self):
325
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
326
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
327
    node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
328
    ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
329
    result = rpc._SsconfResolver(node_list, NotImplemented,
330
                                 ssc=ssc, nslookup_fn=NotImplemented)
331
    self.assertEqual(result, zip(node_list, addr_list))
332

    
333
  def testNsLookup(self):
334
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
335
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
336
    ssc = GetFakeSimpleStoreClass(lambda _: [])
337
    node_addr_map = dict(zip(node_list, addr_list))
338
    nslookup_fn = lambda name, family=None: node_addr_map.get(name)
339
    result = rpc._SsconfResolver(node_list, NotImplemented,
340
                                 ssc=ssc, nslookup_fn=nslookup_fn)
341
    self.assertEqual(result, zip(node_list, addr_list))
342

    
343
  def testBothLookups(self):
344
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
345
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
346
    n = len(addr_list) / 2
347
    node_addr_list = [" ".join(t) for t in zip(node_list[n:], addr_list[n:])]
348
    ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
349
    node_addr_map = dict(zip(node_list[:n], addr_list[:n]))
350
    nslookup_fn = lambda name, family=None: node_addr_map.get(name)
351
    result = rpc._SsconfResolver(node_list, NotImplemented,
352
                                 ssc=ssc, nslookup_fn=nslookup_fn)
353
    self.assertEqual(result, zip(node_list, addr_list))
354

    
355
  def testAddressLookupIPv6(self):
356
    addr_list = ["2001:db8::%d" % n for n in range(0, 255, 11)]
357
    node_list = ["node%d.example.com" % n for n in range(0, 255, 11)]
358
    node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
359
    ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
360
    result = rpc._SsconfResolver(node_list, NotImplemented,
361
                                 ssc=ssc, nslookup_fn=NotImplemented)
362
    self.assertEqual(result, zip(node_list, addr_list))
363

    
364

    
365
class TestStaticResolver(unittest.TestCase):
366
  def test(self):
367
    addresses = ["192.0.2.%d" % n for n in range(0, 123, 7)]
368
    nodes = ["node%s.example.com" % n for n in range(0, 123, 7)]
369
    res = rpc._StaticResolver(addresses)
370
    self.assertEqual(res(nodes, NotImplemented), zip(nodes, addresses))
371

    
372
  def testWrongLength(self):
373
    res = rpc._StaticResolver([])
374
    self.assertRaises(AssertionError, res, ["abc"], NotImplemented)
375

    
376

    
377
class TestNodeConfigResolver(unittest.TestCase):
378
  @staticmethod
379
  def _GetSingleOnlineNode(name):
380
    assert name == "node90.example.com"
381
    return objects.Node(name=name, offline=False, primary_ip="192.0.2.90")
382

    
383
  @staticmethod
384
  def _GetSingleOfflineNode(name):
385
    assert name == "node100.example.com"
386
    return objects.Node(name=name, offline=True, primary_ip="192.0.2.100")
387

    
388
  def testSingleOnline(self):
389
    self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOnlineNode,
390
                                             NotImplemented,
391
                                             ["node90.example.com"], None),
392
                     [("node90.example.com", "192.0.2.90")])
393

    
394
  def testSingleOffline(self):
395
    self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOfflineNode,
396
                                             NotImplemented,
397
                                             ["node100.example.com"], None),
398
                     [("node100.example.com", rpc._OFFLINE)])
399

    
400
  def testUnknownSingleNode(self):
401
    self.assertEqual(rpc._NodeConfigResolver(lambda _: None, NotImplemented,
402
                                             ["node110.example.com"], None),
403
                     [("node110.example.com", "node110.example.com")])
404

    
405
  def testMultiEmpty(self):
406
    self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
407
                                             lambda: {},
408
                                             [], None),
409
                     [])
410

    
411
  def testMultiSomeOffline(self):
412
    nodes = dict(("node%s.example.com" % i,
413
                  objects.Node(name="node%s.example.com" % i,
414
                               offline=((i % 3) == 0),
415
                               primary_ip="192.0.2.%s" % i))
416
                  for i in range(1, 255))
417

    
418
    # Resolve no names
419
    self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
420
                                             lambda: nodes,
421
                                             [], None),
422
                     [])
423

    
424
    # Offline, online and unknown hosts
425
    self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
426
                                             lambda: nodes,
427
                                             ["node3.example.com",
428
                                              "node92.example.com",
429
                                              "node54.example.com",
430
                                              "unknown.example.com",],
431
                                             None), [
432
      ("node3.example.com", rpc._OFFLINE),
433
      ("node92.example.com", "192.0.2.92"),
434
      ("node54.example.com", rpc._OFFLINE),
435
      ("unknown.example.com", "unknown.example.com"),
436
      ])
437

    
438

    
439
if __name__ == "__main__":
440
  testutils.GanetiTestProgram()