rapi.client, http.client: Format url correctly when using IPv6
[ganeti-local] / lib / rpc.py
index f5c4333..e31a529 100644 (file)
@@ -1,7 +1,7 @@
 #
 #
 
-# Copyright (C) 2006, 2007 Google Inc.
+# Copyright (C) 2006, 2007, 2010 Google Inc.
 #
 # This program is free software; you can redistribute it and/or modify
 # it under the terms of the GNU General Public License as published by
@@ -44,6 +44,7 @@ from ganeti import serializer
 from ganeti import constants
 from ganeti import errors
 from ganeti import netutils
+from ganeti import ssconf
 
 # pylint has a bug here, doesn't see this import
 import ganeti.http.client  # pylint: disable-msg=W0611
@@ -151,6 +152,23 @@ def _RpcTimeout(secs):
   return decorator
 
 
+def RunWithRPC(fn):
+  """RPC-wrapper decorator.
+
+  When applied to a function, it runs it with the RPC system
+  initialized, and it shutsdown the system afterwards. This means the
+  function must be called without RPC being initialized.
+
+  """
+  def wrapper(*args, **kwargs):
+    Init()
+    try:
+      return fn(*args, **kwargs)
+    finally:
+      Shutdown()
+  return wrapper
+
+
 class RpcResult(object):
   """RPC Result class.
 
@@ -239,6 +257,33 @@ class RpcResult(object):
     raise ec(*args) # pylint: disable-msg=W0142
 
 
+def _AddressLookup(node_list,
+                   ssc=ssconf.SimpleStore,
+                   nslookup_fn=netutils.Hostname.GetIP):
+  """Return addresses for given node names.
+
+  @type node_list: list
+  @param node_list: List of node names
+  @type ssc: class
+  @param ssc: SimpleStore class that is used to obtain node->ip mappings
+  @type lookup_fn: callable
+  @param lookup_fn: function use to do NS lookup
+  @rtype: list of addresses and/or None's
+  @returns: List of corresponding addresses, if found
+
+  """
+  iplist = ssc().GetNodePrimaryIPList()
+  addresses = []
+  ipmap = dict(entry.split() for entry in iplist)
+  for node in node_list:
+    address = ipmap.get(node)
+    if address is None:
+      address = nslookup_fn(node)
+    addresses.append(address)
+
+  return addresses
+
+
 class Client:
   """RPC Client class.
 
@@ -251,13 +296,14 @@ class Client:
   cause bugs.
 
   """
-  def __init__(self, procedure, body, port):
+  def __init__(self, procedure, body, port, address_lookup_fn=_AddressLookup):
     assert procedure in _TIMEOUTS, ("New RPC call not declared in the"
                                     " timeouts table")
     self.procedure = procedure
     self.body = body
     self.port = port
     self._request = {}
+    self._address_lookup_fn = address_lookup_fn
 
   def ConnectList(self, node_list, address_list=None, read_timeout=None):
     """Add a list of nodes to the target nodes.
@@ -268,15 +314,16 @@ class Client:
     @keyword address_list: either None or a list with node addresses,
         which must have the same length as the node list
     @type read_timeout: int
-    @param read_timeout: overwrites the default read timeout for the
-        given operation
+    @param read_timeout: overwrites default timeout for operation
 
     """
     if address_list is None:
-      address_list = [None for _ in node_list]
-    else:
-      assert len(node_list) == len(address_list), \
-             "Name and address lists should have the same length"
+      # Always use IP address instead of node name
+      address_list = self._address_lookup_fn(node_list)
+
+    assert len(node_list) == len(address_list), \
+           "Name and address lists must have the same length"
+
     for node, address in zip(node_list, address_list):
       self.ConnectNode(node, address, read_timeout=read_timeout)
 
@@ -286,11 +333,16 @@ class Client:
     @type name: str
     @param name: the node name
     @type address: str
-    @keyword address: the node address, if known
+    @param address: the node address, if known
+    @type read_timeout: int
+    @param read_timeout: overwrites default timeout for operation
 
     """
     if address is None:
-      address = name
+      # Always use IP address instead of node name
+      address = self._address_lookup_fn([name])[0]
+
+    assert(address is not None)
 
     if read_timeout is None:
       read_timeout = _TIMEOUTS[self.procedure]