Revision 00267bfe test/ganeti.rpc_unittest.py

b/test/ganeti.rpc_unittest.py
31 31
from ganeti import http
32 32
from ganeti import errors
33 33
from ganeti import serializer
34
from ganeti import objects
34 35

  
35 36
import testutils
36 37

  
......
64 65
  return FakeSimpleStore
65 66

  
66 67

  
67
class TestClient(unittest.TestCase):
68
class TestRpcProcessor(unittest.TestCase):
68 69
  def _FakeAddressLookup(self, map):
69 70
    return lambda node_list: [map.get(node) for node in node_list]
70 71

  
71 72
  def _GetVersionResponse(self, req):
72
    self.assertEqual(req.host, "localhost")
73
    self.assertEqual(req.host, "127.0.0.1")
73 74
    self.assertEqual(req.port, 24094)
74 75
    self.assertEqual(req.path, "/version")
76
    self.assertEqual(req.read_timeout, rpc._TMO_URGENT)
75 77
    req.success = True
76 78
    req.resp_status_code = http.HTTP_OK
77 79
    req.resp_body = serializer.DumpJson((True, 123))
78 80

  
79 81
  def testVersionSuccess(self):
80
    fn = self._FakeAddressLookup({"localhost": "localhost"})
81
    client = rpc.Client("version", None, 24094, address_lookup_fn=fn)
82
    client.ConnectNode("localhost")
82
    resolver = rpc._StaticResolver(["127.0.0.1"])
83 83
    pool = FakeHttpPool(self._GetVersionResponse)
84
    result = client.GetResults(http_pool=pool)
84
    proc = rpc._RpcProcessor(resolver, 24094)
85
    result = proc(["localhost"], "version", None, http_pool=pool)
85 86
    self.assertEqual(result.keys(), ["localhost"])
86 87
    lhresp = result["localhost"]
87 88
    self.assertFalse(lhresp.offline)
......
92 93
    lhresp.Raise("should not raise")
93 94
    self.assertEqual(pool.reqcount, 1)
94 95

  
96
  def _ReadTimeoutResponse(self, req):
97
    self.assertEqual(req.host, "192.0.2.13")
98
    self.assertEqual(req.port, 19176)
99
    self.assertEqual(req.path, "/version")
100
    self.assertEqual(req.read_timeout, 12356)
101
    req.success = True
102
    req.resp_status_code = http.HTTP_OK
103
    req.resp_body = serializer.DumpJson((True, -1))
104

  
105
  def testReadTimeout(self):
106
    resolver = rpc._StaticResolver(["192.0.2.13"])
107
    pool = FakeHttpPool(self._ReadTimeoutResponse)
108
    proc = rpc._RpcProcessor(resolver, 19176)
109
    result = proc(["node31856"], "version", None, http_pool=pool,
110
                  read_timeout=12356)
111
    self.assertEqual(result.keys(), ["node31856"])
112
    lhresp = result["node31856"]
113
    self.assertFalse(lhresp.offline)
114
    self.assertEqual(lhresp.node, "node31856")
115
    self.assertFalse(lhresp.fail_msg)
116
    self.assertEqual(lhresp.payload, -1)
117
    self.assertEqual(lhresp.call, "version")
118
    lhresp.Raise("should not raise")
119
    self.assertEqual(pool.reqcount, 1)
120

  
121
  def testOfflineNode(self):
122
    resolver = rpc._StaticResolver([rpc._OFFLINE])
123
    pool = FakeHttpPool(NotImplemented)
124
    proc = rpc._RpcProcessor(resolver, 30668)
125
    result = proc(["n17296"], "version", None, http_pool=pool)
126
    self.assertEqual(result.keys(), ["n17296"])
127
    lhresp = result["n17296"]
128
    self.assertTrue(lhresp.offline)
129
    self.assertEqual(lhresp.node, "n17296")
130
    self.assertTrue(lhresp.fail_msg)
131
    self.assertFalse(lhresp.payload)
132
    self.assertEqual(lhresp.call, "version")
133

  
134
    # With a message
135
    self.assertRaises(errors.OpExecError, lhresp.Raise, "should raise")
136

  
137
    # No message
138
    self.assertRaises(errors.OpExecError, lhresp.Raise, None)
139

  
140
    self.assertEqual(pool.reqcount, 0)
