Merge branch 'devel-2.4'
[ganeti-local] / lib / confd / client.py
index 47c0009..2ca2ed8 100644 (file)
@@ -1,7 +1,7 @@
 #
 #
 
 #
 #
 
-# Copyright (C) 2009 Google Inc.
+# Copyright (C) 2009, 2010 Google Inc.
 #
 # This program is free software; you can redistribute it and/or modify
 # it under the terms of the GNU General Public License as published by
 #
 # This program is free software; you can redistribute it and/or modify
 # it under the terms of the GNU General Public License as published by
@@ -61,6 +61,8 @@ from ganeti import daemon # contains AsyncUDPSocket
 from ganeti import errors
 from ganeti import confd
 from ganeti import ssconf
 from ganeti import errors
 from ganeti import confd
 from ganeti import ssconf
+from ganeti import compat
+from ganeti import netutils
 
 
 class ConfdAsyncUDPClient(daemon.AsyncUDPSocket):
 
 
 class ConfdAsyncUDPClient(daemon.AsyncUDPSocket):
@@ -70,14 +72,14 @@ class ConfdAsyncUDPClient(daemon.AsyncUDPSocket):
   implement a non-asyncore based client library.
 
   """
   implement a non-asyncore based client library.
 
   """
-  def __init__(self, client):
+  def __init__(self, client, family):
     """Constructor for ConfdAsyncUDPClient
 
     @type client: L{ConfdClient}
     @param client: client library, to pass the datagrams to
 
     """
     """Constructor for ConfdAsyncUDPClient
 
     @type client: L{ConfdClient}
     @param client: client library, to pass the datagrams to
 
     """
-    daemon.AsyncUDPSocket.__init__(self)
+    daemon.AsyncUDPSocket.__init__(self, family)
     self.client = client
 
   # this method is overriding a daemon.AsyncUDPSocket method
     self.client = client
 
   # this method is overriding a daemon.AsyncUDPSocket method
@@ -91,12 +93,16 @@ class _Request(object):
   @ivar request: the request data
   @ivar args: any extra arguments for the callback
   @ivar expiry: the expiry timestamp of the request
   @ivar request: the request data
   @ivar args: any extra arguments for the callback
   @ivar expiry: the expiry timestamp of the request
+  @ivar sent: the set of contacted peers
+  @ivar rcvd: the set of peers who replied
 
   """
 
   """
-  def __init__(self, request, args, expiry):
+  def __init__(self, request, args, expiry, sent):
     self.request = request
     self.args = args
     self.expiry = expiry
     self.request = request
     self.args = args
     self.expiry = expiry
+    self.sent = frozenset(sent)
+    self.rcvd = set()
 
 
 class ConfdClient:
 
 
 class ConfdClient:
@@ -131,15 +137,16 @@ class ConfdClient:
       raise errors.ProgrammerError("callback must be callable")
 
     self.UpdatePeerList(peers)
       raise errors.ProgrammerError("callback must be callable")
 
     self.UpdatePeerList(peers)
+    self._SetPeersAddressFamily()
     self._hmac_key = hmac_key
     self._hmac_key = hmac_key
-    self._socket = ConfdAsyncUDPClient(self)
+    self._socket = ConfdAsyncUDPClient(self, self._family)
     self._callback = callback
     self._confd_port = port
     self._logger = logger
     self._requests = {}
 
     if self._confd_port is None:
     self._callback = callback
     self._confd_port = port
     self._logger = logger
     self._requests = {}
 
     if self._confd_port is None:
-      self._confd_port = utils.GetDaemonPort(constants.CONFD)
+      self._confd_port = netutils.GetDaemonPort(constants.CONFD)
 
   def UpdatePeerList(self, peers):
     """Update the list of peers
 
   def UpdatePeerList(self, peers):
     """Update the list of peers
@@ -190,7 +197,7 @@ class ConfdClient:
                                           )
         self._callback(client_reply)
 
                                           )
         self._callback(client_reply)
 
