Statistics
| Branch: | Tag: | Revision:

root / test / ganeti.rpc_unittest.py @ f863d7aa

History | View | Annotate | Download (15.9 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2010, 2011 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for testing ganeti.rpc"""
23

    
24
import os
25
import sys
26
import unittest
27

    
28
from ganeti import constants
29
from ganeti import compat
30
from ganeti import rpc
31
from ganeti import http
32
from ganeti import errors
33
from ganeti import serializer
34
from ganeti import objects
35

    
36
import testutils
37

    
38

    
39
class _FakeRequestProcessor:
40
  def __init__(self, response_fn):
41
    self._response_fn = response_fn
42
    self.reqcount = 0
43

    
44
  def __call__(self, reqs, lock_monitor_cb=None):
45
    assert lock_monitor_cb is None or callable(lock_monitor_cb)
46
    for req in reqs:
47
      self.reqcount += 1
48
      self._response_fn(req)
49

    
50

    
51
def GetFakeSimpleStoreClass(fn):
52
  class FakeSimpleStore:
53
    GetNodePrimaryIPList = fn
54
    GetPrimaryIPFamily = lambda _: None
55

    
56
  return FakeSimpleStore
57

    
58

    
59
class TestRpcProcessor(unittest.TestCase):
60
  def _FakeAddressLookup(self, map):
61
    return lambda node_list: [map.get(node) for node in node_list]
62

    
63
  def _GetVersionResponse(self, req):
64
    self.assertEqual(req.host, "127.0.0.1")
65
    self.assertEqual(req.port, 24094)
66
    self.assertEqual(req.path, "/version")
67
    self.assertEqual(req.read_timeout, rpc._TMO_URGENT)
68
    req.success = True
69
    req.resp_status_code = http.HTTP_OK
70
    req.resp_body = serializer.DumpJson((True, 123))
71

    
72
  def testVersionSuccess(self):
73
    resolver = rpc._StaticResolver(["127.0.0.1"])
74
    http_proc = _FakeRequestProcessor(self._GetVersionResponse)
75
    proc = rpc._RpcProcessor(resolver, 24094)
76
    result = proc(["localhost"], "version", {"localhost": ""}, 60,
77
                  _req_process_fn=http_proc)
78
    self.assertEqual(result.keys(), ["localhost"])
79
    lhresp = result["localhost"]
80
    self.assertFalse(lhresp.offline)
81
    self.assertEqual(lhresp.node, "localhost")
82
    self.assertFalse(lhresp.fail_msg)
83
    self.assertEqual(lhresp.payload, 123)
84
    self.assertEqual(lhresp.call, "version")
85
    lhresp.Raise("should not raise")
86
    self.assertEqual(http_proc.reqcount, 1)
87

    
88
  def _ReadTimeoutResponse(self, req):
89
    self.assertEqual(req.host, "192.0.2.13")
90
    self.assertEqual(req.port, 19176)
91
    self.assertEqual(req.path, "/version")
92
    self.assertEqual(req.read_timeout, 12356)
93
    req.success = True
94
    req.resp_status_code = http.HTTP_OK
95
    req.resp_body = serializer.DumpJson((True, -1))
96

    
97
  def testReadTimeout(self):
98
    resolver = rpc._StaticResolver(["192.0.2.13"])
99
    http_proc = _FakeRequestProcessor(self._ReadTimeoutResponse)
100
    proc = rpc._RpcProcessor(resolver, 19176)
101
    host = "node31856"
102
    body = {host: ""}
103
    result = proc([host], "version", body, 12356, _req_process_fn=http_proc)
104
    self.assertEqual(result.keys(), [host])
105
    lhresp = result[host]
106
    self.assertFalse(lhresp.offline)
107
    self.assertEqual(lhresp.node, host)
108
    self.assertFalse(lhresp.fail_msg)
109
    self.assertEqual(lhresp.payload, -1)
110
    self.assertEqual(lhresp.call, "version")
111
    lhresp.Raise("should not raise")
112
    self.assertEqual(http_proc.reqcount, 1)
113

    
114
  def testOfflineNode(self):
115
    resolver = rpc._StaticResolver([rpc._OFFLINE])
116
    http_proc = _FakeRequestProcessor(NotImplemented)
117
    proc = rpc._RpcProcessor(resolver, 30668)
118
    host = "n17296"
119
    body = {host: ""}
120
    result = proc([host], "version", body, 60, _req_process_fn=http_proc)
121
    self.assertEqual(result.keys(), [host])
122
    lhresp = result[host]
123
    self.assertTrue(lhresp.offline)
124
    self.assertEqual(lhresp.node, host)
125
    self.assertTrue(lhresp.fail_msg)
126
    self.assertFalse(lhresp.payload)
127
    self.assertEqual(lhresp.call, "version")
128

    
129
    # With a message
130
    self.assertRaises(errors.OpExecError, lhresp.Raise, "should raise")
131

    
132
    # No message
133
    self.assertRaises(errors.OpExecError, lhresp.Raise, None)
134

    
135
    self.assertEqual(http_proc.reqcount, 0)
136

    
137
  def _GetMultiVersionResponse(self, req):
138
    self.assert_(req.host.startswith("node"))
139
    self.assertEqual(req.port, 23245)
140
    self.assertEqual(req.path, "/version")
141
    req.success = True
142
    req.resp_status_code = http.HTTP_OK
143
    req.resp_body = serializer.DumpJson((True, 987))
144

    
145
  def testMultiVersionSuccess(self):
146
    nodes = ["node%s" % i for i in range(50)]
147
    body = dict((n, "") for n in nodes)
148
    resolver = rpc._StaticResolver(nodes)
149
    http_proc = _FakeRequestProcessor(self._GetMultiVersionResponse)
150
    proc = rpc._RpcProcessor(resolver, 23245)
151
    result = proc(nodes, "version", body, 60, _req_process_fn=http_proc,)
152
    self.assertEqual(sorted(result.keys()), sorted(nodes))
153

    
154
    for name in nodes:
155
      lhresp = result[name]
156
      self.assertFalse(lhresp.offline)
157
      self.assertEqual(lhresp.node, name)
158
      self.assertFalse(lhresp.fail_msg)
159
      self.assertEqual(lhresp.payload, 987)
160
      self.assertEqual(lhresp.call, "version")
161
      lhresp.Raise("should not raise")
162

    
163
    self.assertEqual(http_proc.reqcount, len(nodes))
164

    
165
  def _GetVersionResponseFail(self, errinfo, req):
166
    self.assertEqual(req.path, "/version")
167
    req.success = True
168
    req.resp_status_code = http.HTTP_OK
169
    req.resp_body = serializer.DumpJson((False, errinfo))
170

    
171
  def testVersionFailure(self):
172
    resolver = rpc._StaticResolver(["aef9ur4i.example.com"])
173
    proc = rpc._RpcProcessor(resolver, 5903)
174
    for errinfo in [None, "Unknown error"]:
175
      http_proc = \
176
        _FakeRequestProcessor(compat.partial(self._GetVersionResponseFail,
177
                                             errinfo))
178
      host = "aef9ur4i.example.com"
179
      body = {host: ""}
180
      result = proc(body.keys(), "version", body, 60,
181
                    _req_process_fn=http_proc)
182
      self.assertEqual(result.keys(), [host])
183
      lhresp = result[host]
184
      self.assertFalse(lhresp.offline)
185
      self.assertEqual(lhresp.node, host)
186
      self.assert_(lhresp.fail_msg)
187
      self.assertFalse(lhresp.payload)
188
      self.assertEqual(lhresp.call, "version")
189
      self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
190
      self.assertEqual(http_proc.reqcount, 1)
191

    
192
  def _GetHttpErrorResponse(self, httperrnodes, failnodes, req):
193
    self.assertEqual(req.path, "/vg_list")
194
    self.assertEqual(req.port, 15165)
195

    
196
    if req.host in httperrnodes:
197
      req.success = False
198
      req.error = "Node set up for HTTP errors"
199

    
200
    elif req.host in failnodes:
201
      req.success = True
202
      req.resp_status_code = 404
203
      req.resp_body = serializer.DumpJson({
204
        "code": 404,
205
        "message": "Method not found",
206
        "explain": "Explanation goes here",
207
        })
208
    else:
209
      req.success = True
210
      req.resp_status_code = http.HTTP_OK
211
      req.resp_body = serializer.DumpJson((True, hash(req.host)))
212

    
213
  def testHttpError(self):
214
    nodes = ["uaf6pbbv%s" % i for i in range(50)]
215
    body = dict((n, "") for n in nodes)
216
    resolver = rpc._StaticResolver(nodes)
217

    
218
    httperrnodes = set(nodes[1::7])
219
    self.assertEqual(len(httperrnodes), 7)
220

    
221
    failnodes = set(nodes[2::3]) - httperrnodes
222
    self.assertEqual(len(failnodes), 14)
223

    
224
    self.assertEqual(len(set(nodes) - failnodes - httperrnodes), 29)
225

    
226
    proc = rpc._RpcProcessor(resolver, 15165)
227
    http_proc = \
228
      _FakeRequestProcessor(compat.partial(self._GetHttpErrorResponse,
229
                                           httperrnodes, failnodes))
230
    result = proc(nodes, "vg_list", body, rpc._TMO_URGENT,
231
                  _req_process_fn=http_proc)
232
    self.assertEqual(sorted(result.keys()), sorted(nodes))
233

    
234
    for name in nodes:
235
      lhresp = result[name]
236
      self.assertFalse(lhresp.offline)
237
      self.assertEqual(lhresp.node, name)
238
      self.assertEqual(lhresp.call, "vg_list")
239

    
240
      if name in httperrnodes:
241
        self.assert_(lhresp.fail_msg)
242
        self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
243
      elif name in failnodes:
244
        self.assert_(lhresp.fail_msg)
245
        self.assertRaises(errors.OpPrereqError, lhresp.Raise, "failed",
246
                          prereq=True, ecode=errors.ECODE_INVAL)
247
      else:
248
        self.assertFalse(lhresp.fail_msg)
249
        self.assertEqual(lhresp.payload, hash(name))
250
        lhresp.Raise("should not raise")
251

    
252
    self.assertEqual(http_proc.reqcount, len(nodes))
253

    
254
  def _GetInvalidResponseA(self, req):
255
    self.assertEqual(req.path, "/version")
256
    req.success = True
257
    req.resp_status_code = http.HTTP_OK
258
    req.resp_body = serializer.DumpJson(("This", "is", "an", "invalid",
259
                                         "response", "!", 1, 2, 3))
260

    
261
  def _GetInvalidResponseB(self, req):
262
    self.assertEqual(req.path, "/version")
263
    req.success = True
264
    req.resp_status_code = http.HTTP_OK
265
    req.resp_body = serializer.DumpJson("invalid response")
266

    
267
  def testInvalidResponse(self):
268
    resolver = rpc._StaticResolver(["oqo7lanhly.example.com"])
269
    proc = rpc._RpcProcessor(resolver, 19978)
270

    
271
    for fn in [self._GetInvalidResponseA, self._GetInvalidResponseB]:
272
      http_proc = _FakeRequestProcessor(fn)
273
      host = "oqo7lanhly.example.com"
274
      body = {host: ""}
275
      result = proc([host], "version", body, 60,
276
                    _req_process_fn=http_proc)
277
      self.assertEqual(result.keys(), [host])
278
      lhresp = result[host]
279
      self.assertFalse(lhresp.offline)
280
      self.assertEqual(lhresp.node, host)
281
      self.assert_(lhresp.fail_msg)
282
      self.assertFalse(lhresp.payload)
283
      self.assertEqual(lhresp.call, "version")
284
      self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
285
      self.assertEqual(http_proc.reqcount, 1)
286

    
287
  def _GetBodyTestResponse(self, test_data, req):
288
    self.assertEqual(req.host, "192.0.2.84")
289
    self.assertEqual(req.port, 18700)
290
    self.assertEqual(req.path, "/upload_file")
291
    self.assertEqual(serializer.LoadJson(req.post_data), test_data)
292
    req.success = True
293
    req.resp_status_code = http.HTTP_OK
294
    req.resp_body = serializer.DumpJson((True, None))
295

    
296
  def testResponseBody(self):
297
    test_data = {
298
      "Hello": "World",
299
      "xyz": range(10),
300
      }
301
    resolver = rpc._StaticResolver(["192.0.2.84"])
302
    http_proc = _FakeRequestProcessor(compat.partial(self._GetBodyTestResponse,
303
                                                     test_data))
304
    proc = rpc._RpcProcessor(resolver, 18700)
305
    host = "node19759"
306
    body = {host: serializer.DumpJson(test_data)}
307
    result = proc([host], "upload_file", body, 30, _req_process_fn=http_proc)
308
    self.assertEqual(result.keys(), [host])
309
    lhresp = result[host]
310
    self.assertFalse(lhresp.offline)
311
    self.assertEqual(lhresp.node, host)
312
    self.assertFalse(lhresp.fail_msg)
313
    self.assertEqual(lhresp.payload, None)
314
    self.assertEqual(lhresp.call, "upload_file")
315
    lhresp.Raise("should not raise")
316
    self.assertEqual(http_proc.reqcount, 1)
317

    
318

    
319
class TestSsconfResolver(unittest.TestCase):
320
  def testSsconfLookup(self):
321
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
322
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
323
    node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
324
    ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
325
    result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=NotImplemented)
326
    self.assertEqual(result, zip(node_list, addr_list))
327

    
328
  def testNsLookup(self):
329
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
330
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
331
    ssc = GetFakeSimpleStoreClass(lambda _: [])
332
    node_addr_map = dict(zip(node_list, addr_list))
333
    nslookup_fn = lambda name, family=None: node_addr_map.get(name)
334
    result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=nslookup_fn)
335
    self.assertEqual(result, zip(node_list, addr_list))
336

    
337
  def testBothLookups(self):
338
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
339
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
340
    n = len(addr_list) / 2
341
    node_addr_list = [" ".join(t) for t in zip(node_list[n:], addr_list[n:])]
342
    ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
343
    node_addr_map = dict(zip(node_list[:n], addr_list[:n]))
344
    nslookup_fn = lambda name, family=None: node_addr_map.get(name)
345
    result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=nslookup_fn)
346
    self.assertEqual(result, zip(node_list, addr_list))
347

    
348
  def testAddressLookupIPv6(self):
349
    addr_list = ["2001:db8::%d" % n for n in range(0, 255, 11)]
350
    node_list = ["node%d.example.com" % n for n in range(0, 255, 11)]
351
    node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
352
    ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
353
    result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=NotImplemented)
354
    self.assertEqual(result, zip(node_list, addr_list))
355

    
356

    
357
class TestStaticResolver(unittest.TestCase):
358
  def test(self):
359
    addresses = ["192.0.2.%d" % n for n in range(0, 123, 7)]
360
    nodes = ["node%s.example.com" % n for n in range(0, 123, 7)]
361
    res = rpc._StaticResolver(addresses)
362
    self.assertEqual(res(nodes), zip(nodes, addresses))
363

    
364
  def testWrongLength(self):
365
    res = rpc._StaticResolver([])
366
    self.assertRaises(AssertionError, res, ["abc"])
367

    
368

    
369
class TestNodeConfigResolver(unittest.TestCase):
370
  @staticmethod
371
  def _GetSingleOnlineNode(name):
372
    assert name == "node90.example.com"
373
    return objects.Node(name=name, offline=False, primary_ip="192.0.2.90")
374

    
375
  @staticmethod
376
  def _GetSingleOfflineNode(name):
377
    assert name == "node100.example.com"
378
    return objects.Node(name=name, offline=True, primary_ip="192.0.2.100")
379

    
380
  def testSingleOnline(self):
381
    self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOnlineNode,
382
                                             NotImplemented,
383
                                             ["node90.example.com"]),
384
                     [("node90.example.com", "192.0.2.90")])
385

    
386
  def testSingleOffline(self):
387
    self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOfflineNode,
388
                                             NotImplemented,
389
                                             ["node100.example.com"]),
390
                     [("node100.example.com", rpc._OFFLINE)])
391

    
392
  def testUnknownSingleNode(self):
393
    self.assertEqual(rpc._NodeConfigResolver(lambda _: None, NotImplemented,
394
                                             ["node110.example.com"]),
395
                     [("node110.example.com", "node110.example.com")])
396

    
397
  def testMultiEmpty(self):
398
    self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
399
                                             lambda: {},
400
                                             []),
401
                     [])
402

    
403
  def testMultiSomeOffline(self):
404
    nodes = dict(("node%s.example.com" % i,
405
                  objects.Node(name="node%s.example.com" % i,
406
                               offline=((i % 3) == 0),
407
                               primary_ip="192.0.2.%s" % i))
408
                  for i in range(1, 255))
409

    
410
    # Resolve no names
411
    self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
412
                                             lambda: nodes,
413
                                             []),
414
                     [])
415

    
416
    # Offline, online and unknown hosts
417
    self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
418
                                             lambda: nodes,
419
                                             ["node3.example.com",
420
                                              "node92.example.com",
421
                                              "node54.example.com",
422
                                              "unknown.example.com",]), [
423
      ("node3.example.com", rpc._OFFLINE),
424
      ("node92.example.com", "192.0.2.92"),
425
      ("node54.example.com", rpc._OFFLINE),
426
      ("unknown.example.com", "unknown.example.com"),
427
      ])
428

    
429

    
430
if __name__ == "__main__":
431
  testutils.GanetiTestProgram()