Merge branch 'devel-2.4'
[ganeti-local] / lib / confd / client.py
index efcb68d..2ca2ed8 100644 (file)
@@ -1,7 +1,7 @@
 #
 #
 
 #
 #
 
-# Copyright (C) 2009 Google Inc.
+# Copyright (C) 2009, 2010 Google Inc.
 #
 # This program is free software; you can redistribute it and/or modify
 # it under the terms of the GNU General Public License as published by
 #
 # This program is free software; you can redistribute it and/or modify
 # it under the terms of the GNU General Public License as published by
@@ -60,6 +60,9 @@ from ganeti import serializer
 from ganeti import daemon # contains AsyncUDPSocket
 from ganeti import errors
 from ganeti import confd
 from ganeti import daemon # contains AsyncUDPSocket
 from ganeti import errors
 from ganeti import confd
+from ganeti import ssconf
+from ganeti import compat
+from ganeti import netutils
 
 
 class ConfdAsyncUDPClient(daemon.AsyncUDPSocket):
 
 
 class ConfdAsyncUDPClient(daemon.AsyncUDPSocket):
@@ -69,14 +72,14 @@ class ConfdAsyncUDPClient(daemon.AsyncUDPSocket):
   implement a non-asyncore based client library.
 
   """
   implement a non-asyncore based client library.
 
   """
-  def __init__(self, client):
+  def __init__(self, client, family):
     """Constructor for ConfdAsyncUDPClient
 
     @type client: L{ConfdClient}
     @param client: client library, to pass the datagrams to
 
     """
     """Constructor for ConfdAsyncUDPClient
 
     @type client: L{ConfdClient}
     @param client: client library, to pass the datagrams to
 
     """
-    daemon.AsyncUDPSocket.__init__(self)
+    daemon.AsyncUDPSocket.__init__(self, family)
     self.client = client
 
   # this method is overriding a daemon.AsyncUDPSocket method
     self.client = client
 
   # this method is overriding a daemon.AsyncUDPSocket method
@@ -84,6 +87,24 @@ class ConfdAsyncUDPClient(daemon.AsyncUDPSocket):
     self.client.HandleResponse(payload, ip, port)
 
 
     self.client.HandleResponse(payload, ip, port)
 
 
+class _Request(object):
+  """Request status structure.
+
+  @ivar request: the request data
+  @ivar args: any extra arguments for the callback
+  @ivar expiry: the expiry timestamp of the request
+  @ivar sent: the set of contacted peers
+  @ivar rcvd: the set of peers who replied
+
+  """
+  def __init__(self, request, args, expiry, sent):
+    self.request = request
+    self.args = args
+    self.expiry = expiry
+    self.sent = frozenset(sent)
+    self.rcvd = set()
+
+
 class ConfdClient:
   """Send queries to confd, and get back answers.
 
 class ConfdClient:
   """Send queries to confd, and get back answers.
 
@@ -91,6 +112,11 @@ class ConfdClient:
   getting back answers, this is an asynchronous library. It can either work
   through asyncore or with your own handling.
 
   getting back answers, this is an asynchronous library. It can either work
   through asyncore or with your own handling.
 
+  @type _requests: dict
+  @ivar _requests: dictionary indexes by salt, which contains data
+      about the outstanding requests; the values are objects of type
+      L{_Request}
+
   """
   def __init__(self, hmac_key, peers, callback, port=None, logger=None):
     """Constructor for ConfdClient
   """
   def __init__(self, hmac_key, peers, callback, port=None, logger=None):
     """Constructor for ConfdClient
@@ -111,16 +137,16 @@ class ConfdClient:
       raise errors.ProgrammerError("callback must be callable")
 
     self.UpdatePeerList(peers)
       raise errors.ProgrammerError("callback must be callable")
 
     self.UpdatePeerList(peers)
+    self._SetPeersAddressFamily()
     self._hmac_key = hmac_key
     self._hmac_key = hmac_key
-    self._socket = ConfdAsyncUDPClient(self)
+    self._socket = ConfdAsyncUDPClient(self, self._family)
     self._callback = callback
     self._confd_port = port
     self._logger = logger
     self._requests = {}
     self._callback = callback
     self._confd_port = port
     self._logger = logger
     self._requests = {}
-    self._expire_requests = []
 
     if self._confd_port is None:
 
     if self._confd_port is None:
-      self._confd_port = utils.GetDaemonPort(constants.CONFD)
+      self._confd_port = netutils.GetDaemonPort(constants.CONFD)
 
   def UpdatePeerList(self, peers):
     """Update the list of peers
 
   def UpdatePeerList(self, peers):
     """Update the list of peers