141

  
95 142
  def _GetMultiVersionResponse(self, req):
96 143
    self.assert_(req.host.startswith("node"))
97 144
    self.assertEqual(req.port, 23245)
......
102 149

  
103 150
  def testMultiVersionSuccess(self):
104 151
    nodes = ["node%s" % i for i in range(50)]
105
    fn = self._FakeAddressLookup(dict(zip(nodes, nodes)))
106
    client = rpc.Client("version", None, 23245, address_lookup_fn=fn)
107
    client.ConnectList(nodes)
108

  
152
    resolver = rpc._StaticResolver(nodes)
109 153
    pool = FakeHttpPool(self._GetMultiVersionResponse)
110
    result = client.GetResults(http_pool=pool)
154
    proc = rpc._RpcProcessor(resolver, 23245)
155
    result = proc(nodes, "version", None, http_pool=pool)
111 156
    self.assertEqual(sorted(result.keys()), sorted(nodes))
112 157

  
113 158
    for name in nodes:
......
121 166

  
122 167
    self.assertEqual(pool.reqcount, len(nodes))
123 168

  
124
  def _GetVersionResponseFail(self, req):
169
  def _GetVersionResponseFail(self, errinfo, req):
125 170
    self.assertEqual(req.path, "/version")
126 171
    req.success = True
127 172
    req.resp_status_code = http.HTTP_OK
128
    req.resp_body = serializer.DumpJson((False, "Unknown error"))
173
    req.resp_body = serializer.DumpJson((False, errinfo))
129 174

  
130 175
  def testVersionFailure(self):
131
    lookup_map = {"aef9ur4i.example.com": "aef9ur4i.example.com"}
132
    fn = self._FakeAddressLookup(lookup_map)
133
    client = rpc.Client("version", None, 5903, address_lookup_fn=fn)
134
    client.ConnectNode("aef9ur4i.example.com")
135
    pool = FakeHttpPool(self._GetVersionResponseFail)
136
    result = client.GetResults(http_pool=pool)
137
    self.assertEqual(result.keys(), ["aef9ur4i.example.com"])
138
    lhresp = result["aef9ur4i.example.com"]
139
    self.assertFalse(lhresp.offline)
140
    self.assertEqual(lhresp.node, "aef9ur4i.example.com")
141
    self.assert_(lhresp.fail_msg)
142
    self.assertFalse(lhresp.payload)
143
    self.assertEqual(lhresp.call, "version")
144
    self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
145
    self.assertEqual(pool.reqcount, 1)
176
    resolver = rpc._StaticResolver(["aef9ur4i.example.com"])
177
    proc = rpc._RpcProcessor(resolver, 5903)
178
    for errinfo in [None, "Unknown error"]:
179
      pool = FakeHttpPool(compat.partial(self._GetVersionResponseFail, errinfo))
180
      result = proc(["aef9ur4i.example.com"], "version", None, http_pool=pool)
181
      self.assertEqual(result.keys(), ["aef9ur4i.example.com"])
182
      lhresp = result["aef9ur4i.example.com"]
183
      self.assertFalse(lhresp.offline)
184
      self.assertEqual(lhresp.node, "aef9ur4i.example.com")
185
      self.assert_(lhresp.fail_msg)
186
      self.assertFalse(lhresp.payload)
187
      self.assertEqual(lhresp.call, "version")
188
      self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
189
      self.assertEqual(pool.reqcount, 1)
146 190

  
147 191
  def _GetHttpErrorResponse(self, httperrnodes, failnodes, req):
148 192
    self.assertEqual(req.path, "/vg_list")
......
167 211

  
168 212
  def testHttpError(self):
169 213
    nodes = ["uaf6pbbv%s" % i for i in range(50)]
170
    fn = self._FakeAddressLookup(dict(zip(nodes, nodes)))
214
    resolver = rpc._StaticResolver(nodes)
171 215

  
172 216
    httperrnodes = set(nodes[1::7])
173 217
    self.assertEqual(len(httperrnodes), 7)
......
177 221

  
178 222
    self.assertEqual(len(set(nodes) - failnodes - httperrnodes), 29)
179 223

  
180
    client = rpc.Client("vg_list", None, 15165, address_lookup_fn=fn)
181
    client.ConnectList(nodes)
182

  
224
    proc = rpc._RpcProcessor(resolver, 15165)
