Revision 00267bfe

b/lib/rpc.py
46 46
from ganeti import netutils
47 47
from ganeti import ssconf
48 48
from ganeti import runtime
49
from ganeti import compat
49 50

  
50 51
# pylint has a bug here, doesn't see this import
51 52
import ganeti.http.client  # pylint: disable=W0611
......
77 78
_TIMEOUTS = {
78 79
}
79 80

  
81
#: Special value to describe an offline host
82
_OFFLINE = object()
83

  
80 84

  
81 85
def Init():
82 86
  """Initializes the module-global HTTP client manager.
......
285 289
    raise ec(*args) # pylint: disable=W0142
286 290

  
287 291

  
288
def _AddressLookup(node_list,
289
                   ssc=ssconf.SimpleStore,
290
                   nslookup_fn=netutils.Hostname.GetIP):
292
def _SsconfResolver(node_list,
293
                    ssc=ssconf.SimpleStore,
294
                    nslookup_fn=netutils.Hostname.GetIP):
291 295
  """Return addresses for given node names.
292 296

  
293 297
  @type node_list: list
......
296 300
  @param ssc: SimpleStore class that is used to obtain node->ip mappings
297 301
  @type nslookup_fn: callable
298 302
  @param nslookup_fn: function use to do NS lookup
299
  @rtype: list of addresses and/or None's
300
  @returns: List of corresponding addresses, if found
303
  @rtype: list of tuple; (string, string)
304
  @return: List of tuples containing node name and IP address
301 305

  
302 306
  """
303 307
  ss = ssc()
304 308
  iplist = ss.GetNodePrimaryIPList()
305 309
  family = ss.GetPrimaryIPFamily()
306
  addresses = []
307 310
  ipmap = dict(entry.split() for entry in iplist)
311

  
312
  result = []
308 313
  for node in node_list:
309
    address = ipmap.get(node)
310
    if address is None:
311
      address = nslookup_fn(node, family=family)
312
    addresses.append(address)
314
    ip = ipmap.get(node)
315
    if ip is None:
316
      ip = nslookup_fn(node, family=family)
317
    result.append((node, ip))
318

  
319
  return result
320

  
321

  
322
class _StaticResolver:
323
  def __init__(self, addresses):
324
    """Initializes this class.
325

  
326
    """
327
    self._addresses = addresses
313 328

  
314
  return addresses
329
  def __call__(self, hosts):
330
    """Returns static addresses for hosts.
315 331

  
332
    """
333
    assert len(hosts) == len(self._addresses)
334
    return zip(hosts, self._addresses)
316 335

  
317
class Client:
318
  """RPC Client class.
319 336

  
320
  This class, given a (remote) method name, a list of parameters and a
321
  list of nodes, will contact (in parallel) all nodes, and return a
322
  dict of results (key: node name, value: result).
337
def _CheckConfigNode(name, node):
338
  """Checks if a node is online.
323 339

  
324
  One current bug is that generic failure is still signaled by
325
  'False' result, which is not good. This overloading of values can
326
  cause bugs.
340
  @type name: string
341
  @param name: Node name
342
  @type node: L{objects.Node} or None
343
  @param node: Node object
327 344

  
328 345
  """
329
  def __init__(self, procedure, body, port, address_lookup_fn=_AddressLookup):
330
    assert procedure in _TIMEOUTS, ("New RPC call not declared in the"
331
                                    " timeouts table")
332
    self.procedure = procedure
333
    self.body = body
334
    self.port = port
335
    self._request = {}
336
    self._address_lookup_fn = address_lookup_fn
337

  
338
  def ConnectList(self, node_list, address_list=None, read_timeout=None):
339
    """Add a list of nodes to the target nodes.
346
  if node is None:
347
    # Depend on DNS for name resolution
348
    ip = name
349
  elif node.offline:
350
    ip = _OFFLINE
351
  else:
352
    ip = node.primary_ip
353
  return (name, ip)
340 354

  
341
    @type node_list: list
342
    @param node_list: the list of node names to connect
343
    @type address_list: list or None
344
    @keyword address_list: either None or a list with node addresses,
345
        which must have the same length as the node list
346
    @type read_timeout: int
347
    @param read_timeout: overwrites default timeout for operation
355

  
356
def _NodeConfigResolver(single_node_fn, all_nodes_fn, hosts):
357
  """Calculate node addresses using configuration.