-  def SendRequest(self, request, args=None, coverage=None, async=True):
+  def SendRequest(self, request, args=None, coverage=0, async=True):
     """Send a confd request to some MCs
 
     @type request: L{objects.ConfdRequest}
     """Send a confd request to some MCs
 
     @type request: L{objects.ConfdRequest}
@@ -198,13 +205,19 @@ class ConfdClient:
     @type args: tuple
     @param args: additional callback arguments
     @type coverage: integer
     @type args: tuple
     @param args: additional callback arguments
     @type coverage: integer
-    @param coverage: number of remote nodes to contact
+    @param coverage: number of remote nodes to contact; if default
+        (0), it will use a reasonable default
+        (L{ganeti.constants.CONFD_DEFAULT_REQ_COVERAGE}), if -1 is
+        passed, it will use the maximum number of peers, otherwise the
+        number passed in will be used
     @type async: boolean
     @param async: handle the write asynchronously
 
     """
     @type async: boolean
     @param async: handle the write asynchronously
 
     """
-    if coverage is None:
+    if coverage == 0:
       coverage = min(len(self._peers), constants.CONFD_DEFAULT_REQ_COVERAGE)
       coverage = min(len(self._peers), constants.CONFD_DEFAULT_REQ_COVERAGE)
+    elif coverage == -1:
+      coverage = len(self._peers)
 
     if coverage > len(self._peers):
       raise errors.ConfdClientError("Not enough MCs known to provide the"
 
     if coverage > len(self._peers):
       raise errors.ConfdClientError("Not enough MCs known to provide the"
@@ -233,7 +246,8 @@ class ConfdClient:
         raise errors.ConfdClientError("Request too big")
 
     expire_time = now + constants.CONFD_CLIENT_EXPIRE_TIMEOUT
         raise errors.ConfdClientError("Request too big")
 
     expire_time = now + constants.CONFD_CLIENT_EXPIRE_TIMEOUT
-    self._requests[request.rsalt] = _Request(request, args, expire_time)
+    self._requests[request.rsalt] = _Request(request, args, expire_time,
+                                             targets)
 
     if not async:
       self.FlushSendQueue()
 
     if not async:
       self.FlushSendQueue()
@@ -259,6 +273,8 @@ class ConfdClient:
           self._logger.debug("Discarding unknown (expired?) reply: %s" % err)
         return
 
           self._logger.debug("Discarding unknown (expired?) reply: %s" % err)
         return
 
+      rq.rcvd.add(ip)
+
       client_reply = ConfdUpcallPayload(salt=salt,
                                         type=UPCALL_REPLY,
                                         server_reply=answer,
       client_reply = ConfdUpcallPayload(salt=salt,
                                         type=UPCALL_REPLY,
                                         server_reply=answer,
@@ -293,6 +309,95 @@ class ConfdClient:
     """
     return self._socket.process_next_packet(timeout=timeout)
 
     """
     return self._socket.process_next_packet(timeout=timeout)
 
+  @staticmethod
+  def _NeededReplies(peer_cnt):
+    """Compute the minimum safe number of replies for a query.
+
+    The algorithm is designed to work well for both small and big
+    number of peers:
+        - for less than three, we require all responses
+        - for less than five, we allow one miss
+        - otherwise, half the number plus one
+
+    This guarantees that we progress monotonically: 1->1, 2->2, 3->2,
+    4->2, 5->3, 6->3, 7->4, etc.
+
+    @type peer_cnt: int
+    @param peer_cnt: the number of peers contacted
+    @rtype: int
+    @return: the number of replies which should give a safe coverage
+
+    """
+    if peer_cnt < 3:
+      return peer_cnt
+    elif peer_cnt < 5:
+      return peer_cnt - 1
+    else:
+      return int(peer_cnt/2) + 1
+
+  def WaitForReply(self, salt, timeout=constants.CONFD_CLIENT_EXPIRE_TIMEOUT):
+    """Wait for replies to a given request.
+
+    This method will wait until either the timeout expires or a
+    minimum number (computed using L{_NeededReplies}) of replies are
+    received for the given salt. It is useful when doing synchronous
+    calls to this library.
+
+    @param salt: the salt of the request we want responses for
+    @param timeout: the maximum timeout (should be less or equal to
+        L{ganeti.constants.CONFD_CLIENT_EXPIRE_TIMEOUT}
+    @rtype: tuple
+    @return: a tuple of (timed_out, sent_cnt, recv_cnt); if the
+        request is unknown, timed_out will be true and the counters
+        will be zero
+
+    """
+    def _CheckResponse():
+      if salt not in self._requests:
+        # expired?
+        if self._logger:
+          self._logger.debug("Discarding unknown/expired request: %s" % salt)
+        return MISSING
+      rq = self._requests[salt]
+      if len(rq.rcvd) >= expected:
+        # already got all replies
+        return (False, len(rq.sent), len(rq.rcvd))
+      # else wait, using default timeout
+      self.ReceiveReply()
+      raise utils.RetryAgain()
+
+    MISSING = (True, 0, 0)
+
+    if salt not in self._requests:
+      return MISSING
+    # extend the expire time with the current timeout, so that we
+    # don't get the request expired from under us
+    rq = self._requests[salt]
+    rq.expiry += timeout
+    sent = len(rq.sent)
+    expected = self._NeededReplies(sent)
+
+    try:
+      return utils.Retry(_CheckResponse, 0, timeout)
+    except utils.RetryTimeout:
+      if salt in self._requests:
+        rq = self._requests[salt]
+        return (True, len(rq.sent), len(rq.rcvd))
+      else:
+        return MISSING
+
+  def _SetPeersAddressFamily(self):
+    if not self._peers:
+      raise errors.ConfdClientError("Peer list empty")
+    try:
+      peer = self._peers[0]
+      self._family = netutils.IPAddress.GetAddressFamily(peer)
+      for peer in self._peers[1:]:
+        if netutils.IPAddress.GetAddressFamily(peer) != self._family:
+          raise errors.ConfdClientError("Peers must be of same address family")
+    except errors.IPAddressError:
+      raise errors.ConfdClientError("Peer address %s invalid" % peer)
+
 
 # UPCALL_REPLY: server reply upcall
 # has all ConfdUpcallPayload fields populated
 
 # UPCALL_REPLY: server reply upcall
 # has all ConfdUpcallPayload fields populated
@@ -491,7 +596,7 @@ class ConfdCountingCallback:
     """Have all the registered queries received at least an answer?
 
     """
     """Have all the registered queries received at least an answer?
 
     """
-    return utils.all(self._answers.values())
+    return compat.all(self._answers.values())
 
   def _HandleExpire(self, up):
     # if we have no answer we have received none, before the expiration.
 
   def _HandleExpire(self, up):
     # if we have no answer we have received none, before the expiration.
@@ -523,6 +628,53 @@ class ConfdCountingCallback:
     self._callback(up)
 
 
     self._callback(up)
 
 
+class StoreResultCallback:
+  """Callback that simply stores the most recent answer.
+
+  @ivar _answers: dict of salt to (have_answer, reply)
+
+  """
+  _NO_KEY = (False, None)
+
+  def __init__(self):
+    """Constructor for StoreResultCallback
+
+    """
+    # answers contains a dict of salt -> best result
+    self._answers = {}
+
+  def GetResponse(self, salt):
+    """Return the best match for a salt
+
+    """
+    return self._answers.get(salt, self._NO_KEY)
+
+  def _HandleExpire(self, up):
+    """Expiration handler.
+
+    """
+    if up.salt in self._answers and self._answers[up.salt] == self._NO_KEY:
+      del self._answers[up.salt]
+
+  def _HandleReply(self, up):
+    """Handle a single confd reply, and decide whether to filter it.
+
+    """
+    self._answers[up.salt] = (True, up)
+
+  def __call__(self, up):
+    """Filtering callback
+
+    @type up: L{ConfdUpcallPayload}
+    @param up: upper callback
+
+    """
+    if up.type == UPCALL_REPLY:
+      self._HandleReply(up)
+    elif up.type == UPCALL_EXPIRE:
+      self._HandleExpire(up)
+
+
 def GetConfdClient(callback):
   """Return a client configured using the given callback.
 
 def GetConfdClient(callback):
   """Return a client configured using the given callback.