Statistics
| Branch: | Tag: | Revision:

root / test / ganeti.rpc_unittest.py @ abbf2cd9

History | View | Annotate | Download (16.1 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2010 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for testing ganeti.rpc"""
23

    
24
import os
25
import sys
26
import unittest
27

    
28
from ganeti import constants
29
from ganeti import compat
30
from ganeti import rpc
31
from ganeti import http
32
from ganeti import errors
33
from ganeti import serializer
34
from ganeti import objects
35

    
36
import testutils
37

    
38

    
39
class TestTimeouts(unittest.TestCase):
40
  def test(self):
41
    names = [name[len("call_"):] for name in dir(rpc.RpcRunner)
42
             if name.startswith("call_")]
43
    self.assertEqual(len(names), len(rpc._TIMEOUTS))
44
    self.assertFalse([name for name in names
45
                      if not (rpc._TIMEOUTS[name] is None or
46
                              rpc._TIMEOUTS[name] > 0)])
47

    
48

    
49
class _FakeRequestProcessor:
50
  def __init__(self, response_fn):
51
    self._response_fn = response_fn
52
    self.reqcount = 0
53

    
54
  def __call__(self, reqs, lock_monitor_cb=None):
55
    assert lock_monitor_cb is None or callable(lock_monitor_cb)
56
    for req in reqs:
57
      self.reqcount += 1
58
      self._response_fn(req)
59

    
60

    
61
def GetFakeSimpleStoreClass(fn):
62
  class FakeSimpleStore:
63
    GetNodePrimaryIPList = fn
64
    GetPrimaryIPFamily = lambda _: None
65

    
66
  return FakeSimpleStore
67

    
68

    
69
class TestRpcProcessor(unittest.TestCase):
70
  def _FakeAddressLookup(self, map):
71
    return lambda node_list: [map.get(node) for node in node_list]
72

    
73
  def _GetVersionResponse(self, req):
74
    self.assertEqual(req.host, "127.0.0.1")
75
    self.assertEqual(req.port, 24094)
76
    self.assertEqual(req.path, "/version")
77
    self.assertEqual(req.read_timeout, rpc._TMO_URGENT)
78
    req.success = True
79
    req.resp_status_code = http.HTTP_OK
80
    req.resp_body = serializer.DumpJson((True, 123))
81

    
82
  def testVersionSuccess(self):
83
    resolver = rpc._StaticResolver(["127.0.0.1"])
84
    http_proc = _FakeRequestProcessor(self._GetVersionResponse)
85
    proc = rpc._RpcProcessor(resolver, 24094)
86
    result = proc(["localhost"], "version", None, _req_process_fn=http_proc)
87
    self.assertEqual(result.keys(), ["localhost"])
88
    lhresp = result["localhost"]
89
    self.assertFalse(lhresp.offline)
90
    self.assertEqual(lhresp.node, "localhost")
91
    self.assertFalse(lhresp.fail_msg)
92
    self.assertEqual(lhresp.payload, 123)
93
    self.assertEqual(lhresp.call, "version")
94
    lhresp.Raise("should not raise")
95
    self.assertEqual(http_proc.reqcount, 1)
96

    
97
  def _ReadTimeoutResponse(self, req):
98
    self.assertEqual(req.host, "192.0.2.13")
99
    self.assertEqual(req.port, 19176)
100
    self.assertEqual(req.path, "/version")
101
    self.assertEqual(req.read_timeout, 12356)
102
    req.success = True
103
    req.resp_status_code = http.HTTP_OK
104
    req.resp_body = serializer.DumpJson((True, -1))
105

    
106
  def testReadTimeout(self):
107
    resolver = rpc._StaticResolver(["192.0.2.13"])
108
    http_proc = _FakeRequestProcessor(self._ReadTimeoutResponse)
109
    proc = rpc._RpcProcessor(resolver, 19176)
110
    result = proc(["node31856"], "version", None, _req_process_fn=http_proc,
111
                  read_timeout=12356)
112
    self.assertEqual(result.keys(), ["node31856"])
113
    lhresp = result["node31856"]
114
    self.assertFalse(lhresp.offline)
115
    self.assertEqual(lhresp.node, "node31856")
116
    self.assertFalse(lhresp.fail_msg)
117
    self.assertEqual(lhresp.payload, -1)
118
    self.assertEqual(lhresp.call, "version")
119
    lhresp.Raise("should not raise")
120
    self.assertEqual(http_proc.reqcount, 1)
121

    
122
  def testOfflineNode(self):
123
    resolver = rpc._StaticResolver([rpc._OFFLINE])
124
    http_proc = _FakeRequestProcessor(NotImplemented)
125
    proc = rpc._RpcProcessor(resolver, 30668)
126
    result = proc(["n17296"], "version", None, _req_process_fn=http_proc)
127
    self.assertEqual(result.keys(), ["n17296"])
128
    lhresp = result["n17296"]
129
    self.assertTrue(lhresp.offline)
130
    self.assertEqual(lhresp.node, "n17296")
131
    self.assertTrue(lhresp.fail_msg)
132
    self.assertFalse(lhresp.payload)
133
    self.assertEqual(lhresp.call, "version")
134

    
135
    # With a message
136
    self.assertRaises(errors.OpExecError, lhresp.Raise, "should raise")
137

    
138
    # No message
139
    self.assertRaises(errors.OpExecError, lhresp.Raise, None)
140

    
141
    self.assertEqual(http_proc.reqcount, 0)
142

    
143
  def _GetMultiVersionResponse(self, req):
144
    self.assert_(req.host.startswith("node"))
145
    self.assertEqual(req.port, 23245)
146
    self.assertEqual(req.path, "/version")
147
    req.success = True
148
    req.resp_status_code = http.HTTP_OK
149
    req.resp_body = serializer.DumpJson((True, 987))
150

    
151
  def testMultiVersionSuccess(self):
152
    nodes = ["node%s" % i for i in range(50)]
153
    resolver = rpc._StaticResolver(nodes)
154
    http_proc = _FakeRequestProcessor(self._GetMultiVersionResponse)
155
    proc = rpc._RpcProcessor(resolver, 23245)
156
    result = proc(nodes, "version", None, _req_process_fn=http_proc)
157
    self.assertEqual(sorted(result.keys()), sorted(nodes))
158

    
159
    for name in nodes:
160
      lhresp = result[name]
161
      self.assertFalse(lhresp.offline)
162
      self.assertEqual(lhresp.node, name)
163
      self.assertFalse(lhresp.fail_msg)
164
      self.assertEqual(lhresp.payload, 987)
165
      self.assertEqual(lhresp.call, "version")
166
      lhresp.Raise("should not raise")
167

    
168
    self.assertEqual(http_proc.reqcount, len(nodes))
169

    
170
  def _GetVersionResponseFail(self, errinfo, req):
171
    self.assertEqual(req.path, "/version")
172
    req.success = True
173
    req.resp_status_code = http.HTTP_OK
174
    req.resp_body = serializer.DumpJson((False, errinfo))
175

    
176
  def testVersionFailure(self):
177
    resolver = rpc._StaticResolver(["aef9ur4i.example.com"])
178
    proc = rpc._RpcProcessor(resolver, 5903)
179
    for errinfo in [None, "Unknown error"]:
180
      http_proc = \
181
        _FakeRequestProcessor(compat.partial(self._GetVersionResponseFail,
182
                                             errinfo))
183
      result = proc(["aef9ur4i.example.com"], "version", None,
184
                    _req_process_fn=http_proc)
185
      self.assertEqual(result.keys(), ["aef9ur4i.example.com"])
186
      lhresp = result["aef9ur4i.example.com"]
187
      self.assertFalse(lhresp.offline)
188
      self.assertEqual(lhresp.node, "aef9ur4i.example.com")
189
      self.assert_(lhresp.fail_msg)
190
      self.assertFalse(lhresp.payload)
191
      self.assertEqual(lhresp.call, "version")
192
      self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
193
      self.assertEqual(http_proc.reqcount, 1)
194

    
195
  def _GetHttpErrorResponse(self, httperrnodes, failnodes, req):
196
    self.assertEqual(req.path, "/vg_list")
197
    self.assertEqual(req.port, 15165)
198

    
199
    if req.host in httperrnodes:
200
      req.success = False
201
      req.error = "Node set up for HTTP errors"
202

    
203
    elif req.host in failnodes:
204
      req.success = True
205
      req.resp_status_code = 404
206
      req.resp_body = serializer.DumpJson({
207
        "code": 404,
208
        "message": "Method not found",
209
        "explain": "Explanation goes here",
210
        })
211
    else:
212
      req.success = True
213
      req.resp_status_code = http.HTTP_OK
214
      req.resp_body = serializer.DumpJson((True, hash(req.host)))
215

    
216
  def testHttpError(self):
217
    nodes = ["uaf6pbbv%s" % i for i in range(50)]
218
    resolver = rpc._StaticResolver(nodes)
219

    
220
    httperrnodes = set(nodes[1::7])
221
    self.assertEqual(len(httperrnodes), 7)
222

    
223
    failnodes = set(nodes[2::3]) - httperrnodes
224
    self.assertEqual(len(failnodes), 14)
225

    
226
    self.assertEqual(len(set(nodes) - failnodes - httperrnodes), 29)
227

    
228
    proc = rpc._RpcProcessor(resolver, 15165)
229
    http_proc = \
230
      _FakeRequestProcessor(compat.partial(self._GetHttpErrorResponse,
231
                                           httperrnodes, failnodes))
232
    result = proc(nodes, "vg_list", None, _req_process_fn=http_proc)
233
    self.assertEqual(sorted(result.keys()), sorted(nodes))
234

    
235
    for name in nodes:
236
      lhresp = result[name]
237
      self.assertFalse(lhresp.offline)
238
      self.assertEqual(lhresp.node, name)
239
      self.assertEqual(lhresp.call, "vg_list")
240

    
241
      if name in httperrnodes:
242
        self.assert_(lhresp.fail_msg)
243
        self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
244
      elif name in failnodes:
245
        self.assert_(lhresp.fail_msg)
246
        self.assertRaises(errors.OpPrereqError, lhresp.Raise, "failed",
247
                          prereq=True, ecode=errors.ECODE_INVAL)
248
      else:
249
        self.assertFalse(lhresp.fail_msg)
250
        self.assertEqual(lhresp.payload, hash(name))
251
        lhresp.Raise("should not raise")
252

    
253
    self.assertEqual(http_proc.reqcount, len(nodes))
254

    
255
  def _GetInvalidResponseA(self, req):
256
    self.assertEqual(req.path, "/version")
257
    req.success = True
258
    req.resp_status_code = http.HTTP_OK
259
    req.resp_body = serializer.DumpJson(("This", "is", "an", "invalid",
260
                                         "response", "!", 1, 2, 3))
261

    
262
  def _GetInvalidResponseB(self, req):
263
    self.assertEqual(req.path, "/version")
264
    req.success = True
265
    req.resp_status_code = http.HTTP_OK
266
    req.resp_body = serializer.DumpJson("invalid response")
267

    
268
  def testInvalidResponse(self):
269
    resolver = rpc._StaticResolver(["oqo7lanhly.example.com"])
270
    proc = rpc._RpcProcessor(resolver, 19978)
271

    
272
    for fn in [self._GetInvalidResponseA, self._GetInvalidResponseB]:
273
      http_proc = _FakeRequestProcessor(fn)
274
      result = proc(["oqo7lanhly.example.com"], "version", None,
275
                    _req_process_fn=http_proc)
276
      self.assertEqual(result.keys(), ["oqo7lanhly.example.com"])
277
      lhresp = result["oqo7lanhly.example.com"]
278
      self.assertFalse(lhresp.offline)
279
      self.assertEqual(lhresp.node, "oqo7lanhly.example.com")
280
      self.assert_(lhresp.fail_msg)
281
      self.assertFalse(lhresp.payload)
282
      self.assertEqual(lhresp.call, "version")
283
      self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
284
      self.assertEqual(http_proc.reqcount, 1)
285

    
286
  def _GetBodyTestResponse(self, test_data, req):
287
    self.assertEqual(req.host, "192.0.2.84")
288
    self.assertEqual(req.port, 18700)
289
    self.assertEqual(req.path, "/upload_file")
290
    self.assertEqual(serializer.LoadJson(req.post_data), test_data)
291
    req.success = True
292
    req.resp_status_code = http.HTTP_OK
293
    req.resp_body = serializer.DumpJson((True, None))
294

    
295
  def testResponseBody(self):
296
    test_data = {
297
      "Hello": "World",
298
      "xyz": range(10),
299
      }
300
    resolver = rpc._StaticResolver(["192.0.2.84"])
301
    http_proc = _FakeRequestProcessor(compat.partial(self._GetBodyTestResponse,
302
                                                     test_data))
303
    proc = rpc._RpcProcessor(resolver, 18700)
304
    body = serializer.DumpJson(test_data)
305
    result = proc(["node19759"], "upload_file", body, _req_process_fn=http_proc)
306
    self.assertEqual(result.keys(), ["node19759"])
307
    lhresp = result["node19759"]
308
    self.assertFalse(lhresp.offline)
309
    self.assertEqual(lhresp.node, "node19759")
310
    self.assertFalse(lhresp.fail_msg)
311
    self.assertEqual(lhresp.payload, None)
312
    self.assertEqual(lhresp.call, "upload_file")
313
    lhresp.Raise("should not raise")
314
    self.assertEqual(http_proc.reqcount, 1)
315

    
316

    
317
class TestSsconfResolver(unittest.TestCase):
318
  def testSsconfLookup(self):
319
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
320
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
321
    node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
322
    ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
323
    result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=NotImplemented)
324
    self.assertEqual(result, zip(node_list, addr_list))
325

    
326
  def testNsLookup(self):
327
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
328
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
329
    ssc = GetFakeSimpleStoreClass(lambda _: [])
330
    node_addr_map = dict(zip(node_list, addr_list))
331
    nslookup_fn = lambda name, family=None: node_addr_map.get(name)
332
    result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=nslookup_fn)
333
    self.assertEqual(result, zip(node_list, addr_list))
334

    
335
  def testBothLookups(self):
336
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
337
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
338
    n = len(addr_list) / 2
339
    node_addr_list = [" ".join(t) for t in zip(node_list[n:], addr_list[n:])]
340
    ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
341
    node_addr_map = dict(zip(node_list[:n], addr_list[:n]))
342
    nslookup_fn = lambda name, family=None: node_addr_map.get(name)
343
    result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=nslookup_fn)
344
    self.assertEqual(result, zip(node_list, addr_list))
345

    
346
  def testAddressLookupIPv6(self):
347
    addr_list = ["2001:db8::%d" % n for n in range(0, 255, 11)]
348
    node_list = ["node%d.example.com" % n for n in range(0, 255, 11)]
349
    node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
350
    ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
351
    result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=NotImplemented)
352
    self.assertEqual(result, zip(node_list, addr_list))
353

    
354

    
355
class TestStaticResolver(unittest.TestCase):
356
  def test(self):
357
    addresses = ["192.0.2.%d" % n for n in range(0, 123, 7)]
358
    nodes = ["node%s.example.com" % n for n in range(0, 123, 7)]
359
    res = rpc._StaticResolver(addresses)
360
    self.assertEqual(res(nodes), zip(nodes, addresses))
361

    
362
  def testWrongLength(self):
363
    res = rpc._StaticResolver([])
364
    self.assertRaises(AssertionError, res, ["abc"])
365

    
366

    
367
class TestNodeConfigResolver(unittest.TestCase):
368
  @staticmethod
369
  def _GetSingleOnlineNode(name):
370
    assert name == "node90.example.com"
371
    return objects.Node(name=name, offline=False, primary_ip="192.0.2.90")
372

    
373
  @staticmethod
374
  def _GetSingleOfflineNode(name):
375
    assert name == "node100.example.com"
376
    return objects.Node(name=name, offline=True, primary_ip="192.0.2.100")
377

    
378
  def testSingleOnline(self):
379
    self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOnlineNode,
380
                                             NotImplemented,
381
                                             ["node90.example.com"]),
382
                     [("node90.example.com", "192.0.2.90")])
383

    
384
  def testSingleOffline(self):
385
    self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOfflineNode,
386
                                             NotImplemented,
387
                                             ["node100.example.com"]),
388
                     [("node100.example.com", rpc._OFFLINE)])
389

    
390
  def testUnknownSingleNode(self):
391
    self.assertEqual(rpc._NodeConfigResolver(lambda _: None, NotImplemented,
392
                                             ["node110.example.com"]),
393
                     [("node110.example.com", "node110.example.com")])
394

    
395
  def testMultiEmpty(self):
396
    self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
397
                                             lambda: {},
398
                                             []),
399
                     [])
400

    
401
  def testMultiSomeOffline(self):
402
    nodes = dict(("node%s.example.com" % i,
403
                  objects.Node(name="node%s.example.com" % i,
404
                               offline=((i % 3) == 0),
405
                               primary_ip="192.0.2.%s" % i))
406
                  for i in range(1, 255))
407

    
408
    # Resolve no names
409
    self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
410
                                             lambda: nodes,
411
                                             []),
412
                     [])
413

    
414
    # Offline, online and unknown hosts
415
    self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
416
                                             lambda: nodes,
417
                                             ["node3.example.com",
418
                                              "node92.example.com",
419
                                              "node54.example.com",
420
                                              "unknown.example.com",]), [
421
      ("node3.example.com", rpc._OFFLINE),
422
      ("node92.example.com", "192.0.2.92"),
423
      ("node54.example.com", rpc._OFFLINE),
424
      ("unknown.example.com", "unknown.example.com"),
425
      ])
426

    
427

    
428
if __name__ == "__main__":
429
  testutils.GanetiTestProgram()