358

  
359
  """
360
  # Special case for single-host lookups
361
  if len(hosts) == 1:
362
    (name, ) = hosts
363
    return [_CheckConfigNode(name, single_node_fn(name))]
364
  else:
365
    all_nodes = all_nodes_fn()
366
    return [_CheckConfigNode(name, all_nodes.get(name, None))
367
            for name in hosts]
368

  
369

  
370
class _RpcProcessor:
371
  def __init__(self, resolver, port):
372
    """Initializes this class.
373

  
374
    @param resolver: callable accepting a list of hostnames, returning a list
375
      of tuples containing name and IP address (IP address can be the name or
376
      the special value L{_OFFLINE} to mark offline machines)
377
    @type port: int
378
    @param port: TCP port
348 379

  
349 380
    """
350
    if address_list is None:
351
      # Always use IP address instead of node name
352
      address_list = self._address_lookup_fn(node_list)
381
    self._resolver = resolver
382
    self._port = port
383

  
384
  @staticmethod
385
  def _PrepareRequests(hosts, port, procedure, body, read_timeout):
386
    """Prepares requests by sorting offline hosts into separate list.
353 387

  
354
    assert len(node_list) == len(address_list), \
355
           "Name and address lists must have the same length"
388
    """
389
    results = {}
390
    requests = {}
356 391

  
357
    for node, address in zip(node_list, address_list):
358
      self.ConnectNode(node, address, read_timeout=read_timeout)
392
    for (name, ip) in hosts:
393
      if ip is _OFFLINE:
394
        # Node is marked as offline
395
        results[name] = RpcResult(node=name, offline=True, call=procedure)
396
      else:
397
        requests[name] = \
398
          http.client.HttpClientRequest(str(ip), port,
399
                                        http.HTTP_PUT, str("/%s" % procedure),
400
                                        headers=_RPC_CLIENT_HEADERS,
401
                                        post_data=body,
402
                                        read_timeout=read_timeout)
359 403

  
360
  def ConnectNode(self, name, address=None, read_timeout=None):
361
    """Add a node to the target list.
404
    return (results, requests)
362 405

  
363
    @type name: str
364
    @param name: the node name
365
    @type address: str
366
    @param address: the node address, if known
367
    @type read_timeout: int
368
    @param read_timeout: overwrites default timeout for operation
406
  @staticmethod
407
  def _CombineResults(results, requests, procedure):
408
    """Combines pre-computed results for offline hosts with actual call results.
369 409

  
370 410
    """
371
    if address is None:
372
      # Always use IP address instead of node name
373
      address = self._address_lookup_fn([name])[0]
411
    for name, req in requests.items():
412
      if req.success and req.resp_status_code == http.HTTP_OK:
413
        host_result = RpcResult(data=serializer.LoadJson(req.resp_body),
414
                                node=name, call=procedure)
415
      else:
416
        # TODO: Better error reporting
417
        if req.error:
418
          msg = req.error
419
        else:
420
          msg = req.resp_body
374 421

  
375
    assert(address is not None)
422
        logging.error("RPC error in %s on node %s: %s", procedure, name, msg)
423
        host_result = RpcResult(data=msg, failed=True, node=name,
424
                                call=procedure)
376 425

  
377
    if read_timeout is None:
378
      read_timeout = _TIMEOUTS[self.procedure]
426
      results[name] = host_result
379 427

  
380
    self._request[name] = \
381
      http.client.HttpClientRequest(str(address), self.port,
382
                                    http.HTTP_PUT, str("/%s" % self.procedure),
383
                                    headers=_RPC_CLIENT_HEADERS,
384
                                    post_data=str(self.body),
385
                                    read_timeout=read_timeout)
428
    return results
386 429

  
387
  def GetResults(self, http_pool=None):
388
    """Call nodes and return results.
430
  def __call__(self, hosts, procedure, body, read_timeout=None, http_pool=None):
431
    """Makes an RPC request to a number of nodes.
