Statistics
| Branch: | Tag: | Revision:

root / test / ganeti.rpc_unittest.py @ 2e04d454

History | View | Annotate | Download (16 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2010 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for testing ganeti.rpc"""
23

    
24
import os
25
import sys
26
import unittest
27

    
28
from ganeti import constants
29
from ganeti import compat
30
from ganeti import rpc
31
from ganeti import http
32
from ganeti import errors
33
from ganeti import serializer
34
from ganeti import objects
35

    
36
import testutils
37

    
38

    
39
class _FakeRequestProcessor:
40
  def __init__(self, response_fn):
41
    self._response_fn = response_fn
42
    self.reqcount = 0
43

    
44
  def __call__(self, reqs, lock_monitor_cb=None):
45
    assert lock_monitor_cb is None or callable(lock_monitor_cb)
46
    for req in reqs:
47
      self.reqcount += 1
48
      self._response_fn(req)
49

    
50

    
51
def GetFakeSimpleStoreClass(fn):
52
  class FakeSimpleStore:
53
    GetNodePrimaryIPList = fn
54
    GetPrimaryIPFamily = lambda _: None
55

    
56
  return FakeSimpleStore
57

    
58

    
59
class TestRpcProcessor(unittest.TestCase):
60
  def _FakeAddressLookup(self, map):
61
    return lambda node_list: [map.get(node) for node in node_list]
62

    
63
  def _GetVersionResponse(self, req):
64
    self.assertEqual(req.host, "127.0.0.1")
65
    self.assertEqual(req.port, 24094)
66
    self.assertEqual(req.path, "/version")
67
    self.assertEqual(req.read_timeout, rpc._TMO_URGENT)
68
    req.success = True
69
    req.resp_status_code = http.HTTP_OK
70
    req.resp_body = serializer.DumpJson((True, 123))
71

    
72
  def testVersionSuccess(self):
73
    resolver = rpc._StaticResolver(["127.0.0.1"])
74
    http_proc = _FakeRequestProcessor(self._GetVersionResponse)
75
    proc = rpc._RpcProcessor(resolver, 24094)
76
    result = proc(["localhost"], "version", None, _req_process_fn=http_proc,
77
                  read_timeout=60)
78
    self.assertEqual(result.keys(), ["localhost"])
79
    lhresp = result["localhost"]
80
    self.assertFalse(lhresp.offline)
81
    self.assertEqual(lhresp.node, "localhost")
82
    self.assertFalse(lhresp.fail_msg)
83
    self.assertEqual(lhresp.payload, 123)
84
    self.assertEqual(lhresp.call, "version")
85
    lhresp.Raise("should not raise")
86
    self.assertEqual(http_proc.reqcount, 1)
87

    
88
  def _ReadTimeoutResponse(self, req):
89
    self.assertEqual(req.host, "192.0.2.13")
90
    self.assertEqual(req.port, 19176)
91
    self.assertEqual(req.path, "/version")
92
    self.assertEqual(req.read_timeout, 12356)
93
    req.success = True
94
    req.resp_status_code = http.HTTP_OK
95
    req.resp_body = serializer.DumpJson((True, -1))
96

    
97
  def testReadTimeout(self):
98
    resolver = rpc._StaticResolver(["192.0.2.13"])
99
    http_proc = _FakeRequestProcessor(self._ReadTimeoutResponse)
100
    proc = rpc._RpcProcessor(resolver, 19176)
101
    result = proc(["node31856"], "version", None, _req_process_fn=http_proc,
102
                  read_timeout=12356)
103
    self.assertEqual(result.keys(), ["node31856"])
104
    lhresp = result["node31856"]
105
    self.assertFalse(lhresp.offline)
106
    self.assertEqual(lhresp.node, "node31856")
107
    self.assertFalse(lhresp.fail_msg)
108
    self.assertEqual(lhresp.payload, -1)
109
    self.assertEqual(lhresp.call, "version")
110
    lhresp.Raise("should not raise")
111
    self.assertEqual(http_proc.reqcount, 1)
112

    
113
  def testOfflineNode(self):
114
    resolver = rpc._StaticResolver([rpc._OFFLINE])
115
    http_proc = _FakeRequestProcessor(NotImplemented)
116
    proc = rpc._RpcProcessor(resolver, 30668)
117
    result = proc(["n17296"], "version", None, _req_process_fn=http_proc,
118
                  read_timeout=60)
119
    self.assertEqual(result.keys(), ["n17296"])
120
    lhresp = result["n17296"]
121
    self.assertTrue(lhresp.offline)
122
    self.assertEqual(lhresp.node, "n17296")
123
    self.assertTrue(lhresp.fail_msg)
124
    self.assertFalse(lhresp.payload)
125
    self.assertEqual(lhresp.call, "version")
126

    
127
    # With a message
128
    self.assertRaises(errors.OpExecError, lhresp.Raise, "should raise")
129

    
130
    # No message
131
    self.assertRaises(errors.OpExecError, lhresp.Raise, None)
132

    
133
    self.assertEqual(http_proc.reqcount, 0)
134

    
135
  def _GetMultiVersionResponse(self, req):
136
    self.assert_(req.host.startswith("node"))
137
    self.assertEqual(req.port, 23245)
138
    self.assertEqual(req.path, "/version")
139
    req.success = True
140
    req.resp_status_code = http.HTTP_OK
141
    req.resp_body = serializer.DumpJson((True, 987))
142

    
143
  def testMultiVersionSuccess(self):
144
    nodes = ["node%s" % i for i in range(50)]
145
    resolver = rpc._StaticResolver(nodes)
146
    http_proc = _FakeRequestProcessor(self._GetMultiVersionResponse)
147
    proc = rpc._RpcProcessor(resolver, 23245)
148
    result = proc(nodes, "version", None, _req_process_fn=http_proc,
149
                  read_timeout=60)
150
    self.assertEqual(sorted(result.keys()), sorted(nodes))
151

    
152
    for name in nodes:
153
      lhresp = result[name]
154
      self.assertFalse(lhresp.offline)
155
      self.assertEqual(lhresp.node, name)
156
      self.assertFalse(lhresp.fail_msg)
157
      self.assertEqual(lhresp.payload, 987)
158
      self.assertEqual(lhresp.call, "version")
159
      lhresp.Raise("should not raise")
160

    
161
    self.assertEqual(http_proc.reqcount, len(nodes))
162

    
163
  def _GetVersionResponseFail(self, errinfo, req):
164
    self.assertEqual(req.path, "/version")
165
    req.success = True
166
    req.resp_status_code = http.HTTP_OK
167
    req.resp_body = serializer.DumpJson((False, errinfo))
168

    
169
  def testVersionFailure(self):
170
    resolver = rpc._StaticResolver(["aef9ur4i.example.com"])
171
    proc = rpc._RpcProcessor(resolver, 5903)
172
    for errinfo in [None, "Unknown error"]:
173
      http_proc = \
174
        _FakeRequestProcessor(compat.partial(self._GetVersionResponseFail,
175
                                             errinfo))
176
      result = proc(["aef9ur4i.example.com"], "version", None,
177
                    _req_process_fn=http_proc, read_timeout=60)
178
      self.assertEqual(result.keys(), ["aef9ur4i.example.com"])
179
      lhresp = result["aef9ur4i.example.com"]
180
      self.assertFalse(lhresp.offline)
181
      self.assertEqual(lhresp.node, "aef9ur4i.example.com")
182
      self.assert_(lhresp.fail_msg)
183
      self.assertFalse(lhresp.payload)
184
      self.assertEqual(lhresp.call, "version")
185
      self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
186
      self.assertEqual(http_proc.reqcount, 1)
187

    
188
  def _GetHttpErrorResponse(self, httperrnodes, failnodes, req):
189
    self.assertEqual(req.path, "/vg_list")
190
    self.assertEqual(req.port, 15165)
191

    
192
    if req.host in httperrnodes:
193
      req.success = False
194
      req.error = "Node set up for HTTP errors"
195

    
196
    elif req.host in failnodes:
197
      req.success = True
198
      req.resp_status_code = 404
199
      req.resp_body = serializer.DumpJson({
200
        "code": 404,
201
        "message": "Method not found",
202
        "explain": "Explanation goes here",
203
        })
204
    else:
205
      req.success = True
206
      req.resp_status_code = http.HTTP_OK
207
      req.resp_body = serializer.DumpJson((True, hash(req.host)))
208

    
209
  def testHttpError(self):
210
    nodes = ["uaf6pbbv%s" % i for i in range(50)]
211
    resolver = rpc._StaticResolver(nodes)
212

    
213
    httperrnodes = set(nodes[1::7])
214
    self.assertEqual(len(httperrnodes), 7)
215

    
216
    failnodes = set(nodes[2::3]) - httperrnodes
217
    self.assertEqual(len(failnodes), 14)
218

    
219
    self.assertEqual(len(set(nodes) - failnodes - httperrnodes), 29)
220

    
221
    proc = rpc._RpcProcessor(resolver, 15165)
222
    http_proc = \
223
      _FakeRequestProcessor(compat.partial(self._GetHttpErrorResponse,
224
                                           httperrnodes, failnodes))
225
    result = proc(nodes, "vg_list", None, _req_process_fn=http_proc,
226
                  read_timeout=rpc._TMO_URGENT)
227
    self.assertEqual(sorted(result.keys()), sorted(nodes))
228

    
229
    for name in nodes:
230
      lhresp = result[name]
231
      self.assertFalse(lhresp.offline)
232
      self.assertEqual(lhresp.node, name)
233
      self.assertEqual(lhresp.call, "vg_list")
234

    
235
      if name in httperrnodes:
236
        self.assert_(lhresp.fail_msg)
237
        self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
238
      elif name in failnodes:
239
        self.assert_(lhresp.fail_msg)
240
        self.assertRaises(errors.OpPrereqError, lhresp.Raise, "failed",
241
                          prereq=True, ecode=errors.ECODE_INVAL)
242
      else:
243
        self.assertFalse(lhresp.fail_msg)
244
        self.assertEqual(lhresp.payload, hash(name))
245
        lhresp.Raise("should not raise")
246

    
247
    self.assertEqual(http_proc.reqcount, len(nodes))
248

    
249
  def _GetInvalidResponseA(self, req):
250
    self.assertEqual(req.path, "/version")
251
    req.success = True
252
    req.resp_status_code = http.HTTP_OK
253
    req.resp_body = serializer.DumpJson(("This", "is", "an", "invalid",
254
                                         "response", "!", 1, 2, 3))
255

    
256
  def _GetInvalidResponseB(self, req):
257
    self.assertEqual(req.path, "/version")
258
    req.success = True
259
    req.resp_status_code = http.HTTP_OK
260
    req.resp_body = serializer.DumpJson("invalid response")
261

    
262
  def testInvalidResponse(self):
263
    resolver = rpc._StaticResolver(["oqo7lanhly.example.com"])
264
    proc = rpc._RpcProcessor(resolver, 19978)
265

    
266
    for fn in [self._GetInvalidResponseA, self._GetInvalidResponseB]:
267
      http_proc = _FakeRequestProcessor(fn)
268
      result = proc(["oqo7lanhly.example.com"], "version", None,
269
                    _req_process_fn=http_proc, read_timeout=60)
270
      self.assertEqual(result.keys(), ["oqo7lanhly.example.com"])
271
      lhresp = result["oqo7lanhly.example.com"]
272
      self.assertFalse(lhresp.offline)
273
      self.assertEqual(lhresp.node, "oqo7lanhly.example.com")
274
      self.assert_(lhresp.fail_msg)
275
      self.assertFalse(lhresp.payload)
276
      self.assertEqual(lhresp.call, "version")
277
      self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
278
      self.assertEqual(http_proc.reqcount, 1)
279

    
280
  def _GetBodyTestResponse(self, test_data, req):
281
    self.assertEqual(req.host, "192.0.2.84")
282
    self.assertEqual(req.port, 18700)
283
    self.assertEqual(req.path, "/upload_file")
284
    self.assertEqual(serializer.LoadJson(req.post_data), test_data)
285
    req.success = True
286
    req.resp_status_code = http.HTTP_OK
287
    req.resp_body = serializer.DumpJson((True, None))
288

    
289
  def testResponseBody(self):
290
    test_data = {
291
      "Hello": "World",
292
      "xyz": range(10),
293
      }
294
    resolver = rpc._StaticResolver(["192.0.2.84"])
295
    http_proc = _FakeRequestProcessor(compat.partial(self._GetBodyTestResponse,
296
                                                     test_data))
297
    proc = rpc._RpcProcessor(resolver, 18700)
298
    body = serializer.DumpJson(test_data)
299
    result = proc(["node19759"], "upload_file", body, _req_process_fn=http_proc,
300
                  read_timeout=30)
301
    self.assertEqual(result.keys(), ["node19759"])
302
    lhresp = result["node19759"]
303
    self.assertFalse(lhresp.offline)
304
    self.assertEqual(lhresp.node, "node19759")
305
    self.assertFalse(lhresp.fail_msg)
306
    self.assertEqual(lhresp.payload, None)
307
    self.assertEqual(lhresp.call, "upload_file")
308
    lhresp.Raise("should not raise")
309
    self.assertEqual(http_proc.reqcount, 1)
310

    
311

    
312
class TestSsconfResolver(unittest.TestCase):
313
  def testSsconfLookup(self):
314
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
315
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
316
    node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
317
    ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
318
    result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=NotImplemented)
319
    self.assertEqual(result, zip(node_list, addr_list))
320

    
321
  def testNsLookup(self):
322
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
323
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
324
    ssc = GetFakeSimpleStoreClass(lambda _: [])
325
    node_addr_map = dict(zip(node_list, addr_list))
326
    nslookup_fn = lambda name, family=None: node_addr_map.get(name)
327
    result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=nslookup_fn)
328
    self.assertEqual(result, zip(node_list, addr_list))
329

    
330
  def testBothLookups(self):
331
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
332
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
333
    n = len(addr_list) / 2
334
    node_addr_list = [" ".join(t) for t in zip(node_list[n:], addr_list[n:])]
335
    ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
336
    node_addr_map = dict(zip(node_list[:n], addr_list[:n]))
337
    nslookup_fn = lambda name, family=None: node_addr_map.get(name)
338
    result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=nslookup_fn)
339
    self.assertEqual(result, zip(node_list, addr_list))
340

    
341
  def testAddressLookupIPv6(self):
342
    addr_list = ["2001:db8::%d" % n for n in range(0, 255, 11)]
343
    node_list = ["node%d.example.com" % n for n in range(0, 255, 11)]
344
    node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
345
    ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
346
    result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=NotImplemented)
347
    self.assertEqual(result, zip(node_list, addr_list))
348

    
349

    
350
class TestStaticResolver(unittest.TestCase):
351
  def test(self):
352
    addresses = ["192.0.2.%d" % n for n in range(0, 123, 7)]
353
    nodes = ["node%s.example.com" % n for n in range(0, 123, 7)]
354
    res = rpc._StaticResolver(addresses)
355
    self.assertEqual(res(nodes), zip(nodes, addresses))
356

    
357
  def testWrongLength(self):
358
    res = rpc._StaticResolver([])
359
    self.assertRaises(AssertionError, res, ["abc"])
360

    
361

    
362
class TestNodeConfigResolver(unittest.TestCase):
363
  @staticmethod
364
  def _GetSingleOnlineNode(name):
365
    assert name == "node90.example.com"
366
    return objects.Node(name=name, offline=False, primary_ip="192.0.2.90")
367

    
368
  @staticmethod
369
  def _GetSingleOfflineNode(name):
370
    assert name == "node100.example.com"
371
    return objects.Node(name=name, offline=True, primary_ip="192.0.2.100")
372

    
373
  def testSingleOnline(self):
374
    self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOnlineNode,
375
                                             NotImplemented,
376
                                             ["node90.example.com"]),
377
                     [("node90.example.com", "192.0.2.90")])
378

    
379
  def testSingleOffline(self):
380
    self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOfflineNode,
381
                                             NotImplemented,
382
                                             ["node100.example.com"]),
383
                     [("node100.example.com", rpc._OFFLINE)])
384

    
385
  def testUnknownSingleNode(self):
386
    self.assertEqual(rpc._NodeConfigResolver(lambda _: None, NotImplemented,
387
                                             ["node110.example.com"]),
388
                     [("node110.example.com", "node110.example.com")])
389

    
390
  def testMultiEmpty(self):
391
    self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
392
                                             lambda: {},
393
                                             []),
394
                     [])
395

    
396
  def testMultiSomeOffline(self):
397
    nodes = dict(("node%s.example.com" % i,
398
                  objects.Node(name="node%s.example.com" % i,
399
                               offline=((i % 3) == 0),
400
                               primary_ip="192.0.2.%s" % i))
401
                  for i in range(1, 255))
402

    
403
    # Resolve no names
404
    self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
405
                                             lambda: nodes,
406
                                             []),
407
                     [])
408

    
409
    # Offline, online and unknown hosts
410
    self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
411
                                             lambda: nodes,
412
                                             ["node3.example.com",
413
                                              "node92.example.com",
414
                                              "node54.example.com",
415
                                              "unknown.example.com",]), [
416
      ("node3.example.com", rpc._OFFLINE),
417
      ("node92.example.com", "192.0.2.92"),
418
      ("node54.example.com", rpc._OFFLINE),
419
      ("unknown.example.com", "unknown.example.com"),
420
      ])
421

    
422

    
423
if __name__ == "__main__":
424
  testutils.GanetiTestProgram()