Add a special lexer for sphinx/pygments
[ganeti-local] / lib / confd / client.py
index e944431..900b5f7 100644 (file)
@@ -1,7 +1,7 @@
 #
 #
 
-# Copyright (C) 2009 Google Inc.
+# Copyright (C) 2009, 2010 Google Inc.
 #
 # This program is free software; you can redistribute it and/or modify
 # it under the terms of the GNU General Public License as published by
@@ -29,7 +29,8 @@ This way the client library doesn't ever need to "wait" on a particular answer,
 and can proceed even if some udp packets are lost. It's up to the user to
 reschedule queries if they haven't received responses and they need them.
 
-Example usage:
+Example usage::
+
   client = ConfdClient(...) # includes callback specification
   req = confd_client.ConfdClientRequest(type=constants.CONFD_REQ_PING)
   client.SendRequest(req)
@@ -43,7 +44,12 @@ You can use the provided ConfdFilterCallback to act as a filter, only passing
 confirming what you already got.
 
 """
-import socket
+
+# pylint: disable=E0203
+
+# E0203: Access to member %r before its definition, since we use
+# objects.py which doesn't explicitely initialise its members
+
 import time
 import random
 
@@ -54,6 +60,9 @@ from ganeti import serializer
 from ganeti import daemon # contains AsyncUDPSocket
 from ganeti import errors
 from ganeti import confd
+from ganeti import ssconf
+from ganeti import compat
+from ganeti import netutils
 
 
 class ConfdAsyncUDPClient(daemon.AsyncUDPSocket):
@@ -63,14 +72,14 @@ class ConfdAsyncUDPClient(daemon.AsyncUDPSocket):
   implement a non-asyncore based client library.
 
   """
-  def __init__(self, client):
+  def __init__(self, client, family):
     """Constructor for ConfdAsyncUDPClient
 
     @type client: L{ConfdClient}
     @param client: client library, to pass the datagrams to
 
     """
-    daemon.AsyncUDPSocket.__init__(self)
+    daemon.AsyncUDPSocket.__init__(self, family)
     self.client = client
 
   # this method is overriding a daemon.AsyncUDPSocket method
@@ -78,6 +87,24 @@ class ConfdAsyncUDPClient(daemon.AsyncUDPSocket):
     self.client.HandleResponse(payload, ip, port)
 
 
+class _Request(object):
+  """Request status structure.
+
+  @ivar request: the request data
+  @ivar args: any extra arguments for the callback
+  @ivar expiry: the expiry timestamp of the request
+  @ivar sent: the set of contacted peers
+  @ivar rcvd: the set of peers who replied
+
+  """
+  def __init__(self, request, args, expiry, sent):
+    self.request = request
+    self.args = args
+    self.expiry = expiry
+    self.sent = frozenset(sent)
+    self.rcvd = set()
+
+
 class ConfdClient:
   """Send queries to confd, and get back answers.
 
@@ -85,6 +112,11 @@ class ConfdClient:
   getting back answers, this is an asynchronous library. It can either work
   through asyncore or with your own handling.
 
+  @type _requests: dict
+  @ivar _requests: dictionary indexes by salt, which contains data
+      about the outstanding requests; the values are objects of type
+      L{_Request}
+
   """
   def __init__(self, hmac_key, peers, callback, port=None, logger=None):
     """Constructor for ConfdClient
@@ -96,25 +128,25 @@ class ConfdClient:
     @type callback: f(L{ConfdUpcallPayload})
     @param callback: function to call when getting answers
     @type port: integer
-    @keyword port: confd port (default: use GetDaemonPort)
-    @type logger: L{logging.Logger}
-    @keyword logger: optional logger for internal conditions
+    @param port: confd port (default: use GetDaemonPort)
+    @type logger: logging.Logger
+    @param logger: optional logger for internal conditions
 
     """
     if not callable(callback):
       raise errors.ProgrammerError("callback must be callable")
 
     self.UpdatePeerList(peers)
+    self._SetPeersAddressFamily()
     self._hmac_key = hmac_key
-    self._socket = ConfdAsyncUDPClient(self)
+    self._socket = ConfdAsyncUDPClient(self, self._family)
     self._callback = callback
     self._confd_port = port
     self._logger = logger
     self._requests = {}
-    self._expire_requests = []
 
     if self._confd_port is None:
-      self._confd_port = utils.GetDaemonPort(constants.CONFD)
+      self._confd_port = netutils.GetDaemonPort(constants.CONFD)
 
   def UpdatePeerList(self, peers):
     """Update the list of peers
@@ -123,9 +155,12 @@ class ConfdClient:
     @param peers: list of peer nodes
 
     """