389 432

  
390
    @rtype: list
391
    @return: List of RPC results
433
    @type hosts: sequence
434
    @param hosts: Hostnames
435
    @type procedure: string
436
    @param procedure: Request path
437
    @type body: string
438
    @param body: Request body
439
    @type read_timeout: int or None
440
    @param read_timeout: Read timeout for request
392 441

  
393 442
    """
443
    assert procedure in _TIMEOUTS, "RPC call not declared in the timeouts table"
444

  
394 445
    if not http_pool:
395 446
      http_pool = _thread_local.GetHttpClientPool()
396 447

  
397
    http_pool.ProcessRequests(self._request.values())
398

  
399
    results = {}
448
    if read_timeout is None:
449
      read_timeout = _TIMEOUTS[procedure]
400 450

  
401
    for name, req in self._request.iteritems():
402
      if req.success and req.resp_status_code == http.HTTP_OK:
403
        results[name] = RpcResult(data=serializer.LoadJson(req.resp_body),
404
                                  node=name, call=self.procedure)
405
        continue
451
    (results, requests) = \
452
      self._PrepareRequests(self._resolver(hosts), self._port, procedure,
453
                            str(body), read_timeout)
406 454

  
407
      # TODO: Better error reporting
408
      if req.error:
409
        msg = req.error
410
      else:
411
        msg = req.resp_body
455
    http_pool.ProcessRequests(requests.values())
412 456

  
413
      logging.error("RPC error in %s from node %s: %s",
414
                    self.procedure, name, msg)
415
      results[name] = RpcResult(data=msg, failed=True, node=name,
416
                                call=self.procedure)
457
    assert not frozenset(results).intersection(requests)
417 458

  
418
    return results
459
    return self._CombineResults(results, requests, procedure)
419 460

  
420 461

  
421 462
def _EncodeImportExportIO(ieio, ieioargs):
......
445 486

  
446 487
    """
447 488
    self._cfg = context.cfg
448
    self.port = netutils.GetDaemonPort(constants.NODED)
489
    self._proc = _RpcProcessor(compat.partial(_NodeConfigResolver,
490
                                              self._cfg.GetNodeInfo,
491
                                              self._cfg.GetAllNodesInfo),
492
                               netutils.GetDaemonPort(constants.NODED))
449 493

  
450 494
  def _InstDict(self, instance, hvp=None, bep=None, osp=None):
451 495
    """Convert the given instance to a dict.
......
483 527
        nic['nicparams'])
484 528
    return idict
485 529

  
486
  def _ConnectList(self, client, node_list, call, read_timeout=None):
487
    """Helper for computing node addresses.
488

  
489
    @type client: L{ganeti.rpc.Client}
490
    @param client: a C{Client} instance
491
    @type node_list: list
492
    @param node_list: the node list we should connect
493
    @type call: string
494
    @param call: the name of the remote procedure call, for filling in
495
        correctly any eventual offline nodes' results
496
    @type read_timeout: int
497
    @param read_timeout: overwrites the default read timeout for the
498
        given operation
499

  
500
    """
501
    all_nodes = self._cfg.GetAllNodesInfo()
502
    name_list = []
503
    addr_list = []
504
    skip_dict = {}
505
    for node in node_list:
506
      if node in all_nodes:
507
        if all_nodes[node].offline:
508
          skip_dict[node] = RpcResult(node=node, offline=True, call=call)
509
          continue
510
        val = all_nodes[node].primary_ip
511
      else:
512
        val = None
513
      addr_list.append(val)
514
      name_list.append(node)
515
    if name_list:
516
      client.ConnectList(name_list, address_list=addr_list,
517
                         read_timeout=read_timeout)
518
    return skip_dict
519

  
520
  def _ConnectNode(self, client, node, call, read_timeout=None):
521
    """Helper for computing one node's address.
522

  
523
    @type client: L{ganeti.rpc.Client}
524
    @param client: a C{Client} instance
525
    @type node: str
526
    @param node: the node we should connect
527
    @type call: string
528
    @param call: the name of the remote procedure call, for filling in
