Statistics
| Branch: | Tag: | Revision:

root / test / ganeti.rpc_unittest.py @ d9de612c

History | View | Annotate | Download (16.1 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2010, 2011 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for testing ganeti.rpc"""
23

    
24
import os
25
import sys
26
import unittest
27

    
28
from ganeti import constants
29
from ganeti import compat
30
from ganeti import rpc
31
from ganeti import http
32
from ganeti import errors
33
from ganeti import serializer
34
from ganeti import objects
35

    
36
import testutils
37

    
38

    
39
class _FakeRequestProcessor:
40
  def __init__(self, response_fn):
41
    self._response_fn = response_fn
42
    self.reqcount = 0
43

    
44
  def __call__(self, reqs, lock_monitor_cb=None):
45
    assert lock_monitor_cb is None or callable(lock_monitor_cb)
46
    for req in reqs:
47
      self.reqcount += 1
48
      self._response_fn(req)
49

    
50

    
51
def GetFakeSimpleStoreClass(fn):
52
  class FakeSimpleStore:
53
    GetNodePrimaryIPList = fn
54
    GetPrimaryIPFamily = lambda _: None
55

    
56
  return FakeSimpleStore
57

    
58

    
59
class TestRpcProcessor(unittest.TestCase):
60
  def _FakeAddressLookup(self, map):
61
    return lambda node_list: [map.get(node) for node in node_list]
62

    
63
  def _GetVersionResponse(self, req):
64
    self.assertEqual(req.host, "127.0.0.1")
65
    self.assertEqual(req.port, 24094)
66
    self.assertEqual(req.path, "/version")
67
    self.assertEqual(req.read_timeout, rpc._TMO_URGENT)
68
    req.success = True
69
    req.resp_status_code = http.HTTP_OK
70
    req.resp_body = serializer.DumpJson((True, 123))
71

    
72
  def testVersionSuccess(self):
73
    resolver = rpc._StaticResolver(["127.0.0.1"])
74
    http_proc = _FakeRequestProcessor(self._GetVersionResponse)
75
    proc = rpc._RpcProcessor(resolver, 24094)
76
    result = proc(["localhost"], "version", {"localhost": ""},
77
                  _req_process_fn=http_proc, read_timeout=60)
78
    self.assertEqual(result.keys(), ["localhost"])
79
    lhresp = result["localhost"]
80
    self.assertFalse(lhresp.offline)
81
    self.assertEqual(lhresp.node, "localhost")
82
    self.assertFalse(lhresp.fail_msg)
83
    self.assertEqual(lhresp.payload, 123)
84
    self.assertEqual(lhresp.call, "version")
85
    lhresp.Raise("should not raise")
86
    self.assertEqual(http_proc.reqcount, 1)
87

    
88
  def _ReadTimeoutResponse(self, req):
89
    self.assertEqual(req.host, "192.0.2.13")
90
    self.assertEqual(req.port, 19176)
91
    self.assertEqual(req.path, "/version")
92
    self.assertEqual(req.read_timeout, 12356)
93
    req.success = True
94
    req.resp_status_code = http.HTTP_OK
95
    req.resp_body = serializer.DumpJson((True, -1))
96

    
97
  def testReadTimeout(self):
98
    resolver = rpc._StaticResolver(["192.0.2.13"])
99
    http_proc = _FakeRequestProcessor(self._ReadTimeoutResponse)
100
    proc = rpc._RpcProcessor(resolver, 19176)
101
    host = "node31856"
102
    body = {host: ""}
103
    result = proc([host], "version", body, _req_process_fn=http_proc,
104
                  read_timeout=12356)
105
    self.assertEqual(result.keys(), [host])
106
    lhresp = result[host]
107
    self.assertFalse(lhresp.offline)
108
    self.assertEqual(lhresp.node, host)
109
    self.assertFalse(lhresp.fail_msg)
110
    self.assertEqual(lhresp.payload, -1)
111
    self.assertEqual(lhresp.call, "version")
112
    lhresp.Raise("should not raise")
113
    self.assertEqual(http_proc.reqcount, 1)
114

    
115
  def testOfflineNode(self):
116
    resolver = rpc._StaticResolver([rpc._OFFLINE])
117
    http_proc = _FakeRequestProcessor(NotImplemented)
118
    proc = rpc._RpcProcessor(resolver, 30668)
119
    host = "n17296"
120
    body = {host: ""}
121
    result = proc([host], "version", body, _req_process_fn=http_proc,
122
                  read_timeout=60)
123
    self.assertEqual(result.keys(), [host])
124
    lhresp = result[host]
125
    self.assertTrue(lhresp.offline)
126
    self.assertEqual(lhresp.node, host)
127
    self.assertTrue(lhresp.fail_msg)
128
    self.assertFalse(lhresp.payload)
129
    self.assertEqual(lhresp.call, "version")
130

    
131
    # With a message
132
    self.assertRaises(errors.OpExecError, lhresp.Raise, "should raise")
133

    
134
    # No message
135
    self.assertRaises(errors.OpExecError, lhresp.Raise, None)
136

    
137
    self.assertEqual(http_proc.reqcount, 0)
138

    
139
  def _GetMultiVersionResponse(self, req):
140
    self.assert_(req.host.startswith("node"))
141
    self.assertEqual(req.port, 23245)
142
    self.assertEqual(req.path, "/version")
143
    req.success = True
144
    req.resp_status_code = http.HTTP_OK
145
    req.resp_body = serializer.DumpJson((True, 987))
146

    
147
  def testMultiVersionSuccess(self):
148
    nodes = ["node%s" % i for i in range(50)]
149
    body = dict((n, "") for n in nodes)
150
    resolver = rpc._StaticResolver(nodes)
151
    http_proc = _FakeRequestProcessor(self._GetMultiVersionResponse)
152
    proc = rpc._RpcProcessor(resolver, 23245)
153
    result = proc(nodes, "version", body, _req_process_fn=http_proc,
154
                  read_timeout=60)
155
    self.assertEqual(sorted(result.keys()), sorted(nodes))
156

    
157
    for name in nodes:
158
      lhresp = result[name]
159
      self.assertFalse(lhresp.offline)
160
      self.assertEqual(lhresp.node, name)
161
      self.assertFalse(lhresp.fail_msg)
162
      self.assertEqual(lhresp.payload, 987)
163
      self.assertEqual(lhresp.call, "version")
164
      lhresp.Raise("should not raise")
165

    
166
    self.assertEqual(http_proc.reqcount, len(nodes))
167

    
168
  def _GetVersionResponseFail(self, errinfo, req):
169
    self.assertEqual(req.path, "/version")
170
    req.success = True
171
    req.resp_status_code = http.HTTP_OK
172
    req.resp_body = serializer.DumpJson((False, errinfo))
173

    
174
  def testVersionFailure(self):
175
    resolver = rpc._StaticResolver(["aef9ur4i.example.com"])
176
    proc = rpc._RpcProcessor(resolver, 5903)
177
    for errinfo in [None, "Unknown error"]:
178
      http_proc = \
179
        _FakeRequestProcessor(compat.partial(self._GetVersionResponseFail,
180
                                             errinfo))
181
      host = "aef9ur4i.example.com"
182
      body = {host: ""}
183
      result = proc(body.keys(), "version", body,
184
                    _req_process_fn=http_proc, read_timeout=60)
185
      self.assertEqual(result.keys(), [host])
186
      lhresp = result[host]
187
      self.assertFalse(lhresp.offline)
188
      self.assertEqual(lhresp.node, host)
189
      self.assert_(lhresp.fail_msg)
190
      self.assertFalse(lhresp.payload)
191
      self.assertEqual(lhresp.call, "version")
192
      self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
193
      self.assertEqual(http_proc.reqcount, 1)
194

    
195
  def _GetHttpErrorResponse(self, httperrnodes, failnodes, req):
196
    self.assertEqual(req.path, "/vg_list")
197
    self.assertEqual(req.port, 15165)
198

    
199
    if req.host in httperrnodes:
200
      req.success = False
201
      req.error = "Node set up for HTTP errors"
202

    
203
    elif req.host in failnodes:
204
      req.success = True
205
      req.resp_status_code = 404
206
      req.resp_body = serializer.DumpJson({
207
        "code": 404,
208
        "message": "Method not found",
209
        "explain": "Explanation goes here",
210
        })
211
    else:
212
      req.success = True
213
      req.resp_status_code = http.HTTP_OK
214
      req.resp_body = serializer.DumpJson((True, hash(req.host)))
215

    
216
  def testHttpError(self):
217
    nodes = ["uaf6pbbv%s" % i for i in range(50)]
218
    body = dict((n, "") for n in nodes)
219
    resolver = rpc._StaticResolver(nodes)
220

    
221
    httperrnodes = set(nodes[1::7])
222
    self.assertEqual(len(httperrnodes), 7)
223

    
224
    failnodes = set(nodes[2::3]) - httperrnodes
225
    self.assertEqual(len(failnodes), 14)
226

    
227
    self.assertEqual(len(set(nodes) - failnodes - httperrnodes), 29)
228

    
229
    proc = rpc._RpcProcessor(resolver, 15165)
230
    http_proc = \
231
      _FakeRequestProcessor(compat.partial(self._GetHttpErrorResponse,
232
                                           httperrnodes, failnodes))
233
    result = proc(nodes, "vg_list", body, _req_process_fn=http_proc,
234
                  read_timeout=rpc._TMO_URGENT)
235
    self.assertEqual(sorted(result.keys()), sorted(nodes))
236

    
237
    for name in nodes:
238
      lhresp = result[name]
239
      self.assertFalse(lhresp.offline)
240
      self.assertEqual(lhresp.node, name)
241
      self.assertEqual(lhresp.call, "vg_list")
242

    
243
      if name in httperrnodes:
244
        self.assert_(lhresp.fail_msg)
245
        self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
246
      elif name in failnodes:
247
        self.assert_(lhresp.fail_msg)
248
        self.assertRaises(errors.OpPrereqError, lhresp.Raise, "failed",
249
                          prereq=True, ecode=errors.ECODE_INVAL)
250
      else:
251
        self.assertFalse(lhresp.fail_msg)
252
        self.assertEqual(lhresp.payload, hash(name))
253
        lhresp.Raise("should not raise")
254

    
255
    self.assertEqual(http_proc.reqcount, len(nodes))
256

    
257
  def _GetInvalidResponseA(self, req):
258
    self.assertEqual(req.path, "/version")
259
    req.success = True
260
    req.resp_status_code = http.HTTP_OK
261
    req.resp_body = serializer.DumpJson(("This", "is", "an", "invalid",
262
                                         "response", "!", 1, 2, 3))
263

    
264
  def _GetInvalidResponseB(self, req):
265
    self.assertEqual(req.path, "/version")
266
    req.success = True
267
    req.resp_status_code = http.HTTP_OK
268
    req.resp_body = serializer.DumpJson("invalid response")
269

    
270
  def testInvalidResponse(self):
271
    resolver = rpc._StaticResolver(["oqo7lanhly.example.com"])
272
    proc = rpc._RpcProcessor(resolver, 19978)
273

    
274
    for fn in [self._GetInvalidResponseA, self._GetInvalidResponseB]:
275
      http_proc = _FakeRequestProcessor(fn)
276
      host = "oqo7lanhly.example.com"
277
      body = {host: ""}
278
      result = proc([host], "version", body,
279
                    _req_process_fn=http_proc, read_timeout=60)
280
      self.assertEqual(result.keys(), [host])
281
      lhresp = result[host]
282
      self.assertFalse(lhresp.offline)
283
      self.assertEqual(lhresp.node, host)
284
      self.assert_(lhresp.fail_msg)
285
      self.assertFalse(lhresp.payload)
286
      self.assertEqual(lhresp.call, "version")
287
      self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
288
      self.assertEqual(http_proc.reqcount, 1)
289

    
290
  def _GetBodyTestResponse(self, test_data, req):
291
    self.assertEqual(req.host, "192.0.2.84")
292
    self.assertEqual(req.port, 18700)
293
    self.assertEqual(req.path, "/upload_file")
294
    self.assertEqual(serializer.LoadJson(req.post_data), test_data)
295
    req.success = True
296
    req.resp_status_code = http.HTTP_OK
297
    req.resp_body = serializer.DumpJson((True, None))
298

    
299
  def testResponseBody(self):
300
    test_data = {
301
      "Hello": "World",
302
      "xyz": range(10),
303
      }
304
    resolver = rpc._StaticResolver(["192.0.2.84"])
305
    http_proc = _FakeRequestProcessor(compat.partial(self._GetBodyTestResponse,
306
                                                     test_data))
307
    proc = rpc._RpcProcessor(resolver, 18700)
308
    host = "node19759"
309
    body = {host: serializer.DumpJson(test_data)}
310
    result = proc([host], "upload_file", body, _req_process_fn=http_proc,
311
                  read_timeout=30)
312
    self.assertEqual(result.keys(), [host])
313
    lhresp = result[host]
314
    self.assertFalse(lhresp.offline)
315
    self.assertEqual(lhresp.node, host)
316
    self.assertFalse(lhresp.fail_msg)
317
    self.assertEqual(lhresp.payload, None)
318
    self.assertEqual(lhresp.call, "upload_file")
319
    lhresp.Raise("should not raise")
320
    self.assertEqual(http_proc.reqcount, 1)
321

    
322

    
323
class TestSsconfResolver(unittest.TestCase):
324
  def testSsconfLookup(self):
325
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
326
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
327
    node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
328
    ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
329
    result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=NotImplemented)
330
    self.assertEqual(result, zip(node_list, addr_list))
331

    
332
  def testNsLookup(self):
333
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
334
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
335
    ssc = GetFakeSimpleStoreClass(lambda _: [])
336
    node_addr_map = dict(zip(node_list, addr_list))
337
    nslookup_fn = lambda name, family=None: node_addr_map.get(name)
338
    result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=nslookup_fn)
339
    self.assertEqual(result, zip(node_list, addr_list))
340

    
341
  def testBothLookups(self):
342
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
343
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
344
    n = len(addr_list) / 2
345
    node_addr_list = [" ".join(t) for t in zip(node_list[n:], addr_list[n:])]
346
    ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
347
    node_addr_map = dict(zip(node_list[:n], addr_list[:n]))
348
    nslookup_fn = lambda name, family=None: node_addr_map.get(name)
349
    result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=nslookup_fn)
350
    self.assertEqual(result, zip(node_list, addr_list))
351

    
352
  def testAddressLookupIPv6(self):
353
    addr_list = ["2001:db8::%d" % n for n in range(0, 255, 11)]
354
    node_list = ["node%d.example.com" % n for n in range(0, 255, 11)]
355
    node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
356
    ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
357
    result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=NotImplemented)
358
    self.assertEqual(result, zip(node_list, addr_list))
359

    
360

    
361
class TestStaticResolver(unittest.TestCase):
362
  def test(self):
363
    addresses = ["192.0.2.%d" % n for n in range(0, 123, 7)]
364
    nodes = ["node%s.example.com" % n for n in range(0, 123, 7)]
365
    res = rpc._StaticResolver(addresses)
366
    self.assertEqual(res(nodes), zip(nodes, addresses))
367

    
368
  def testWrongLength(self):
369
    res = rpc._StaticResolver([])
370
    self.assertRaises(AssertionError, res, ["abc"])
371

    
372

    
373
class TestNodeConfigResolver(unittest.TestCase):
374
  @staticmethod
375
  def _GetSingleOnlineNode(name):
376
    assert name == "node90.example.com"
377
    return objects.Node(name=name, offline=False, primary_ip="192.0.2.90")
378

    
379
  @staticmethod
380
  def _GetSingleOfflineNode(name):
381
    assert name == "node100.example.com"
382
    return objects.Node(name=name, offline=True, primary_ip="192.0.2.100")
383

    
384
  def testSingleOnline(self):
385
    self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOnlineNode,
386
                                             NotImplemented,
387
                                             ["node90.example.com"]),
388
                     [("node90.example.com", "192.0.2.90")])
389

    
390
  def testSingleOffline(self):
391
    self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOfflineNode,
392
                                             NotImplemented,
393
                                             ["node100.example.com"]),
394
                     [("node100.example.com", rpc._OFFLINE)])
395

    
396
  def testUnknownSingleNode(self):
397
    self.assertEqual(rpc._NodeConfigResolver(lambda _: None, NotImplemented,
398
                                             ["node110.example.com"]),
399
                     [("node110.example.com", "node110.example.com")])
400

    
401
  def testMultiEmpty(self):
402
    self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
403
                                             lambda: {},
404
                                             []),
405
                     [])
406

    
407
  def testMultiSomeOffline(self):
408
    nodes = dict(("node%s.example.com" % i,
409
                  objects.Node(name="node%s.example.com" % i,
410
                               offline=((i % 3) == 0),
411
                               primary_ip="192.0.2.%s" % i))
412
                  for i in range(1, 255))
413

    
414
    # Resolve no names
415
    self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
416
                                             lambda: nodes,
417
                                             []),
418
                     [])
419

    
420
    # Offline, online and unknown hosts
421
    self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
422
                                             lambda: nodes,
423
                                             ["node3.example.com",
424
                                              "node92.example.com",
425
                                              "node54.example.com",
426
                                              "unknown.example.com",]), [
427
      ("node3.example.com", rpc._OFFLINE),
428
      ("node92.example.com", "192.0.2.92"),
429
      ("node54.example.com", rpc._OFFLINE),
430
      ("unknown.example.com", "unknown.example.com"),
431
      ])
432

    
433

    
434
if __name__ == "__main__":
435
  testutils.GanetiTestProgram()