Revision eb202c13

b/lib/rpc.py
1 1
#
2 2
#
3 3

  
4
# Copyright (C) 2006, 2007 Google Inc.
4
# Copyright (C) 2006, 2007, 2010 Google Inc.
5 5
#
6 6
# This program is free software; you can redistribute it and/or modify
7 7
# it under the terms of the GNU General Public License as published by
......
44 44
from ganeti import constants
45 45
from ganeti import errors
46 46
from ganeti import netutils
47
from ganeti import ssconf
47 48

  
48 49
# pylint has a bug here, doesn't see this import
49 50
import ganeti.http.client  # pylint: disable-msg=W0611
......
256 257
    raise ec(*args) # pylint: disable-msg=W0142
257 258

  
258 259

  
260
def _AddressLookup(node_list,
261
                   ssc=ssconf.SimpleStore,
262
                   nslookup_fn=netutils.HostInfo.LookupHostname):
263
  """Return addresses for given node names.
264

  
265
  @type node_list: list
266
  @param node_list: List of node names
267
  @type ssc: class
268
  @param ssc: SimpleStore class that is used to obtain node->ip mappings
269
  @type lookup_fn: callable
270
  @param lookup_fn: function use to do NS lookup
271
  @rtype: list of addresses and/or None's
272
  @returns: List of corresponding addresses, if found
273

  
274
  """
275
  def _NSLookup(name):
276
    _, _, addrs = nslookup_fn(name)
277
    return addrs[0]
278

  
279
  addresses = []
280
  try:
281
    iplist = ssc().GetNodePrimaryIPList()
282
    ipmap = dict(entry.split() for entry in iplist)
283
    for node in node_list:
284
      address = ipmap.get(node)
285
      if address is None:
286
        address = _NSLookup(node)
287
      addresses.append(address)
288
  except errors.ConfigurationError:
289
    # Address not found in so we do a NS lookup
290
    addresses = [_NSLookup(node) for node in node_list]
291

  
292
  return addresses
293

  
294

  
259 295
class Client:
260 296
  """RPC Client class.
261 297

  
......
268 304
  cause bugs.
269 305

  
270 306
  """
271
  def __init__(self, procedure, body, port):
307
  def __init__(self, procedure, body, port, address_lookup_fn=_AddressLookup):
272 308
    assert procedure in _TIMEOUTS, ("New RPC call not declared in the"
273 309
                                    " timeouts table")
274 310
    self.procedure = procedure
275 311
    self.body = body
276 312
    self.port = port
277 313
    self._request = {}
314
    self._address_lookup_fn = address_lookup_fn
278 315

  
279 316
  def ConnectList(self, node_list, address_list=None, read_timeout=None):
280 317
    """Add a list of nodes to the target nodes.
......
285 322
    @keyword address_list: either None or a list with node addresses,
286 323
        which must have the same length as the node list
287 324
    @type read_timeout: int
288
    @param read_timeout: overwrites the default read timeout for the
289
        given operation
325
    @param read_timeout: overwrites default timeout for operation
290 326

  
291 327
    """
292 328
    if address_list is None:
293
      address_list = [None for _ in node_list]
294
    else:
295
      assert len(node_list) == len(address_list), \
296
             "Name and address lists should have the same length"
329
      # Always use IP address instead of node name
330
      address_list = self._address_lookup_fn(node_list)
331

  
332
    assert len(node_list) == len(address_list), \
333
           "Name and address lists must have the same length"
334

  
297 335
    for node, address in zip(node_list, address_list):
298 336
      self.ConnectNode(node, address, read_timeout=read_timeout)
299 337

  
......
303 341
    @type name: str
304 342
    @param name: the node name
305 343
    @type address: str
306
    @keyword address: the node address, if known
344
    @param address: the node address, if known
345
    @type read_timeout: int
346
    @param read_timeout: overwrites default timeout for operation