529
        correctly any eventual offline nodes' results
530
    @type read_timeout: int
531
    @param read_timeout: overwrites the default read timeout for the
532
        given operation
533

  
534
    """
535
    node_info = self._cfg.GetNodeInfo(node)
536
    if node_info is not None:
537
      if node_info.offline:
538
        return RpcResult(node=node, offline=True, call=call)
539
      addr = node_info.primary_ip
540
    else:
541
      addr = None
542
    client.ConnectNode(node, address=addr, read_timeout=read_timeout)
543

  
544 530
  def _MultiNodeCall(self, node_list, procedure, args, read_timeout=None):
545 531
    """Helper for making a multi-node call
546 532

  
547 533
    """
548 534
    body = serializer.DumpJson(args, indent=False)
549
    c = Client(procedure, body, self.port)
550
    skip_dict = self._ConnectList(c, node_list, procedure,
551
                                  read_timeout=read_timeout)
552
    skip_dict.update(c.GetResults())
553
    return skip_dict
535
    return self._proc(node_list, procedure, body, read_timeout=read_timeout)
554 536

  
555
  @classmethod
556
  def _StaticMultiNodeCall(cls, node_list, procedure, args,
537
  @staticmethod
538
  def _StaticMultiNodeCall(node_list, procedure, args,
557 539
                           address_list=None, read_timeout=None):
558 540
    """Helper for making a multi-node static call
559 541

  
560 542
    """
561 543
    body = serializer.DumpJson(args, indent=False)
562
    c = Client(procedure, body, netutils.GetDaemonPort(constants.NODED))
563
    c.ConnectList(node_list, address_list=address_list,
564
                  read_timeout=read_timeout)
565
    return c.GetResults()
544

  
545
    if address_list is None:
546
      resolver = _SsconfResolver
547
    else:
548
      # Caller provided an address list
549
      resolver = _StaticResolver(address_list)
550

  
551
    proc = _RpcProcessor(resolver,
552
                         netutils.GetDaemonPort(constants.NODED))
553
    return proc(node_list, procedure, body, read_timeout=read_timeout)
566 554

  
567 555
  def _SingleNodeCall(self, node, procedure, args, read_timeout=None):
568 556
    """Helper for making a single-node call
569 557

  
570 558
    """
571 559
    body = serializer.DumpJson(args, indent=False)
572
    c = Client(procedure, body, self.port)
573
    result = self._ConnectNode(c, node, procedure, read_timeout=read_timeout)
574
    if result is None:
575
      # we did connect, node is not offline
576
      result = c.GetResults()[node]
577
    return result
560
    return self._proc([node], procedure, body, read_timeout=read_timeout)[node]
578 561

  
579 562
  @classmethod
580 563
  def _StaticSingleNodeCall(cls, node, procedure, args, read_timeout=None):
......
582 565

  
583 566
    """
584 567
    body = serializer.DumpJson(args, indent=False)
585
    c = Client(procedure, body, netutils.GetDaemonPort(constants.NODED))
586
    c.ConnectNode(node, read_timeout=read_timeout)
587
    return c.GetResults()[node]
568
    proc = _RpcProcessor(_SsconfResolver,
569
                         netutils.GetDaemonPort(constants.NODED))
570
    return proc([node], procedure, body, read_timeout=read_timeout)[node]
588 571

  
589 572
  #
590 573
  # Begin RPC calls
b/test/ganeti.rpc_unittest.py
31 31
from ganeti import http
32 32
from ganeti import errors
33 33
from ganeti import serializer
34
from ganeti import objects
34 35

  
35 36
import testutils
36 37

  
......
64 65
  return FakeSimpleStore
65 66

  
66 67

  
67
class TestClient(unittest.TestCase):
68
class TestRpcProcessor(unittest.TestCase):
68 69
  def _FakeAddressLookup(self, map):
69 70
    return lambda node_list: [map.get(node) for node in node_list]
70 71

  
71 72
  def _GetVersionResponse(self, req):
72
    self.assertEqual(req.host, "localhost")
73
    self.assertEqual(req.host, "127.0.0.1")
73 74
    self.assertEqual(req.port, 24094)