183 225
    pool = FakeHttpPool(compat.partial(self._GetHttpErrorResponse,
184 226
                                       httperrnodes, failnodes))
185
    result = client.GetResults(http_pool=pool)
227
    result = proc(nodes, "vg_list", None, http_pool=pool)
186 228
    self.assertEqual(sorted(result.keys()), sorted(nodes))
187 229

  
188 230
    for name in nodes:
......
219 261
    req.resp_body = serializer.DumpJson("invalid response")
220 262

  
221 263
  def testInvalidResponse(self):
222
    lookup_map = {"oqo7lanhly.example.com": "oqo7lanhly.example.com"}
223
    fn = self._FakeAddressLookup(lookup_map)
224
    client = rpc.Client("version", None, 19978, address_lookup_fn=fn)
264
    resolver = rpc._StaticResolver(["oqo7lanhly.example.com"])
265
    proc = rpc._RpcProcessor(resolver, 19978)
266

  
225 267
    for fn in [self._GetInvalidResponseA, self._GetInvalidResponseB]:
226
      client.ConnectNode("oqo7lanhly.example.com")
227 268
      pool = FakeHttpPool(fn)
228
      result = client.GetResults(http_pool=pool)
269
      result = proc(["oqo7lanhly.example.com"], "version", None, http_pool=pool)
229 270
      self.assertEqual(result.keys(), ["oqo7lanhly.example.com"])
230 271
      lhresp = result["oqo7lanhly.example.com"]
231 272
      self.assertFalse(lhresp.offline)
......
236 277
      self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
237 278
      self.assertEqual(pool.reqcount, 1)
238 279

  
239
  def testAddressLookupSimpleStore(self):
280
  def _GetBodyTestResponse(self, test_data, req):
281
    self.assertEqual(req.host, "192.0.2.84")
282
    self.assertEqual(req.port, 18700)
283
    self.assertEqual(req.path, "/upload_file")
284
    self.assertEqual(serializer.LoadJson(req.post_data), test_data)
285
    req.success = True
286
    req.resp_status_code = http.HTTP_OK
287
    req.resp_body = serializer.DumpJson((True, None))
288

  
289
  def testResponseBody(self):
290
    test_data = {
291
      "Hello": "World",
292
      "xyz": range(10),
293
      }
294
    resolver = rpc._StaticResolver(["192.0.2.84"])
295
    pool = FakeHttpPool(compat.partial(self._GetBodyTestResponse, test_data))
296
    proc = rpc._RpcProcessor(resolver, 18700)
297
    body = serializer.DumpJson(test_data)
298
    result = proc(["node19759"], "upload_file", body, http_pool=pool)
299
    self.assertEqual(result.keys(), ["node19759"])
300
    lhresp = result["node19759"]
301
    self.assertFalse(lhresp.offline)
302
    self.assertEqual(lhresp.node, "node19759")
303
    self.assertFalse(lhresp.fail_msg)
304
    self.assertEqual(lhresp.payload, None)
305
    self.assertEqual(lhresp.call, "upload_file")
306
    lhresp.Raise("should not raise")
307
    self.assertEqual(pool.reqcount, 1)
308

  
309

  
310
class TestSsconfResolver(unittest.TestCase):
311
  def testSsconfLookup(self):
240 312
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
241 313
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
242
    node_addr_list = [ " ".join(t) for t in zip(node_list, addr_list)]
314
    node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
243 315
    ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
244
    result = rpc._AddressLookup(node_list, ssc=ssc)
245
    self.assertEqual(result, addr_list)
316
    result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=NotImplemented)
317
    self.assertEqual(result, zip(node_list, addr_list))
246 318

  
247
  def testAddressLookupNSLookup(self):
319
  def testNsLookup(self):
248 320
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
249 321
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
250 322
    ssc = GetFakeSimpleStoreClass(lambda _: [])
251 323
    node_addr_map = dict(zip(node_list, addr_list))
252 324
    nslookup_fn = lambda name, family=None: node_addr_map.get(name)
253
    result = rpc._AddressLookup(node_list, ssc=ssc, nslookup_fn=nslookup_fn)
254
    self.assertEqual(result, addr_list)
325
    result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=nslookup_fn)
326
    self.assertEqual(result, zip(node_list, addr_list))
255 327

  
256
  def testAddressLookupBoth(self):
328
  def testBothLookups(self):