+    # we are actually called from init, so:
+    # pylint: disable=W0201
     if not isinstance(peers, list):
       raise errors.ProgrammerError("peers must be a list")
-    self._peers = peers
+    # make a copy of peers, since we're going to shuffle the list, later
+    self._peers = list(peers)
 
   def _PackRequest(self, request, now=None):
     """Prepare a request to be sent on the wire.
@@ -136,7 +171,7 @@ class ConfdClient:
     """
     if now is None:
       now = time.time()
-    tstamp = '%d' % now
+    tstamp = "%d" % now
     req = serializer.DumpSignedJson(request.ToDict(), self._hmac_key, tstamp)
     return confd.PackMagic(req)
 
@@ -151,35 +186,38 @@ class ConfdClient:
 
     """
     now = time.time()
-    while self._expire_requests:
-      expire_time, rsalt = self._expire_requests[0]
-      if now >= expire_time:
-        self._expire_requests.pop(0)
-        (request, args) = self._requests[rsalt]
+    for rsalt, rq in self._requests.items():
+      if now >= rq.expiry:
         del self._requests[rsalt]
         client_reply = ConfdUpcallPayload(salt=rsalt,
                                           type=UPCALL_EXPIRE,
-                                          orig_request=request,
-                                          extra_args=args,
+                                          orig_request=rq.request,
+                                          extra_args=rq.args,
                                           client=self,
                                           )
         self._callback(client_reply)
-      else:
-        break
 
-  def SendRequest(self, request, args=None, coverage=None):
+  def SendRequest(self, request, args=None, coverage=0, async=True):
     """Send a confd request to some MCs
 
     @type request: L{objects.ConfdRequest}
     @param request: the request to send
     @type args: tuple
-    @keyword args: additional callback arguments
+    @param args: additional callback arguments
     @type coverage: integer
-    @keyword coverage: number of remote nodes to contact
+    @param coverage: number of remote nodes to contact; if default
+        (0), it will use a reasonable default
+        (L{ganeti.constants.CONFD_DEFAULT_REQ_COVERAGE}), if -1 is
+        passed, it will use the maximum number of peers, otherwise the
+        number passed in will be used
+    @type async: boolean
+    @param async: handle the write asynchronously
 
     """
-    if coverage is None:
+    if coverage == 0:
       coverage = min(len(self._peers), constants.CONFD_DEFAULT_REQ_COVERAGE)
+    elif coverage == -1:
+      coverage = len(self._peers)
 
     if coverage > len(self._peers):
       raise errors.ConfdClientError("Not enough MCs known to provide the"
@@ -207,9 +245,12 @@ class ConfdClient:
       except errors.UdpDataSizeError:
         raise errors.ConfdClientError("Request too big")
 
-    self._requests[request.rsalt] = (request, args)
     expire_time = now + constants.CONFD_CLIENT_EXPIRE_TIMEOUT
-    self._expire_requests.append((expire_time, request.rsalt))
+    self._requests[request.rsalt] = _Request(request, args, expire_time,
+                                             targets)
+
+    if not async:
+      self.FlushSendQueue()
 
   def HandleResponse(self, payload, ip, port):
     """Asynchronous handler for a confd reply
@@ -226,19 +267,21 @@ class ConfdClient:
         return
 
       try:
-        (request, args) = self._requests[salt]
+        rq = self._requests[salt]
       except KeyError:
         if self._logger:
           self._logger.debug("Discarding unknown (expired?) reply: %s" % err)
         return
 
+      rq.rcvd.add(ip)
+
       client_reply = ConfdUpcallPayload(salt=salt,
                                         type=UPCALL_REPLY,
                                         server_reply=answer,
-                                        orig_request=request,
+                                        orig_request=rq.request,
                                         server_ip=ip,
                                         server_port=port,
-                                        extra_args=args,
+                                        extra_args=rq.args,
                                         client=self,
                                        )
       self._callback(client_reply)
@@ -246,6 +289,115 @@ class ConfdClient:
     finally:
       self.ExpireRequests()
 