74 75
    self.assertEqual(req.path, "/version")
76
    self.assertEqual(req.read_timeout, rpc._TMO_URGENT)
75 77
    req.success = True
76 78
    req.resp_status_code = http.HTTP_OK
77 79
    req.resp_body = serializer.DumpJson((True, 123))
78 80

  
79 81
  def testVersionSuccess(self):
80
    fn = self._FakeAddressLookup({"localhost": "localhost"})
81
    client = rpc.Client("version", None, 24094, address_lookup_fn=fn)
82
    client.ConnectNode("localhost")
82
    resolver = rpc._StaticResolver(["127.0.0.1"])
83 83
    pool = FakeHttpPool(self._GetVersionResponse)
84
    result = client.GetResults(http_pool=pool)
84
    proc = rpc._RpcProcessor(resolver, 24094)
85
    result = proc(["localhost"], "version", None, http_pool=pool)
85 86
    self.assertEqual(result.keys(), ["localhost"])
86 87
    lhresp = result["localhost"]
87 88
    self.assertFalse(lhresp.offline)
......
92 93
    lhresp.Raise("should not raise")
93 94
    self.assertEqual(pool.reqcount, 1)
94 95

  
96
  def _ReadTimeoutResponse(self, req):
97
    self.assertEqual(req.host, "192.0.2.13")
98
    self.assertEqual(req.port, 19176)
99
    self.assertEqual(req.path, "/version")
100
    self.assertEqual(req.read_timeout, 12356)
101
    req.success = True
102
    req.resp_status_code = http.HTTP_OK
103
    req.resp_body = serializer.DumpJson((True, -1))
104

  
105
  def testReadTimeout(self):
106
    resolver = rpc._StaticResolver(["192.0.2.13"])
107
    pool = FakeHttpPool(self._ReadTimeoutResponse)
108
    proc = rpc._RpcProcessor(resolver, 19176)
109
    result = proc(["node31856"], "version", None, http_pool=pool,
110
                  read_timeout=12356)
111
    self.assertEqual(result.keys(), ["node31856"])
112
    lhresp = result["node31856"]
113
    self.assertFalse(lhresp.offline)
114
    self.assertEqual(lhresp.node, "node31856")
115
    self.assertFalse(lhresp.fail_msg)
116
    self.assertEqual(lhresp.payload, -1)
117
    self.assertEqual(lhresp.call, "version")
118
    lhresp.Raise("should not raise")
119
    self.assertEqual(pool.reqcount, 1)
120

  
121
  def testOfflineNode(self):
122
    resolver = rpc._StaticResolver([rpc._OFFLINE])
123
    pool = FakeHttpPool(NotImplemented)
124
    proc = rpc._RpcProcessor(resolver, 30668)
125
    result = proc(["n17296"], "version", None, http_pool=pool)
126
    self.assertEqual(result.keys(), ["n17296"])
127
    lhresp = result["n17296"]
128
    self.assertTrue(lhresp.offline)
129
    self.assertEqual(lhresp.node, "n17296")
130
    self.assertTrue(lhresp.fail_msg)
131
    self.assertFalse(lhresp.payload)
132
    self.assertEqual(lhresp.call, "version")
133

  
134
    # With a message
135
    self.assertRaises(errors.OpExecError, lhresp.Raise, "should raise")
136

  
137
    # No message
138
    self.assertRaises(errors.OpExecError, lhresp.Raise, None)
139

  
140
    self.assertEqual(pool.reqcount, 0)
141

  
95 142
  def _GetMultiVersionResponse(self, req):
96 143
    self.assert_(req.host.startswith("node"))
97 144
    self.assertEqual(req.port, 23245)
......
102 149

  
103 150
  def testMultiVersionSuccess(self):
104 151
    nodes = ["node%s" % i for i in range(50)]
105
    fn = self._FakeAddressLookup(dict(zip(nodes, nodes)))
106
    client = rpc.Client("version", None, 23245, address_lookup_fn=fn)
107
    client.ConnectList(nodes)
108

  
152
    resolver = rpc._StaticResolver(nodes)
109 153
    pool = FakeHttpPool(self._GetMultiVersionResponse)