257 329
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
258 330
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
259 331
    n = len(addr_list) / 2
260
    node_addr_list = [ " ".join(t) for t in zip(node_list[n:], addr_list[n:])]
332
    node_addr_list = [" ".join(t) for t in zip(node_list[n:], addr_list[n:])]
261 333
    ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
262 334
    node_addr_map = dict(zip(node_list[:n], addr_list[:n]))
263 335
    nslookup_fn = lambda name, family=None: node_addr_map.get(name)
264
    result = rpc._AddressLookup(node_list, ssc=ssc, nslookup_fn=nslookup_fn)
265
    self.assertEqual(result, addr_list)
336
    result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=nslookup_fn)
337
    self.assertEqual(result, zip(node_list, addr_list))
266 338

  
267 339
  def testAddressLookupIPv6(self):
268
    addr_list = ["2001:db8::%d" % n for n in range(0, 255, 13)]
269
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
270
    node_addr_list = [ " ".join(t) for t in zip(node_list, addr_list)]
340
    addr_list = ["2001:db8::%d" % n for n in range(0, 255, 11)]
341
    node_list = ["node%d.example.com" % n for n in range(0, 255, 11)]
342
    node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
271 343
    ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
272
    result = rpc._AddressLookup(node_list, ssc=ssc)
273
    self.assertEqual(result, addr_list)
344
    result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=NotImplemented)
345
    self.assertEqual(result, zip(node_list, addr_list))
346

  
347

  
348
class TestStaticResolver(unittest.TestCase):
349
  def test(self):
350
    addresses = ["192.0.2.%d" % n for n in range(0, 123, 7)]
351
    nodes = ["node%s.example.com" % n for n in range(0, 123, 7)]
352
    res = rpc._StaticResolver(addresses)
353
    self.assertEqual(res(nodes), zip(nodes, addresses))
354

  
355
  def testWrongLength(self):
356
    res = rpc._StaticResolver([])
357
    self.assertRaises(AssertionError, res, ["abc"])
358

  
359

  
360
class TestNodeConfigResolver(unittest.TestCase):
361
  @staticmethod
362
  def _GetSingleOnlineNode(name):
363
    assert name == "node90.example.com"
364
    return objects.Node(name=name, offline=False, primary_ip="192.0.2.90")
365

  
366
  @staticmethod
367
  def _GetSingleOfflineNode(name):
368
    assert name == "node100.example.com"
369
    return objects.Node(name=name, offline=True, primary_ip="192.0.2.100")
370

  
371
  def testSingleOnline(self):
372
    self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOnlineNode,
373
                                             NotImplemented,
374
                                             ["node90.example.com"]),
375
                     [("node90.example.com", "192.0.2.90")])
376

  
377
  def testSingleOffline(self):
378
    self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOfflineNode,
379
                                             NotImplemented,
380
                                             ["node100.example.com"]),
381
                     [("node100.example.com", rpc._OFFLINE)])
382

  
383
  def testUnknownSingleNode(self):
384
    self.assertEqual(rpc._NodeConfigResolver(lambda _: None, NotImplemented,
385
                                             ["node110.example.com"]),
386
                     [("node110.example.com", "node110.example.com")])
387

  
388
  def testMultiEmpty(self):
389
    self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
390
                                             lambda: {},
391
                                             []),
392
                     [])
393

  
394
  def testMultiSomeOffline(self):
395
    nodes = dict(("node%s.example.com" % i,
396
                  objects.Node(name="node%s.example.com" % i,
397
                               offline=((i % 3) == 0),
398
                               primary_ip="192.0.2.%s" % i))
399
                  for i in range(1, 255))
400

  
401
    # Resolve no names
402
    self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
403
                                             lambda: nodes,
404
                                             []),
405
                     [])
406

  
407
    # Offline, online and unknown hosts
408
    self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
409
                                             lambda: nodes,
410
                                             ["node3.example.com",
411
                                              "node92.example.com",
412
                                              "node54.example.com",
413
                                              "unknown.example.com",]), [
414
      ("node3.example.com", rpc._OFFLINE),
415
      ("node92.example.com", "192.0.2.92"),
416
      ("node54.example.com", rpc._OFFLINE),
417
      ("unknown.example.com", "unknown.example.com"),
418
      ])
274 419

  
275 420

  
276 421
if __name__ == "__main__":

Also available in: Unified diff