+  def FlushSendQueue(self):
+    """Send out all pending requests.
+
+    Can be used for synchronous client use.
+
+    """
+    while self._socket.writable():
+      self._socket.handle_write()
+
+  def ReceiveReply(self, timeout=1):
+    """Receive one reply.
+
+    @type timeout: float
+    @param timeout: how long to wait for the reply
+    @rtype: boolean
+    @return: True if some data has been handled, False otherwise
+
+    """
+    return self._socket.process_next_packet(timeout=timeout)
+
+  @staticmethod
+  def _NeededReplies(peer_cnt):
+    """Compute the minimum safe number of replies for a query.
+
+    The algorithm is designed to work well for both small and big
+    number of peers:
+        - for less than three, we require all responses
+        - for less than five, we allow one miss
+        - otherwise, half the number plus one
+
+    This guarantees that we progress monotonically: 1->1, 2->2, 3->2,
+    4->2, 5->3, 6->3, 7->4, etc.
+
+    @type peer_cnt: int
+    @param peer_cnt: the number of peers contacted
+    @rtype: int
+    @return: the number of replies which should give a safe coverage
+
+    """
+    if peer_cnt < 3:
+      return peer_cnt
+    elif peer_cnt < 5:
+      return peer_cnt - 1
+    else:
+      return int(peer_cnt / 2) + 1
+
+  def WaitForReply(self, salt, timeout=constants.CONFD_CLIENT_EXPIRE_TIMEOUT):
+    """Wait for replies to a given request.
+
+    This method will wait until either the timeout expires or a
+    minimum number (computed using L{_NeededReplies}) of replies are
+    received for the given salt. It is useful when doing synchronous
+    calls to this library.
+
+    @param salt: the salt of the request we want responses for
+    @param timeout: the maximum timeout (should be less or equal to
+        L{ganeti.constants.CONFD_CLIENT_EXPIRE_TIMEOUT}
+    @rtype: tuple
+    @return: a tuple of (timed_out, sent_cnt, recv_cnt); if the
+        request is unknown, timed_out will be true and the counters
+        will be zero
+
+    """
+    def _CheckResponse():
+      if salt not in self._requests:
+        # expired?
+        if self._logger:
+          self._logger.debug("Discarding unknown/expired request: %s" % salt)
+        return MISSING
+      rq = self._requests[salt]
+      if len(rq.rcvd) >= expected:
+        # already got all replies
+        return (False, len(rq.sent), len(rq.rcvd))
+      # else wait, using default timeout
+      self.ReceiveReply()
+      raise utils.RetryAgain()
+
+    MISSING = (True, 0, 0)
+
+    if salt not in self._requests:
+      return MISSING
+    # extend the expire time with the current timeout, so that we
+    # don't get the request expired from under us
+    rq = self._requests[salt]
+    rq.expiry += timeout
+    sent = len(rq.sent)
+    expected = self._NeededReplies(sent)
+
+    try:
+      return utils.Retry(_CheckResponse, 0, timeout)
+    except utils.RetryTimeout:
+      if salt in self._requests:
+        rq = self._requests[salt]
+        return (True, len(rq.sent), len(rq.rcvd))
+      else:
+        return MISSING
+
+  def _SetPeersAddressFamily(self):
+    if not self._peers:
+      raise errors.ConfdClientError("Peer list empty")
+    try:
+      peer = self._peers[0]
+      self._family = netutils.IPAddress.GetAddressFamily(peer)
+      for peer in self._peers[1:]:
+        if netutils.IPAddress.GetAddressFamily(peer) != self._family:
+          raise errors.ConfdClientError("Peers must be of same address family")
+    except errors.IPAddressError:
+      raise errors.ConfdClientError("Peer address %s invalid" % peer)
+
 
 # UPCALL_REPLY: server reply upcall
 # has all ConfdUpcallPayload fields populated
@@ -312,14 +464,20 @@ class ConfdClientRequest(objects.ConfdRequest):
 class ConfdFilterCallback:
   """Callback that calls another callback, but filters duplicate results.
 
+  @ivar consistent: a dictionary indexed by salt; for each salt, if
+      all responses ware identical, this will be True; this is the
+      expected state on a healthy cluster; on inconsistent or
+      partitioned clusters, this might be False, if we see answers
+      with the same serial but different contents
+
   """
   def __init__(self, callback, logger=None):
     """Constructor for ConfdFilterCallback
 
     @type callback: f(L{ConfdUpcallPayload})
     @param callback: function to call when getting answers
-    @type logger: L{logging.Logger}
-    @keyword logger: optional logger for internal conditions
+    @type logger: logging.Logger
+    @param logger: optional logger for internal conditions
 
     """
     if not callable(callback):
@@ -329,6 +487,7 @@ class ConfdFilterCallback:
     self._logger = logger
     # answers contains a dict of salt -> answer
     self._answers = {}
+    self.consistent = {}
 
   def _LogFilter(self, salt, new_reply, old_reply):
     if not self._logger:
@@ -353,6 +512,8 @@ class ConfdFilterCallback:
     # if we have no answer we have received none, before the expiration.
     if up.salt in self._answers:
       del self._answers[up.salt]