110
    result = client.GetResults(http_pool=pool)
154
    proc = rpc._RpcProcessor(resolver, 23245)
155
    result = proc(nodes, "version", None, http_pool=pool)
111 156
    self.assertEqual(sorted(result.keys()), sorted(nodes))
112 157

  
113 158
    for name in nodes:
......
121 166

  
122 167
    self.assertEqual(pool.reqcount, len(nodes))
123 168

  
124
  def _GetVersionResponseFail(self, req):
169
  def _GetVersionResponseFail(self, errinfo, req):
125 170
    self.assertEqual(req.path, "/version")
126 171
    req.success = True
127 172
    req.resp_status_code = http.HTTP_OK
128
    req.resp_body = serializer.DumpJson((False, "Unknown error"))
173
    req.resp_body = serializer.DumpJson((False, errinfo))
129 174

  
130 175
  def testVersionFailure(self):
131
    lookup_map = {"aef9ur4i.example.com": "aef9ur4i.example.com"}
132
    fn = self._FakeAddressLookup(lookup_map)
133
    client = rpc.Client("version", None, 5903, address_lookup_fn=fn)
134
    client.ConnectNode("aef9ur4i.example.com")
135
    pool = FakeHttpPool(self._GetVersionResponseFail)
136
    result = client.GetResults(http_pool=pool)
137
    self.assertEqual(result.keys(), ["aef9ur4i.example.com"])
138
    lhresp = result["aef9ur4i.example.com"]
139
    self.assertFalse(lhresp.offline)
140
    self.assertEqual(lhresp.node, "aef9ur4i.example.com")
141
    self.assert_(lhresp.fail_msg)
142
    self.assertFalse(lhresp.payload)
143
    self.assertEqual(lhresp.call, "version")
144
    self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
145
    self.assertEqual(pool.reqcount, 1)
176
    resolver = rpc._StaticResolver(["aef9ur4i.example.com"])
177
    proc = rpc._RpcProcessor(resolver, 5903)
178
    for errinfo in [None, "Unknown error"]:
179
      pool = FakeHttpPool(compat.partial(self._GetVersionResponseFail, errinfo))
180
      result = proc(["aef9ur4i.example.com"], "version", None, http_pool=pool)
181
      self.assertEqual(result.keys(), ["aef9ur4i.example.com"])
182
      lhresp = result["aef9ur4i.example.com"]
183
      self.assertFalse(lhresp.offline)
184
      self.assertEqual(lhresp.node, "aef9ur4i.example.com")
185
      self.assert_(lhresp.fail_msg)
186
      self.assertFalse(lhresp.payload)
187
      self.assertEqual(lhresp.call, "version")
188
      self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
189
      self.assertEqual(pool.reqcount, 1)
146 190

  
147 191
  def _GetHttpErrorResponse(self, httperrnodes, failnodes, req):
148 192
    self.assertEqual(req.path, "/vg_list")
......
167 211

  
168 212
  def testHttpError(self):
169 213
    nodes = ["uaf6pbbv%s" % i for i in range(50)]
170
    fn = self._FakeAddressLookup(dict(zip(nodes, nodes)))
214
    resolver = rpc._StaticResolver(nodes)
171 215

  
172 216
    httperrnodes = set(nodes[1::7])
173 217
    self.assertEqual(len(httperrnodes), 7)
......
177 221

  
178 222
    self.assertEqual(len(set(nodes) - failnodes - httperrnodes), 29)
179 223

  
180
    client = rpc.Client("vg_list", None, 15165, address_lookup_fn=fn)
181
    client.ConnectList(nodes)
182

  
224
    proc = rpc._RpcProcessor(resolver, 15165)
183 225
    pool = FakeHttpPool(compat.partial(self._GetHttpErrorResponse,
184 226
                                       httperrnodes, failnodes))
185
    result = client.GetResults(http_pool=pool)
227
    result = proc(nodes, "vg_list", None, http_pool=pool)
186 228
    self.assertEqual(sorted(result.keys()), sorted(nodes))
187 229

  
188 230
    for name in nodes:
......
219 261
    req.resp_body = serializer.DumpJson("invalid response")