@@ -160,23 +186,18 @@ class ConfdClient:
 
     """
     now = time.time()
 
     """
     now = time.time()
-    while self._expire_requests:
-      expire_time, rsalt = self._expire_requests[0]
-      if now >= expire_time:
-        self._expire_requests.pop(0)
-        (request, args) = self._requests[rsalt]
+    for rsalt, rq in self._requests.items():
+      if now >= rq.expiry:
         del self._requests[rsalt]
         client_reply = ConfdUpcallPayload(salt=rsalt,
                                           type=UPCALL_EXPIRE,
         del self._requests[rsalt]
         client_reply = ConfdUpcallPayload(salt=rsalt,
                                           type=UPCALL_EXPIRE,
-                                          orig_request=request,
-                                          extra_args=args,
+                                          orig_request=rq.request,
+                                          extra_args=rq.args,
                                           client=self,
                                           )
         self._callback(client_reply)
                                           client=self,
                                           )
         self._callback(client_reply)
-      else:
-        break
 
 
-  def SendRequest(self, request, args=None, coverage=None, async=True):
+  def SendRequest(self, request, args=None, coverage=0, async=True):
     """Send a confd request to some MCs
 
     @type request: L{objects.ConfdRequest}
     """Send a confd request to some MCs
 
     @type request: L{objects.ConfdRequest}
@@ -184,13 +205,19 @@ class ConfdClient:
     @type args: tuple
     @param args: additional callback arguments
     @type coverage: integer
     @type args: tuple
     @param args: additional callback arguments
     @type coverage: integer
-    @param coverage: number of remote nodes to contact
+    @param coverage: number of remote nodes to contact; if default
+        (0), it will use a reasonable default
+        (L{ganeti.constants.CONFD_DEFAULT_REQ_COVERAGE}), if -1 is
+        passed, it will use the maximum number of peers, otherwise the
+        number passed in will be used
     @type async: boolean
     @param async: handle the write asynchronously
 
     """
     @type async: boolean
     @param async: handle the write asynchronously
 
     """
-    if coverage is None:
+    if coverage == 0:
       coverage = min(len(self._peers), constants.CONFD_DEFAULT_REQ_COVERAGE)
       coverage = min(len(self._peers), constants.CONFD_DEFAULT_REQ_COVERAGE)
