Modify utils.TcpPing to make source address optional
authorIustin Pop <iustin@google.com>
Mon, 10 Mar 2008 16:29:10 +0000 (16:29 +0000)
committerIustin Pop <iustin@google.com>
Mon, 10 Mar 2008 16:29:10 +0000 (16:29 +0000)
This patch modifies TcpPing and its callers to make the source address
selection optional. Usually, the kernel will know better what
source address to use, just in some cases we want to enforce a given
source address so it makes sense to make this optional.

Reviewed-by: ultrotter

daemons/ganeti-master
daemons/ganeti-noded
lib/cmdlib.py
lib/utils.py
test/ganeti.utils_unittest.py

index 37ec91a..b24f43f 100755 (executable)
@@ -106,10 +106,9 @@ def StartMaster(master_netdev, master_ip, debug):
   """Starts the master.
 
   """
-  if utils.TcpPing(utils.HostInfo().name, master_ip,
-                   constants.DEFAULT_NODED_PORT):
-    if utils.TcpPing(constants.LOCALHOST_IP_ADDRESS, master_ip,
-                     constants.DEFAULT_NODED_PORT):
+  if utils.TcpPing(master_ip, constants.DEFAULT_NODED_PORT):
+    if utils.TcpPing(master_ip, constants.DEFAULT_NODED_PORT,
+                     source=constants.LOCALHOST_IP_ADDRESS):
       # we already have the ip:
       if debug:
         sys.stderr.write("Notice: already started.\n")
index f344474..ab10cb8 100755 (executable)
@@ -370,8 +370,8 @@ class ServerObject(pb.Avatar):
     """Do a TcpPing on the remote node.
 
     """
-    return utils.TcpPing(params[0], params[1], params[2],
-                         timeout=params[3], live_port_needed=params[4])
+    return utils.TcpPing(params[1], params[2], timeout=params[3],
+                         live_port_needed=params[4], source=params[0])
 
   @staticmethod
   def perspective_node_info(params):
index 1c5356e..2e26cfb 100644 (file)
@@ -515,8 +515,8 @@ class LUInitCluster(LogicalUnit):
 
     self.clustername = clustername = utils.HostInfo(self.op.cluster_name)
 
-    if not utils.TcpPing(constants.LOCALHOST_IP_ADDRESS, hostname.ip,
-                         constants.DEFAULT_NODED_PORT):
+    if not utils.TcpPing(hostname.ip, constants.DEFAULT_NODED_PORT,
+                         source=constants.LOCALHOST_IP_ADDRESS):
       raise errors.OpPrereqError("Inconsistency: this host's name resolves"
                                  " to %s,\nbut this ip address does not"
                                  " belong to this host."
@@ -527,8 +527,8 @@ class LUInitCluster(LogicalUnit):
       raise errors.OpPrereqError("Invalid secondary ip given")
     if (secondary_ip and
         secondary_ip != hostname.ip and
-        (not utils.TcpPing(constants.LOCALHOST_IP_ADDRESS, secondary_ip,
-                           constants.DEFAULT_NODED_PORT))):
+        (not utils.TcpPing(secondary_ip, constants.DEFAULT_NODED_PORT,
+                           source=constants.LOCALHOST_IP_ADDRESS))):
       raise errors.OpPrereqError("You gave %s as secondary IP,"
                                  " but it does not belong to this host." %
                                  secondary_ip)
@@ -1477,16 +1477,13 @@ class LUAddNode(LogicalUnit):
                                    " new node doesn't have one")
 
     # checks reachablity
-    if not utils.TcpPing(utils.HostInfo().name,
-                         primary_ip,
-                         constants.DEFAULT_NODED_PORT):
+    if not utils.TcpPing(primary_ip, constants.DEFAULT_NODED_PORT):
       raise errors.OpPrereqError("Node not reachable by ping")
 
     if not newbie_singlehomed:
       # check reachability from my secondary ip to newbie's secondary ip
-      if not utils.TcpPing(myself.secondary_ip,
-                           secondary_ip,
-                           constants.DEFAULT_NODED_PORT):
+      if not utils.TcpPing(secondary_ip, constants.DEFAULT_NODED_PORT,
+                           source=myself.secondary_ip):
         raise errors.OpPrereqError("Node secondary ip not reachable by TCP"
                                    " based ping to noded port")
 
@@ -3074,8 +3071,7 @@ class LUCreateInstance(LogicalUnit):
                                  " adding an instance in start mode")
 
     if self.op.ip_check:
