Statistics
| Branch: | Tag: | Revision:

root / test / ganeti.rpc_unittest.py @ aea5caef

History | View | Annotate | Download (15.7 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2010 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for testing ganeti.rpc"""
23

    
24
import os
25
import sys
26
import unittest
27

    
28
from ganeti import constants
29
from ganeti import compat
30
from ganeti import rpc
31
from ganeti import http
32
from ganeti import errors
33
from ganeti import serializer
34
from ganeti import objects
35

    
36
import testutils
37

    
38

    
39
class TestTimeouts(unittest.TestCase):
40
  def test(self):
41
    names = [name[len("call_"):] for name in dir(rpc.RpcRunner)
42
             if name.startswith("call_")]
43
    self.assertEqual(len(names), len(rpc._TIMEOUTS))
44
    self.assertFalse([name for name in names
45
                      if not (rpc._TIMEOUTS[name] is None or
46
                              rpc._TIMEOUTS[name] > 0)])
47

    
48

    
49
class FakeHttpPool:
50
  def __init__(self, response_fn):
51
    self._response_fn = response_fn
52
    self.reqcount = 0
53

    
54
  def ProcessRequests(self, reqs, lock_monitor_cb=None):
55
    for req in reqs:
56
      self.reqcount += 1
57
      self._response_fn(req)
58

    
59

    
60
def GetFakeSimpleStoreClass(fn):
61
  class FakeSimpleStore:
62
    GetNodePrimaryIPList = fn
63
    GetPrimaryIPFamily = lambda _: None
64

    
65
  return FakeSimpleStore
66

    
67

    
68
class TestRpcProcessor(unittest.TestCase):
69
  def _FakeAddressLookup(self, map):
70
    return lambda node_list: [map.get(node) for node in node_list]
71

    
72
  def _GetVersionResponse(self, req):
73
    self.assertEqual(req.host, "127.0.0.1")
74
    self.assertEqual(req.port, 24094)
75
    self.assertEqual(req.path, "/version")
76
    self.assertEqual(req.read_timeout, rpc._TMO_URGENT)
77
    req.success = True
78
    req.resp_status_code = http.HTTP_OK
79
    req.resp_body = serializer.DumpJson((True, 123))
80

    
81
  def testVersionSuccess(self):
82
    resolver = rpc._StaticResolver(["127.0.0.1"])
83
    pool = FakeHttpPool(self._GetVersionResponse)
84
    proc = rpc._RpcProcessor(resolver, 24094)
85
    result = proc(["localhost"], "version", None, http_pool=pool)
86
    self.assertEqual(result.keys(), ["localhost"])
87
    lhresp = result["localhost"]
88
    self.assertFalse(lhresp.offline)
89
    self.assertEqual(lhresp.node, "localhost")
90
    self.assertFalse(lhresp.fail_msg)
91
    self.assertEqual(lhresp.payload, 123)
92
    self.assertEqual(lhresp.call, "version")
93
    lhresp.Raise("should not raise")
94
    self.assertEqual(pool.reqcount, 1)
95

    
96
  def _ReadTimeoutResponse(self, req):
97
    self.assertEqual(req.host, "192.0.2.13")
98
    self.assertEqual(req.port, 19176)
99
    self.assertEqual(req.path, "/version")
100
    self.assertEqual(req.read_timeout, 12356)
101
    req.success = True
102
    req.resp_status_code = http.HTTP_OK
103
    req.resp_body = serializer.DumpJson((True, -1))
104

    
105
  def testReadTimeout(self):
106
    resolver = rpc._StaticResolver(["192.0.2.13"])
107
    pool = FakeHttpPool(self._ReadTimeoutResponse)
108
    proc = rpc._RpcProcessor(resolver, 19176)
109
    result = proc(["node31856"], "version", None, http_pool=pool,
110
                  read_timeout=12356)
111
    self.assertEqual(result.keys(), ["node31856"])
112
    lhresp = result["node31856"]
113
    self.assertFalse(lhresp.offline)
114
    self.assertEqual(lhresp.node, "node31856")
115
    self.assertFalse(lhresp.fail_msg)
116
    self.assertEqual(lhresp.payload, -1)
117
    self.assertEqual(lhresp.call, "version")
118
    lhresp.Raise("should not raise")
119
    self.assertEqual(pool.reqcount, 1)
120

    
121
  def testOfflineNode(self):
122
    resolver = rpc._StaticResolver([rpc._OFFLINE])
123
    pool = FakeHttpPool(NotImplemented)
124
    proc = rpc._RpcProcessor(resolver, 30668)
125
    result = proc(["n17296"], "version", None, http_pool=pool)
126
    self.assertEqual(result.keys(), ["n17296"])
127
    lhresp = result["n17296"]
128
    self.assertTrue(lhresp.offline)
129
    self.assertEqual(lhresp.node, "n17296")
130
    self.assertTrue(lhresp.fail_msg)
131
    self.assertFalse(lhresp.payload)
132
    self.assertEqual(lhresp.call, "version")
133

    
134
    # With a message
135
    self.assertRaises(errors.OpExecError, lhresp.Raise, "should raise")
136

    
137
    # No message
138
    self.assertRaises(errors.OpExecError, lhresp.Raise, None)
139

    
140
    self.assertEqual(pool.reqcount, 0)
141

    
142
  def _GetMultiVersionResponse(self, req):
143
    self.assert_(req.host.startswith("node"))
144
    self.assertEqual(req.port, 23245)
145
    self.assertEqual(req.path, "/version")
146
    req.success = True
147
    req.resp_status_code = http.HTTP_OK
148
    req.resp_body = serializer.DumpJson((True, 987))
149

    
150
  def testMultiVersionSuccess(self):
151
    nodes = ["node%s" % i for i in range(50)]
152
    resolver = rpc._StaticResolver(nodes)
153
    pool = FakeHttpPool(self._GetMultiVersionResponse)
154
    proc = rpc._RpcProcessor(resolver, 23245)
155
    result = proc(nodes, "version", None, http_pool=pool)
156
    self.assertEqual(sorted(result.keys()), sorted(nodes))
157

    
158
    for name in nodes:
159
      lhresp = result[name]
160
      self.assertFalse(lhresp.offline)
161
      self.assertEqual(lhresp.node, name)
162
      self.assertFalse(lhresp.fail_msg)
163
      self.assertEqual(lhresp.payload, 987)
164
      self.assertEqual(lhresp.call, "version")
165
      lhresp.Raise("should not raise")
166

    
167
    self.assertEqual(pool.reqcount, len(nodes))
168

    
169
  def _GetVersionResponseFail(self, errinfo, req):
170
    self.assertEqual(req.path, "/version")
171
    req.success = True
172
    req.resp_status_code = http.HTTP_OK
173
    req.resp_body = serializer.DumpJson((False, errinfo))
174

    
175
  def testVersionFailure(self):
176
    resolver = rpc._StaticResolver(["aef9ur4i.example.com"])
177
    proc = rpc._RpcProcessor(resolver, 5903)
178
    for errinfo in [None, "Unknown error"]:
179
      pool = FakeHttpPool(compat.partial(self._GetVersionResponseFail, errinfo))
180
      result = proc(["aef9ur4i.example.com"], "version", None, http_pool=pool)
181
      self.assertEqual(result.keys(), ["aef9ur4i.example.com"])
182
      lhresp = result["aef9ur4i.example.com"]
183
      self.assertFalse(lhresp.offline)
184
      self.assertEqual(lhresp.node, "aef9ur4i.example.com")
185
      self.assert_(lhresp.fail_msg)
186
      self.assertFalse(lhresp.payload)
187
      self.assertEqual(lhresp.call, "version")
188
      self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
189
      self.assertEqual(pool.reqcount, 1)
190

    
191
  def _GetHttpErrorResponse(self, httperrnodes, failnodes, req):
192
    self.assertEqual(req.path, "/vg_list")
193
    self.assertEqual(req.port, 15165)
194

    
195
    if req.host in httperrnodes:
196
      req.success = False
197
      req.error = "Node set up for HTTP errors"
198

    
199
    elif req.host in failnodes:
200
      req.success = True
201
      req.resp_status_code = 404
202
      req.resp_body = serializer.DumpJson({
203
        "code": 404,
204
        "message": "Method not found",
205
        "explain": "Explanation goes here",
206
        })
207
    else:
208
      req.success = True
209
      req.resp_status_code = http.HTTP_OK
210
      req.resp_body = serializer.DumpJson((True, hash(req.host)))
211

    
212
  def testHttpError(self):
213
    nodes = ["uaf6pbbv%s" % i for i in range(50)]
214
    resolver = rpc._StaticResolver(nodes)
215

    
216
    httperrnodes = set(nodes[1::7])
217
    self.assertEqual(len(httperrnodes), 7)
218

    
219
    failnodes = set(nodes[2::3]) - httperrnodes
220
    self.assertEqual(len(failnodes), 14)
221

    
222
    self.assertEqual(len(set(nodes) - failnodes - httperrnodes), 29)
223

    
224
    proc = rpc._RpcProcessor(resolver, 15165)
225
    pool = FakeHttpPool(compat.partial(self._GetHttpErrorResponse,
226
                                       httperrnodes, failnodes))
227
    result = proc(nodes, "vg_list", None, http_pool=pool)
228
    self.assertEqual(sorted(result.keys()), sorted(nodes))
229

    
230
    for name in nodes:
231
      lhresp = result[name]
232
      self.assertFalse(lhresp.offline)
233
      self.assertEqual(lhresp.node, name)
234
      self.assertEqual(lhresp.call, "vg_list")
235

    
236
      if name in httperrnodes:
237
        self.assert_(lhresp.fail_msg)
238
        self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
239
      elif name in failnodes:
240
        self.assert_(lhresp.fail_msg)
241
        self.assertRaises(errors.OpPrereqError, lhresp.Raise, "failed",
242
                          prereq=True, ecode=errors.ECODE_INVAL)
243
      else:
244
        self.assertFalse(lhresp.fail_msg)
245
        self.assertEqual(lhresp.payload, hash(name))
246
        lhresp.Raise("should not raise")
247

    
248
    self.assertEqual(pool.reqcount, len(nodes))
249

    
250
  def _GetInvalidResponseA(self, req):
251
    self.assertEqual(req.path, "/version")
252
    req.success = True
253
    req.resp_status_code = http.HTTP_OK
254
    req.resp_body = serializer.DumpJson(("This", "is", "an", "invalid",
255
                                         "response", "!", 1, 2, 3))
256

    
257
  def _GetInvalidResponseB(self, req):
258
    self.assertEqual(req.path, "/version")
259
    req.success = True
260
    req.resp_status_code = http.HTTP_OK
261
    req.resp_body = serializer.DumpJson("invalid response")
262

    
263
  def testInvalidResponse(self):
264
    resolver = rpc._StaticResolver(["oqo7lanhly.example.com"])
265
    proc = rpc._RpcProcessor(resolver, 19978)
266

    
267
    for fn in [self._GetInvalidResponseA, self._GetInvalidResponseB]:
268
      pool = FakeHttpPool(fn)
269
      result = proc(["oqo7lanhly.example.com"], "version", None, http_pool=pool)
270
      self.assertEqual(result.keys(), ["oqo7lanhly.example.com"])
271
      lhresp = result["oqo7lanhly.example.com"]
272
      self.assertFalse(lhresp.offline)
273
      self.assertEqual(lhresp.node, "oqo7lanhly.example.com")
274
      self.assert_(lhresp.fail_msg)
275
      self.assertFalse(lhresp.payload)
276
      self.assertEqual(lhresp.call, "version")
277
      self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
278
      self.assertEqual(pool.reqcount, 1)
279

    
280
  def _GetBodyTestResponse(self, test_data, req):
281
    self.assertEqual(req.host, "192.0.2.84")
282
    self.assertEqual(req.port, 18700)
283
    self.assertEqual(req.path, "/upload_file")
284
    self.assertEqual(serializer.LoadJson(req.post_data), test_data)
285
    req.success = True
286
    req.resp_status_code = http.HTTP_OK
287
    req.resp_body = serializer.DumpJson((True, None))
288

    
289
  def testResponseBody(self):
290
    test_data = {
291
      "Hello": "World",
292
      "xyz": range(10),
293
      }
294
    resolver = rpc._StaticResolver(["192.0.2.84"])
295
    pool = FakeHttpPool(compat.partial(self._GetBodyTestResponse, test_data))
296
    proc = rpc._RpcProcessor(resolver, 18700)
297
    body = serializer.DumpJson(test_data)
298
    result = proc(["node19759"], "upload_file", body, http_pool=pool)
299
    self.assertEqual(result.keys(), ["node19759"])
300
    lhresp = result["node19759"]
301
    self.assertFalse(lhresp.offline)
302
    self.assertEqual(lhresp.node, "node19759")
303
    self.assertFalse(lhresp.fail_msg)
304
    self.assertEqual(lhresp.payload, None)
305
    self.assertEqual(lhresp.call, "upload_file")
306
    lhresp.Raise("should not raise")
307
    self.assertEqual(pool.reqcount, 1)
308

    
309

    
310
class TestSsconfResolver(unittest.TestCase):
311
  def testSsconfLookup(self):
312
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
313
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
314
    node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
315
    ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
316
    result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=NotImplemented)
317
    self.assertEqual(result, zip(node_list, addr_list))
318

    
319
  def testNsLookup(self):
320
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
321
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
322
    ssc = GetFakeSimpleStoreClass(lambda _: [])
323
    node_addr_map = dict(zip(node_list, addr_list))
324
    nslookup_fn = lambda name, family=None: node_addr_map.get(name)
325
    result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=nslookup_fn)
326
    self.assertEqual(result, zip(node_list, addr_list))
327

    
328
  def testBothLookups(self):
329
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
330
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
331
    n = len(addr_list) / 2
332
    node_addr_list = [" ".join(t) for t in zip(node_list[n:], addr_list[n:])]
333
    ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
334
    node_addr_map = dict(zip(node_list[:n], addr_list[:n]))
335
    nslookup_fn = lambda name, family=None: node_addr_map.get(name)
336
    result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=nslookup_fn)
337
    self.assertEqual(result, zip(node_list, addr_list))
338

    
339
  def testAddressLookupIPv6(self):
340
    addr_list = ["2001:db8::%d" % n for n in range(0, 255, 11)]
341
    node_list = ["node%d.example.com" % n for n in range(0, 255, 11)]
342
    node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
343
    ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
344
    result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=NotImplemented)
345
    self.assertEqual(result, zip(node_list, addr_list))
346

    
347

    
348
class TestStaticResolver(unittest.TestCase):
349
  def test(self):
350
    addresses = ["192.0.2.%d" % n for n in range(0, 123, 7)]
351
    nodes = ["node%s.example.com" % n for n in range(0, 123, 7)]
352
    res = rpc._StaticResolver(addresses)
353
    self.assertEqual(res(nodes), zip(nodes, addresses))
354

    
355
  def testWrongLength(self):
356
    res = rpc._StaticResolver([])
357
    self.assertRaises(AssertionError, res, ["abc"])
358

    
359

    
360
class TestNodeConfigResolver(unittest.TestCase):
361
  @staticmethod
362
  def _GetSingleOnlineNode(name):
363
    assert name == "node90.example.com"
364
    return objects.Node(name=name, offline=False, primary_ip="192.0.2.90")
365

    
366
  @staticmethod
367
  def _GetSingleOfflineNode(name):
368
    assert name == "node100.example.com"
369
    return objects.Node(name=name, offline=True, primary_ip="192.0.2.100")
370

    
371
  def testSingleOnline(self):
372
    self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOnlineNode,
373
                                             NotImplemented,
374
                                             ["node90.example.com"]),
375
                     [("node90.example.com", "192.0.2.90")])
376

    
377
  def testSingleOffline(self):
378
    self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOfflineNode,
379
                                             NotImplemented,
380
                                             ["node100.example.com"]),
381
                     [("node100.example.com", rpc._OFFLINE)])
382

    
383
  def testUnknownSingleNode(self):
384
    self.assertEqual(rpc._NodeConfigResolver(lambda _: None, NotImplemented,
385
                                             ["node110.example.com"]),
386
                     [("node110.example.com", "node110.example.com")])
387

    
388
  def testMultiEmpty(self):
389
    self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
390
                                             lambda: {},
391
                                             []),
392
                     [])
393

    
394
  def testMultiSomeOffline(self):
395
    nodes = dict(("node%s.example.com" % i,
396
                  objects.Node(name="node%s.example.com" % i,
397
                               offline=((i % 3) == 0),
398
                               primary_ip="192.0.2.%s" % i))
399
                  for i in range(1, 255))
400

    
401
    # Resolve no names
402
    self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
403
                                             lambda: nodes,
404
                                             []),
405
                     [])
406

    
407
    # Offline, online and unknown hosts
408
    self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
409
                                             lambda: nodes,
410
                                             ["node3.example.com",
411
                                              "node92.example.com",
412
                                              "node54.example.com",
413
                                              "unknown.example.com",]), [
414
      ("node3.example.com", rpc._OFFLINE),
415
      ("node92.example.com", "192.0.2.92"),
416
      ("node54.example.com", rpc._OFFLINE),
417
      ("unknown.example.com", "unknown.example.com"),
418
      ])
419

    
420

    
421
if __name__ == "__main__":
422
  testutils.GanetiTestProgram()