307 347

  
308 348
    """
309 349
    if address is None:
310
      address = name
350
      # Always use IP address instead of node name
351
      address = self._address_lookup_fn([name])[0]
352

  
353
    assert(address is not None)
311 354

  
312 355
    if read_timeout is None:
313 356
      read_timeout = _TIMEOUTS[self.procedure]
b/test/ganeti.rpc_unittest.py
56 56
      self._response_fn(req)
57 57

  
58 58

  
59
def GetFakeSimpleStoreClass(fn):
60
  class FakeSimpleStore:
61
    GetNodePrimaryIPList = fn
62

  
63
  return FakeSimpleStore
64

  
65

  
59 66
class TestClient(unittest.TestCase):
67
  def _FakeAddressLookup(self, map):
68
    return lambda node_list: [map.get(node) for node in node_list]
69

  
60 70
  def _GetVersionResponse(self, req):
61 71
    self.assertEqual(req.host, "localhost")
62 72
    self.assertEqual(req.port, 24094)
......
66 76
    req.resp_body = serializer.DumpJson((True, 123))
67 77

  
68 78
  def testVersionSuccess(self):
69
    client = rpc.Client("version", None, 24094)
79
    fn = self._FakeAddressLookup({"localhost": "localhost"})
80
    client = rpc.Client("version", None, 24094, address_lookup_fn=fn)
70 81
    client.ConnectNode("localhost")
71 82
    pool = FakeHttpPool(self._GetVersionResponse)
72 83
    result = client.GetResults(http_pool=pool)
......
90 101

  
91 102
  def testMultiVersionSuccess(self):
92 103
    nodes = ["node%s" % i for i in range(50)]
93
    client = rpc.Client("version", None, 23245)
104
    fn = self._FakeAddressLookup(dict(zip(nodes, nodes)))
105
    client = rpc.Client("version", None, 23245, address_lookup_fn=fn)
94 106
    client.ConnectList(nodes)
95 107

  
96 108
    pool = FakeHttpPool(self._GetMultiVersionResponse)
......
115 127
    req.resp_body = serializer.DumpJson((False, "Unknown error"))
116 128

  
117 129
  def testVersionFailure(self):
118
    client = rpc.Client("version", None, 5903)
130
    lookup_map = {"aef9ur4i.example.com": "aef9ur4i.example.com"}
131
    fn = self._FakeAddressLookup(lookup_map)
132
    client = rpc.Client("version", None, 5903, address_lookup_fn=fn)
119 133
    client.ConnectNode("aef9ur4i.example.com")
120 134
    pool = FakeHttpPool(self._GetVersionResponseFail)
121 135
    result = client.GetResults(http_pool=pool)
......
152 166

  
153 167
  def testHttpError(self):
154 168
    nodes = ["uaf6pbbv%s" % i for i in range(50)]
169
    fn = self._FakeAddressLookup(dict(zip(nodes, nodes)))
155 170

  
156 171
    httperrnodes = set(nodes[1::7])
157 172
    self.assertEqual(len(httperrnodes), 7)
......
161 176

  
162 177
    self.assertEqual(len(set(nodes) - failnodes - httperrnodes), 29)
163 178

  
164
    client = rpc.Client("vg_list", None, 15165)
179
    client = rpc.Client("vg_list", None, 15165, address_lookup_fn=fn)
165 180
    client.ConnectList(nodes)
166 181

  
167 182
    pool = FakeHttpPool(compat.partial(self._GetHttpErrorResponse,
......
203 218
    req.resp_body = serializer.DumpJson("invalid response")
204 219

  
205 220
  def testInvalidResponse(self):
206
    client = rpc.Client("version", None, 19978)
221
    lookup_map = {"oqo7lanhly.example.com": "oqo7lanhly.example.com"}
222
    fn = self._FakeAddressLookup(lookup_map)
223
    client = rpc.Client("version", None, 19978, address_lookup_fn=fn)
207 224
    for fn in [self._GetInvalidResponseA, self._GetInvalidResponseB]:
208 225
      client.ConnectNode("oqo7lanhly.example.com")
209 226
      pool = FakeHttpPool(fn)
......
218 235
      self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
219 236
      self.assertEqual(pool.reqcount, 1)
220 237

  
238
  def testAddressLookupSimpleStore(self):
239
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
240
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
241
    node_addr_list = [ " ".join(t) for t in zip(node_list, addr_list)]
242
    ssc = GetFakeSimpleStoreClass(lambda s: node_addr_list)
243
    result = rpc._AddressLookup(node_list, ssc=ssc)
244
    self.assertEqual(result, addr_list)
245

  
246
  def testAddressLookupNSLookup(self):
247
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
248
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
249
    ssc = GetFakeSimpleStoreClass(lambda s: [])
250
    node_addr_map = dict(zip(node_list, addr_list))
251
    nslookup_fn = lambda name: (None, None, [node_addr_map.get(name)])
252
    result = rpc._AddressLookup(node_list, ssc=ssc, nslookup_fn=nslookup_fn)
253
    self.assertEqual(result, addr_list)
254

  
255
  def testAddressLookupBoth(self):
256
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
257
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
258
    n = len(addr_list) / 2
259
    node_addr_list = [ " ".join(t) for t in zip(node_list[n:], addr_list[n:])]
260
    ssc = GetFakeSimpleStoreClass(lambda s: node_addr_list)
261
    node_addr_map = dict(zip(node_list[:n], addr_list[:n]))
262
    nslookup_fn = lambda name: (None, None, [node_addr_map.get(name)])
263
    result = rpc._AddressLookup(node_list, ssc=ssc, nslookup_fn=nslookup_fn)
264
    self.assertEqual(result, addr_list)
265

  
221 266

  
222 267
if __name__ == "__main__":
223 268
  testutils.GanetiTestProgram()

Also available in: Unified diff