-      if utils.TcpPing(utils.HostInfo().name, hostname1.ip,
-                       constants.DEFAULT_NODED_PORT):
+      if utils.TcpPing(hostname1.ip, constants.DEFAULT_NODED_PORT):
         raise errors.OpPrereqError("IP %s of instance %s already in use" %
                                    (hostname1.ip, instance_name))
 
index 6cbd268..f178499 100644 (file)
@@ -908,25 +908,29 @@ def ShellQuoteArgs(args):
   return ' '.join([ShellQuote(i) for i in args])
 
 
-
-def TcpPing(source, target, port, timeout=10, live_port_needed=False):
+def TcpPing(target, port, timeout=10, live_port_needed=False, source=None):
   """Simple ping implementation using TCP connect(2).
 
-  Try to do a TCP connect(2) from the specified source IP to the specified
-  target IP and the specified target port. If live_port_needed is set to true,
-  requires the remote end to accept the connection. The timeout is specified
-  in seconds and defaults to 10 seconds
+  Try to do a TCP connect(2) from an optional source IP to the
+  specified target IP and the specified target port. If the optional
+  parameter live_port_needed is set to true, requires the remote end
+  to accept the connection. The timeout is specified in seconds and
+  defaults to 10 seconds. If the source optional argument is not
+  passed, the source address selection is left to the kernel,
+  otherwise we try to connect using the passed address (failures to
+  bind other than EADDRNOTAVAIL will be ignored).
 
   """
   sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 
   sucess = False
 
-  try:
-    sock.bind((source, 0))
-  except socket.error, (errcode, errstring):
-    if errcode == errno.EADDRNOTAVAIL:
-      success = False
+  if source is not None:
+    try:
+      sock.bind((source, 0))
+    except socket.error, (errcode, errstring):
+      if errcode == errno.EADDRNOTAVAIL:
+        success = False
 
   sock.settimeout(timeout)
 
index 560c7ac..d57e0c0 100755 (executable)
@@ -544,12 +544,20 @@ class TestTcpPing(unittest.TestCase):
 
   def testTcpPingToLocalHostAccept(self):
     self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
-                         constants.LOCALHOST_IP_ADDRESS,
                          self.listenerport,
                          timeout=10,
-                         live_port_needed=True),
+                         live_port_needed=True,
+                         source=constants.LOCALHOST_IP_ADDRESS,
+                         ),
                  "failed to connect to test listener")
 
+    self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
+                         self.listenerport,
+                         timeout=10,
+                         live_port_needed=True,
+                         ),
+                 "failed to connect to test listener (no source)")
+
 
 class TestTcpPingDeaf(unittest.TestCase):
   """Testcase for TCP version of ping - against non listen(2)ing port"""
@@ -565,20 +573,36 @@ class TestTcpPingDeaf(unittest.TestCase):
 
   def testTcpPingToLocalHostAcceptDeaf(self):
     self.failIf(TcpPing(constants.LOCALHOST_IP_ADDRESS,
-                        constants.LOCALHOST_IP_ADDRESS,
                         self.deaflistenerport,
                         timeout=constants.TCP_PING_TIMEOUT,
-                        live_port_needed=True), # need successful connect(2)
+                        live_port_needed=True,
+                        source=constants.LOCALHOST_IP_ADDRESS,
+                        ), # need successful connect(2)
                 "successfully connected to deaf listener")
 
+    self.failIf(TcpPing(constants.LOCALHOST_IP_ADDRESS,
+                        self.deaflistenerport,
+                        timeout=constants.TCP_PING_TIMEOUT,
+                        live_port_needed=True,
+                        ), # need successful connect(2)
+                "successfully connected to deaf listener (no source addr)")
+
   def testTcpPingToLocalHostNoAccept(self):
     self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
-                         constants.LOCALHOST_IP_ADDRESS,
                          self.deaflistenerport,
                          timeout=constants.TCP_PING_TIMEOUT,
-                         live_port_needed=False), # ECONNREFUSED is OK
+                         live_port_needed=False,
+                         source=constants.LOCALHOST_IP_ADDRESS,
+                         ), # ECONNREFUSED is OK
                  "failed to ping alive host on deaf port")
 
+    self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
+                         self.deaflistenerport,
+                         timeout=constants.TCP_PING_TIMEOUT,
+                         live_port_needed=False,
+                         ), # ECONNREFUSED is OK
+                 "failed to ping alive host on deaf port (no source addr)")
+
 
 class TestListVisibleFiles(unittest.TestCase):
   """Test case for ListVisibleFiles"""