220 262

  
221 263
  def testInvalidResponse(self):
222
    lookup_map = {"oqo7lanhly.example.com": "oqo7lanhly.example.com"}
223
    fn = self._FakeAddressLookup(lookup_map)
224
    client = rpc.Client("version", None, 19978, address_lookup_fn=fn)
264
    resolver = rpc._StaticResolver(["oqo7lanhly.example.com"])
265
    proc = rpc._RpcProcessor(resolver, 19978)
266

  
225 267
    for fn in [self._GetInvalidResponseA, self._GetInvalidResponseB]:
226
      client.ConnectNode("oqo7lanhly.example.com")
227 268
      pool = FakeHttpPool(fn)
228
      result = client.GetResults(http_pool=pool)
269
      result = proc(["oqo7lanhly.example.com"], "version", None, http_pool=pool)
229 270
      self.assertEqual(result.keys(), ["oqo7lanhly.example.com"])
230 271
      lhresp = result["oqo7lanhly.example.com"]
231 272
      self.assertFalse(lhresp.offline)
......
236 277
      self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
237 278
      self.assertEqual(pool.reqcount, 1)
238 279

  
239
  def testAddressLookupSimpleStore(self):
280
  def _GetBodyTestResponse(self, test_data, req):
281
    self.assertEqual(req.host, "192.0.2.84")
282
    self.assertEqual(req.port, 18700)
283
    self.assertEqual(req.path, "/upload_file")
284
    self.assertEqual(serializer.LoadJson(req.post_data), test_data)
285
    req.success = True
286
    req.resp_status_code = http.HTTP_OK
287
    req.resp_body = serializer.DumpJson((True, None))
288

  
289
  def testResponseBody(self):
290
    test_data = {
291
      "Hello": "World",
292
      "xyz": range(10),
293
      }
294
    resolver = rpc._StaticResolver(["192.0.2.84"])
295
    pool = FakeHttpPool(compat.partial(self._GetBodyTestResponse, test_data))
296
    proc = rpc._RpcProcessor(resolver, 18700)
297
    body = serializer.DumpJson(test_data)
298
    result = proc(["node19759"], "upload_file", body, http_pool=pool)
299
    self.assertEqual(result.keys(), ["node19759"])
300
    lhresp = result["node19759"]
301
    self.assertFalse(lhresp.offline)
302
    self.assertEqual(lhresp.node, "node19759")
303
    self.assertFalse(lhresp.fail_msg)
304
    self.assertEqual(lhresp.payload, None)
305
    self.assertEqual(lhresp.call, "upload_file")
306
    lhresp.Raise("should not raise")
307
    self.assertEqual(pool.reqcount, 1)
308

  
309

  
310
class TestSsconfResolver(unittest.TestCase):
311
  def testSsconfLookup(self):
240 312
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
241 313
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
242
    node_addr_list = [ " ".join(t) for t in zip(node_list, addr_list)]
314
    node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
243 315
    ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
244
    result = rpc._AddressLookup(node_list, ssc=ssc)
245
    self.assertEqual(result, addr_list)
316
    result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=NotImplemented)
317
    self.assertEqual(result, zip(node_list, addr_list))
246 318

  
247
  def testAddressLookupNSLookup(self):
319
  def testNsLookup(self):
248 320
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
249 321
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
250 322
    ssc = GetFakeSimpleStoreClass(lambda _: [])
251 323
    node_addr_map = dict(zip(node_list, addr_list))
252 324
    nslookup_fn = lambda name, family=None: node_addr_map.get(name)
253
    result = rpc._AddressLookup(node_list, ssc=ssc, nslookup_fn=nslookup_fn)
254
    self.assertEqual(result, addr_list)
325
    result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=nslookup_fn)
326
    self.assertEqual(result, zip(node_list, addr_list))
255 327

  
256
  def testAddressLookupBoth(self):
328
  def testBothLookups(self):
257 329
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
258 330
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
259 331
    n = len(addr_list) / 2
260
    node_addr_list = [ " ".join(t) for t in zip(node_list[n:], addr_list[n:])]
