Merge branch 'devel-2.4'
[ganeti-local] / lib / confd / client.py
index e08141b..2ca2ed8 100644 (file)
@@ -1,7 +1,7 @@
 #
 #
 
-# Copyright (C) 2009 Google Inc.
+# Copyright (C) 2009, 2010 Google Inc.
 #
 # This program is free software; you can redistribute it and/or modify
 # it under the terms of the GNU General Public License as published by
@@ -61,6 +61,8 @@ from ganeti import daemon # contains AsyncUDPSocket
 from ganeti import errors
 from ganeti import confd
 from ganeti import ssconf
+from ganeti import compat
+from ganeti import netutils
 
 
 class ConfdAsyncUDPClient(daemon.AsyncUDPSocket):
@@ -70,14 +72,14 @@ class ConfdAsyncUDPClient(daemon.AsyncUDPSocket):
   implement a non-asyncore based client library.
 
   """
-  def __init__(self, client):
+  def __init__(self, client, family):
     """Constructor for ConfdAsyncUDPClient
 
     @type client: L{ConfdClient}
     @param client: client library, to pass the datagrams to
 
     """
-    daemon.AsyncUDPSocket.__init__(self)
+    daemon.AsyncUDPSocket.__init__(self, family)
     self.client = client
 
   # this method is overriding a daemon.AsyncUDPSocket method
@@ -135,15 +137,16 @@ class ConfdClient:
       raise errors.ProgrammerError("callback must be callable")
 
     self.UpdatePeerList(peers)
+    self._SetPeersAddressFamily()
     self._hmac_key = hmac_key
-    self._socket = ConfdAsyncUDPClient(self)
+    self._socket = ConfdAsyncUDPClient(self, self._family)
     self._callback = callback
     self._confd_port = port
     self._logger = logger
     self._requests = {}
 
     if self._confd_port is None:
-      self._confd_port = utils.GetDaemonPort(constants.CONFD)
+      self._confd_port = netutils.GetDaemonPort(constants.CONFD)
 
   def UpdatePeerList(self, peers):
     """Update the list of peers
@@ -194,7 +197,7 @@ class ConfdClient:
                                           )
         self._callback(client_reply)
 
-  def SendRequest(self, request, args=None, coverage=None, async=True):
+  def SendRequest(self, request, args=None, coverage=0, async=True):
     """Send a confd request to some MCs
 
     @type request: L{objects.ConfdRequest}
@@ -202,13 +205,19 @@ class ConfdClient:
     @type args: tuple
     @param args: additional callback arguments
     @type coverage: integer
-    @param coverage: number of remote nodes to contact
+    @param coverage: number of remote nodes to contact; if default
+        (0), it will use a reasonable default
+        (L{ganeti.constants.CONFD_DEFAULT_REQ_COVERAGE}), if -1 is
+        passed, it will use the maximum number of peers, otherwise the
+        number passed in will be used
     @type async: boolean
     @param async: handle the write asynchronously
 
     """
-    if coverage is None:
+    if coverage == 0:
       coverage = min(len(self._peers), constants.CONFD_DEFAULT_REQ_COVERAGE)
+    elif coverage == -1:
+      coverage = len(self._peers)
 
     if coverage > len(self._peers):
       raise errors.ConfdClientError("Not enough MCs known to provide the"
@@ -377,6 +386,18 @@ class ConfdClient:
       else:
         return MISSING
 
+  def _SetPeersAddressFamily(self):
+    if not self._peers:
+      raise errors.ConfdClientError("Peer list empty")
+    try:
+      peer = self._peers[0]
+      self._family = netutils.IPAddress.GetAddressFamily(peer)
+      for peer in self._peers[1:]:
+        if netutils.IPAddress.GetAddressFamily(peer) != self._family:
+          raise errors.ConfdClientError("Peers must be of same address family")
+    except errors.IPAddressError:
+      raise errors.ConfdClientError("Peer address %s invalid" % peer)
+
 
 # UPCALL_REPLY: server reply upcall
 # has all ConfdUpcallPayload fields populated
@@ -575,7 +596,7 @@ class ConfdCountingCallback:
     """Have all the registered queries received at least an answer?
 
     """
-    return utils.all(self._answers.values())
+    return compat.all(self._answers.values())
 
   def _HandleExpire(self, up):
     # if we have no answer we have received none, before the expiration.