Merge branch 'devel-2.4'
[ganeti-local] / lib / confd / client.py
index 9bd9ef7..2ca2ed8 100644 (file)
@@ -1,7 +1,7 @@
 #
 #
 
-# Copyright (C) 2009 Google Inc.
+# Copyright (C) 2009, 2010 Google Inc.
 #
 # This program is free software; you can redistribute it and/or modify
 # it under the terms of the GNU General Public License as published by
@@ -61,6 +61,8 @@ from ganeti import daemon # contains AsyncUDPSocket
 from ganeti import errors
 from ganeti import confd
 from ganeti import ssconf
+from ganeti import compat
+from ganeti import netutils
 
 
 class ConfdAsyncUDPClient(daemon.AsyncUDPSocket):
@@ -70,14 +72,14 @@ class ConfdAsyncUDPClient(daemon.AsyncUDPSocket):
   implement a non-asyncore based client library.
 
   """
-  def __init__(self, client):
+  def __init__(self, client, family):
     """Constructor for ConfdAsyncUDPClient
 
     @type client: L{ConfdClient}
     @param client: client library, to pass the datagrams to
 
     """
-    daemon.AsyncUDPSocket.__init__(self)
+    daemon.AsyncUDPSocket.__init__(self, family)
     self.client = client
 
   # this method is overriding a daemon.AsyncUDPSocket method
@@ -135,15 +137,16 @@ class ConfdClient:
       raise errors.ProgrammerError("callback must be callable")
 
     self.UpdatePeerList(peers)
+    self._SetPeersAddressFamily()
     self._hmac_key = hmac_key
-    self._socket = ConfdAsyncUDPClient(self)
+    self._socket = ConfdAsyncUDPClient(self, self._family)
     self._callback = callback
     self._confd_port = port
     self._logger = logger
     self._requests = {}
 
     if self._confd_port is None:
-      self._confd_port = utils.GetDaemonPort(constants.CONFD)
+      self._confd_port = netutils.GetDaemonPort(constants.CONFD)
 
   def UpdatePeerList(self, peers):
     """Update the list of peers
@@ -383,6 +386,18 @@ class ConfdClient:
       else:
         return MISSING
 
+  def _SetPeersAddressFamily(self):
+    if not self._peers:
+      raise errors.ConfdClientError("Peer list empty")
+    try:
+      peer = self._peers[0]
+      self._family = netutils.IPAddress.GetAddressFamily(peer)
+      for peer in self._peers[1:]:
+        if netutils.IPAddress.GetAddressFamily(peer) != self._family:
+          raise errors.ConfdClientError("Peers must be of same address family")
+    except errors.IPAddressError:
+      raise errors.ConfdClientError("Peer address %s invalid" % peer)
+
 
 # UPCALL_REPLY: server reply upcall
 # has all ConfdUpcallPayload fields populated
@@ -581,7 +596,7 @@ class ConfdCountingCallback:
     """Have all the registered queries received at least an answer?
 
     """
-    return utils.all(self._answers.values())
+    return compat.all(self._answers.values())
 
   def _HandleExpire(self, up):
     # if we have no answer we have received none, before the expiration.