+    if up.salt in self.consistent:
+      del self.consistent[up.salt]
 
   def _HandleReply(self, up):
     """Handle a single confd reply, and decide whether to filter it.
@@ -364,6 +525,8 @@ class ConfdFilterCallback:
     """
     filter_upcall = False
     salt = up.salt
+    if salt not in self.consistent:
+      self.consistent[salt] = True
     if salt not in self._answers:
       # first answer for a query (don't filter, and record)
       self._answers[salt] = up.server_reply
@@ -378,6 +541,9 @@ class ConfdFilterCallback:
       # else: different content, pass up a second answer
     else:
       # older or same-version answer (duplicate or outdated, filter)
+      if (up.server_reply.serial == self._answers[salt].serial and
+          up.server_reply.answer != self._answers[salt].answer):
+        self.consistent[salt] = False
       filter_upcall = True
       self._LogFilter(salt, up.server_reply, self._answers[salt])
 
@@ -399,3 +565,129 @@ class ConfdFilterCallback:
     if not filter_upcall:
       self._callback(up)
 
+
+class ConfdCountingCallback:
+  """Callback that calls another callback, and counts the answers
+
+  """
+  def __init__(self, callback, logger=None):
+    """Constructor for ConfdCountingCallback
+
+    @type callback: f(L{ConfdUpcallPayload})
+    @param callback: function to call when getting answers
+    @type logger: logging.Logger
+    @param logger: optional logger for internal conditions
+
+    """
+    if not callable(callback):
+      raise errors.ProgrammerError("callback must be callable")
+
+    self._callback = callback
+    self._logger = logger
+    # answers contains a dict of salt -> count
+    self._answers = {}
+
+  def RegisterQuery(self, salt):
+    if salt in self._answers:
+      raise errors.ProgrammerError("query already registered")
+    self._answers[salt] = 0
+
+  def AllAnswered(self):
+    """Have all the registered queries received at least an answer?
+
+    """
+    return compat.all(self._answers.values())
+
+  def _HandleExpire(self, up):
+    # if we have no answer we have received none, before the expiration.
+    if up.salt in self._answers:
+      del self._answers[up.salt]
+
+  def _HandleReply(self, up):
+    """Handle a single confd reply, and decide whether to filter it.
+
+    @rtype: boolean
+    @return: True if the reply should be filtered, False if it should be passed
+             on to the up-callback
+
+    """
+    if up.salt in self._answers:
+      self._answers[up.salt] += 1
+
+  def __call__(self, up):
+    """Filtering callback
+
+    @type up: L{ConfdUpcallPayload}
+    @param up: upper callback
+
+    """
+    if up.type == UPCALL_REPLY:
+      self._HandleReply(up)
+    elif up.type == UPCALL_EXPIRE:
+      self._HandleExpire(up)
+    self._callback(up)
+
+
+class StoreResultCallback:
+  """Callback that simply stores the most recent answer.
+
+  @ivar _answers: dict of salt to (have_answer, reply)
+
+  """
+  _NO_KEY = (False, None)
+
+  def __init__(self):
+    """Constructor for StoreResultCallback
+
+    """
+    # answers contains a dict of salt -> best result
+    self._answers = {}
+
+  def GetResponse(self, salt):
+    """Return the best match for a salt
+
+    """
+    return self._answers.get(salt, self._NO_KEY)
+
+  def _HandleExpire(self, up):
+    """Expiration handler.
+
+    """
+    if up.salt in self._answers and self._answers[up.salt] == self._NO_KEY:
+      del self._answers[up.salt]
+
+  def _HandleReply(self, up):
+    """Handle a single confd reply, and decide whether to filter it.
+
+    """
+    self._answers[up.salt] = (True, up)
+
+  def __call__(self, up):
+    """Filtering callback
+
+    @type up: L{ConfdUpcallPayload}
+    @param up: upper callback
+
+    """
+    if up.type == UPCALL_REPLY:
+      self._HandleReply(up)
+    elif up.type == UPCALL_EXPIRE:
+      self._HandleExpire(up)
+
+
+def GetConfdClient(callback):
+  """Return a client configured using the given callback.
+
+  This is handy to abstract the MC list and HMAC key reading.
+
+  @attention: This should only be called on nodes which are part of a
+      cluster, since it depends on a valid (ganeti) data directory;
+      for code running outside of a cluster, you need to create the
+      client manually
+
+  """
+  ss = ssconf.SimpleStore()
+  mc_file = ss.KeyToFilename(constants.SS_MASTER_CANDIDATES_IPS)
+  mc_list = utils.ReadFile(mc_file).splitlines()
+  hmac_key = utils.ReadFile(constants.CONFD_HMAC_KEY)
+  return ConfdClient(hmac_key, mc_list, callback)