332
    node_addr_list = [" ".join(t) for t in zip(node_list[n:], addr_list[n:])]
261 333
    ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
262 334
    node_addr_map = dict(zip(node_list[:n], addr_list[:n]))
263 335
    nslookup_fn = lambda name, family=None: node_addr_map.get(name)
264
    result = rpc._AddressLookup(node_list, ssc=ssc, nslookup_fn=nslookup_fn)
265
    self.assertEqual(result, addr_list)
336
    result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=nslookup_fn)
337
    self.assertEqual(result, zip(node_list, addr_list))
266 338

  
267 339
  def testAddressLookupIPv6(self):
268
    addr_list = ["2001:db8::%d" % n for n in range(0, 255, 13)]
269
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
270
    node_addr_list = [ " ".join(t) for t in zip(node_list, addr_list)]
340
    addr_list = ["2001:db8::%d" % n for n in range(0, 255, 11)]
341
    node_list = ["node%d.example.com" % n for n in range(0, 255, 11)]
342
    node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
271 343
    ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
272
    result = rpc._AddressLookup(node_list, ssc=ssc)
273
    self.assertEqual(result, addr_list)
344
    result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=NotImplemented)
345
    self.assertEqual(result, zip(node_list, addr_list))
346

  
347

  
348
class TestStaticResolver(unittest.TestCase):
349
  def test(self):
350
    addresses = ["192.0.2.%d" % n for n in range(0, 123, 7)]
351
    nodes = ["node%s.example.com" % n for n in range(0, 123, 7)]
352
    res = rpc._StaticResolver(addresses)
353
    self.assertEqual(res(nodes), zip(nodes, addresses))
354

  
355
  def testWrongLength(self):
356
    res = rpc._StaticResolver([])
357
    self.assertRaises(AssertionError, res, ["abc"])
358

  
359

  
360
class TestNodeConfigResolver(unittest.TestCase):
361
  @staticmethod
362
  def _GetSingleOnlineNode(name):
363
    assert name == "node90.example.com"
364
    return objects.Node(name=name, offline=False, primary_ip="192.0.2.90")
365

  
366
  @staticmethod
367
  def _GetSingleOfflineNode(name):
368
    assert name == "node100.example.com"
369
    return objects.Node(name=name, offline=True, primary_ip="192.0.2.100")
370

  
371
  def testSingleOnline(self):
372
    self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOnlineNode,
373
                                             NotImplemented,
374
                                             ["node90.example.com"]),
375
                     [("node90.example.com", "192.0.2.90")])
376

  
377
  def testSingleOffline(self):
378
    self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOfflineNode,
379
                                             NotImplemented,
380
                                             ["node100.example.com"]),
381
                     [("node100.example.com", rpc._OFFLINE)])
382

  
383
  def testUnknownSingleNode(self):
384
    self.assertEqual(rpc._NodeConfigResolver(lambda _: None, NotImplemented,
385
                                             ["node110.example.com"]),
386
                     [("node110.example.com", "node110.example.com")])
387

  
388
  def testMultiEmpty(self):
389
    self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
390
                                             lambda: {},
391
                                             []),
392
                     [])
393

  
394
  def testMultiSomeOffline(self):
395
    nodes = dict(("node%s.example.com" % i,
396
                  objects.Node(name="node%s.example.com" % i,
397
                               offline=((i % 3) == 0),
398
                               primary_ip="192.0.2.%s" % i))
399
                  for i in range(1, 255))
400

  
401
    # Resolve no names
402
    self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
403
                                             lambda: nodes,
404
                                             []),
405
                     [])
406

  
407
    # Offline, online and unknown hosts
408
    self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
409
                                             lambda: nodes,
410
                                             ["node3.example.com",
411
                                              "node92.example.com",
412
                                              "node54.example.com",
413
                                              "unknown.example.com",]), [
414
      ("node3.example.com", rpc._OFFLINE),
415
      ("node92.example.com", "192.0.2.92"),
416
      ("node54.example.com", rpc._OFFLINE),
417
      ("unknown.example.com", "unknown.example.com"),
418
      ])
274 419

  
275 420

  
276 421
if __name__ == "__main__":

Also available in: Unified diff