+    elif coverage == -1:
+      coverage = len(self._peers)
 
     if coverage > len(self._peers):
       raise errors.ConfdClientError("Not enough MCs known to provide the"
 
     if coverage > len(self._peers):
       raise errors.ConfdClientError("Not enough MCs known to provide the"
@@ -218,9 +245,9 @@ class ConfdClient:
       except errors.UdpDataSizeError:
         raise errors.ConfdClientError("Request too big")
 
       except errors.UdpDataSizeError:
         raise errors.ConfdClientError("Request too big")
 
-    self._requests[request.rsalt] = (request, args)
     expire_time = now + constants.CONFD_CLIENT_EXPIRE_TIMEOUT
     expire_time = now + constants.CONFD_CLIENT_EXPIRE_TIMEOUT
-    self._expire_requests.append((expire_time, request.rsalt))
+    self._requests[request.rsalt] = _Request(request, args, expire_time,
+                                             targets)
 
     if not async:
       self.FlushSendQueue()
 
     if not async:
       self.FlushSendQueue()
@@ -240,19 +267,21 @@ class ConfdClient:
         return
 
       try:
         return
 
       try:
-        (request, args) = self._requests[salt]
+        rq = self._requests[salt]
       except KeyError:
         if self._logger:
           self._logger.debug("Discarding unknown (expired?) reply: %s" % err)
         return
 
       except KeyError:
         if self._logger:
           self._logger.debug("Discarding unknown (expired?) reply: %s" % err)
         return
 
+      rq.rcvd.add(ip)
+
       client_reply = ConfdUpcallPayload(salt=salt,
                                         type=UPCALL_REPLY,
                                         server_reply=answer,
       client_reply = ConfdUpcallPayload(salt=salt,
                                         type=UPCALL_REPLY,
                                         server_reply=answer,
-                                        orig_request=request,
+                                        orig_request=rq.request,
                                         server_ip=ip,
                                         server_port=port,
                                         server_ip=ip,
                                         server_port=port,
-                                        extra_args=args,
+                                        extra_args=rq.args,
                                         client=self,
                                        )
       self._callback(client_reply)
                                         client=self,
                                        )
       self._callback(client_reply)
@@ -280,6 +309,95 @@ class ConfdClient:
     """
     return self._socket.process_next_packet(timeout=timeout)
 
     """
     return self._socket.process_next_packet(timeout=timeout)
 
+  @staticmethod
+  def _NeededReplies(peer_cnt):
+    """Compute the minimum safe number of replies for a query.
+
+    The algorithm is designed to work well for both small and big
+    number of peers:
+        - for less than three, we require all responses
+        - for less than five, we allow one miss
+        - otherwise, half the number plus one
+
+    This guarantees that we progress monotonically: 1->1, 2->2, 3->2,
+    4->2, 5->3, 6->3, 7->4, etc.
+
+    @type peer_cnt: int
+    @param peer_cnt: the number of peers contacted
+    @rtype: int
+    @return: the number of replies which should give a safe coverage
+
+    """
+    if peer_cnt < 3:
+      return peer_cnt
+    elif peer_cnt < 5:
+      return peer_cnt - 1
+    else:
+      return int(peer_cnt/2) + 1
+
+  def WaitForReply(self, salt, timeout=constants.CONFD_CLIENT_EXPIRE_TIMEOUT):
+    """Wait for replies to a given request.
+
+    This method will wait until either the timeout expires or a
+    minimum number (computed using L{_NeededReplies}) of replies are
+    received for the given salt. It is useful when doing synchronous
+    calls to this library.
+
+    @param salt: the salt of the request we want responses for
+    @param timeout: the maximum timeout (should be less or equal to
+        L{ganeti.constants.CONFD_CLIENT_EXPIRE_TIMEOUT}
+    @rtype: tuple
+    @return: a tuple of (timed_out, sent_cnt, recv_cnt); if the
+        request is unknown, timed_out will be true and the counters
+        will be zero
+
+    """
+    def _CheckResponse():
+      if salt not in self._requests:
+        # expired?
+        if self._logger:
+          self._logger.debug("Discarding unknown/expired request: %s" % salt)
+        return MISSING
+      rq = self._requests[salt]
+      if len(rq.rcvd) >= expected:
+        # already got all replies
+        return (False, len(rq.sent), len(rq.rcvd))
+      # else wait, using default timeout
+      self.ReceiveReply()
+      raise utils.RetryAgain()
+
+    MISSING = (True, 0, 0)
+
+    if salt not in self._requests:
+      return MISSING
+    # extend the expire time with the current timeout, so that we
+    # don't get the request expired from under us
+    rq = self._requests[salt]
+    rq.expiry += timeout
+    sent = len(rq.sent)
+    expected = self._NeededReplies(sent)
+
+    try:
+      return utils.Retry(_CheckResponse, 0, timeout)
+    except utils.RetryTimeout:
+      if salt in self._requests:
+        rq = self._requests[salt]
+        return (True, len(rq.sent), len(rq.rcvd))
+      else:
+        return MISSING
+
+  def _SetPeersAddressFamily(self):
+    if not self._peers:
+      raise errors.ConfdClientError("Peer list empty")
+    try:
+      peer = self._peers[0]
+      self._family = netutils.IPAddress.GetAddressFamily(peer)
+      for peer in self._peers[1:]:
+        if netutils.IPAddress.GetAddressFamily(peer) != self._family:
+          raise errors.ConfdClientError("Peers must be of same address family")
+    except errors.IPAddressError:
+      raise errors.ConfdClientError("Peer address %s invalid" % peer)
+
 
 # UPCALL_REPLY: server reply upcall
 # has all ConfdUpcallPayload fields populated
 
 # UPCALL_REPLY: server reply upcall
 # has all ConfdUpcallPayload fields populated
@@ -346,6 +464,12 @@ class ConfdClientRequest(objects.ConfdRequest):
 class ConfdFilterCallback:
   """Callback that calls another callback, but filters duplicate results.
 
 class ConfdFilterCallback:
   """Callback that calls another callback, but filters duplicate results.
 
+  @ivar consistent: a dictionary indexed by salt; for each salt, if
+      all responses ware identical, this will be True; this is the
+      expected state on a healthy cluster; on inconsistent or
+      partitioned clusters, this might be False, if we see answers
+      with the same serial but different contents
+
   """
   def __init__(self, callback, logger=None):
     """Constructor for ConfdFilterCallback
   """
   def __init__(self, callback, logger=None):
     """Constructor for ConfdFilterCallback
@@ -363,6 +487,7 @@ class ConfdFilterCallback:
     self._logger = logger
     # answers contains a dict of salt -> answer
     self._answers = {}
     self._logger = logger
     # answers contains a dict of salt -> answer
     self._answers = {}
+    self.consistent = {}
 
   def _LogFilter(self, salt, new_reply, old_reply):
     if not self._logger:
 
   def _LogFilter(self, salt, new_reply, old_reply):
     if not self._logger:
@@ -387,6 +512,8 @@ class ConfdFilterCallback:
     # if we have no answer we have received none, before the expiration.
     if up.salt in self._answers:
       del self._answers[up.salt]
     # if we have no answer we have received none, before the expiration.
     if up.salt in self._answers:
       del self._answers[up.salt]
+    if up.salt in self.consistent:
+      del self.consistent[up.salt]
 
   def _HandleReply(self, up):
     """Handle a single confd reply, and decide whether to filter it.
 
   def _HandleReply(self, up):
     """Handle a single confd reply, and decide whether to filter it.
@@ -398,6 +525,8 @@ class ConfdFilterCallback:
     """
     filter_upcall = False
     salt = up.salt
     """
     filter_upcall = False
     salt = up.salt
+    if salt not in self.consistent:
+      self.consistent[salt] = True
     if salt not in self._answers:
       # first answer for a query (don't filter, and record)
       self._answers[salt] = up.server_reply
     if salt not in self._answers:
       # first answer for a query (don't filter, and record)
       self._answers[salt] = up.server_reply
@@ -412,6 +541,9 @@ class ConfdFilterCallback:
       # else: different content, pass up a second answer
     else:
       # older or same-version answer (duplicate or outdated, filter)
       # else: different content, pass up a second answer
     else:
       # older or same-version answer (duplicate or outdated, filter)
+      if (up.server_reply.serial == self._answers[salt].serial and
+          up.server_reply.answer != self._answers[salt].answer):
+        self.consistent[salt] = False
       filter_upcall = True
       self._LogFilter(salt, up.server_reply, self._answers[salt])
 
       filter_upcall = True
       self._LogFilter(salt, up.server_reply, self._answers[salt])
 
@@ -464,7 +596,7 @@ class ConfdCountingCallback:
     """Have all the registered queries received at least an answer?
 
     """
     """Have all the registered queries received at least an answer?
 
     """
-    return utils.all(self._answers.values())
+    return compat.all(self._answers.values())
 
   def _HandleExpire(self, up):
     # if we have no answer we have received none, before the expiration.
 
   def _HandleExpire(self, up):
     # if we have no answer we have received none, before the expiration.
@@ -494,3 +626,68 @@ class ConfdCountingCallback:
     elif up.type == UPCALL_EXPIRE:
       self._HandleExpire(up)
     self._callback(up)
     elif up.type == UPCALL_EXPIRE:
       self._HandleExpire(up)
     self._callback(up)
+
+
+class StoreResultCallback:
+  """Callback that simply stores the most recent answer.
+
+  @ivar _answers: dict of salt to (have_answer, reply)
+
+  """
+  _NO_KEY = (False, None)
+
+  def __init__(self):
+    """Constructor for StoreResultCallback
+
+    """
+    # answers contains a dict of salt -> best result
+    self._answers = {}
+
+  def GetResponse(self, salt):
+    """Return the best match for a salt
+
+    """
+    return self._answers.get(salt, self._NO_KEY)
+
+  def _HandleExpire(self, up):
+    """Expiration handler.
+
+    """
+    if up.salt in self._answers and self._answers[up.salt] == self._NO_KEY:
+      del self._answers[up.salt]
+
+  def _HandleReply(self, up):
+    """Handle a single confd reply, and decide whether to filter it.
+
+    """
+    self._answers[up.salt] = (True, up)
+
+  def __call__(self, up):
+    """Filtering callback
+
+    @type up: L{ConfdUpcallPayload}
+    @param up: upper callback
+
+    """
+    if up.type == UPCALL_REPLY:
+      self._HandleReply(up)
+    elif up.type == UPCALL_EXPIRE:
+      self._HandleExpire(up)
+
+
+def GetConfdClient(callback):
+  """Return a client configured using the given callback.
+
+  This is handy to abstract the MC list and HMAC key reading.
+
+  @attention: This should only be called on nodes which are part of a
+      cluster, since it depends on a valid (ganeti) data directory;
+      for code running outside of a cluster, you need to create the
+      client manually
+
+  """
+  ss = ssconf.SimpleStore()
+  mc_file = ss.KeyToFilename(constants.SS_MASTER_CANDIDATES_IPS)
+  mc_list = utils.ReadFile(mc_file).splitlines()
+  hmac_key = utils.ReadFile(constants.CONFD_HMAC_KEY)
+  return ConfdClient(hmac_key, mc_list, callback)