Small improvements for cluster verify
[ganeti-local] / lib / locking.py
index be9cd13..9f24771 100644 (file)
@@ -1,7 +1,7 @@
 #
 #
 
-# Copyright (C) 2006, 2007, 2008, 2009, 2010 Google Inc.
+# Copyright (C) 2006, 2007, 2008, 2009, 2010, 2011 Google Inc.
 #
 # This program is free software; you can redistribute it and/or modify
 # it under the terms of the GNU General Public License as published by
@@ -32,7 +32,6 @@ import errno
 import weakref
 import logging
 import heapq
-import operator
 import itertools
 
 from ganeti import errors
@@ -247,7 +246,7 @@ class SingleNotifyPipeCondition(_BaseCondition):
       self._write_fd = None
     self._poller = None
 
-  def wait(self, timeout=None):
+  def wait(self, timeout):
     """Wait for a notification.
 
     @type timeout: float or None
@@ -314,7 +313,7 @@ class PipeCondition(_BaseCondition):
     self._waiters = set()
     self._single_condition = self._single_condition_class(self._lock)
 
-  def wait(self, timeout=None):
+  def wait(self, timeout):
     """Wait for a notification.
 
     @type timeout: float or None
@@ -434,9 +433,10 @@ class SharedLock(object):
 
     # Register with lock monitor
     if monitor:
+      logging.debug("Adding lock %s to monitor", name)
       monitor.RegisterLock(self)
 
-  def GetInfo(self, requested):
+  def GetLockInfo(self, requested):
     """Retrieves information for querying locks.
 
     @type requested: set
@@ -489,7 +489,7 @@ class SharedLock(object):
       else:
         pending = None
 
-      return (self.name, mode, owner_names, pending)
+      return [(self.name, mode, owner_names, pending)]
     finally:
       self.__lock.release()
 
@@ -541,6 +541,8 @@ class SharedLock(object):
     finally:
       self.__lock.release()
 
+  is_owned = _is_owned
+
   def _count_pending(self):
     """Returns the number of pending acquires.
 
@@ -562,7 +564,9 @@ class SharedLock(object):
     self.__lock.acquire()
     try:
       # Order is important: __find_first_pending_queue modifies __pending
-      return not (self.__find_first_pending_queue() or
+      (_, prioqueue) = self.__find_first_pending_queue()
+
+      return not (prioqueue or
                   self.__pending or
                   self.__pending_by_prio or
                   self.__pending_shared)
@@ -596,16 +600,15 @@ class SharedLock(object):
     while self.__pending:
       (priority, prioqueue) = self.__pending[0]
 
-      if not prioqueue:
-        heapq.heappop(self.__pending)
-        del self.__pending_by_prio[priority]
-        assert priority not in self.__pending_shared
-        continue
-
       if prioqueue:
-        return prioqueue
+        return (priority, prioqueue)
 
-    return None
+      # Remove empty queue
+      heapq.heappop(self.__pending)
+      del self.__pending_by_prio[priority]
+      assert priority not in self.__pending_shared
+
+    return (None, None)
 
   def __is_on_top(self, cond):
     """Checks whether the passed condition is on top of the queue.
@@ -613,7 +616,9 @@ class SharedLock(object):
     The caller must make sure the queue isn't empty.
 
     """
-    return cond == self.__find_first_pending_queue()[0]
+    (_, prioqueue) = self.__find_first_pending_queue()
+
+    return cond == prioqueue[0]
 
   def __acquire_unlocked(self, shared, timeout, priority):
     """Acquire a shared lock.
@@ -690,7 +695,9 @@ class SharedLock(object):
       if not wait_condition.has_waiting():
         prioqueue.remove(wait_condition)
         if wait_condition.shared:
-          del self.__pending_shared[priority]
+          # Remove from list of shared acquires if it wasn't while releasing
+          # (e.g. on lock deletion)
+          self.__pending_shared.pop(priority, None)
 
     return False
 
@@ -722,6 +729,48 @@ class SharedLock(object):
     finally:
       self.__lock.release()
 
+  def downgrade(self):
+    """Changes the lock mode from exclusive to shared.
+
+    Pending acquires in shared mode on the same priority will go ahead.
+
+    """
+    self.__lock.acquire()
+    try:
+      assert self.__is_owned(), "Lock must be owned"
+
+      if self.__is_exclusive():
+        # Do nothing if the lock is already acquired in shared mode
+        self.__exc = None
+        self.__do_acquire(1)
+
+        # Important: pending shared acquires should only jump ahead if there
+        # was a transition from exclusive to shared, otherwise an owner of a
+        # shared lock can keep calling this function to push incoming shared
+        # acquires
+        (priority, prioqueue) = self.__find_first_pending_queue()
+        if prioqueue:
+          # Is there a pending shared acquire on this priority?
+          cond = self.__pending_shared.pop(priority, None)
+          if cond:
+            assert cond.shared
+            assert cond in prioqueue
+
+            # Ensure shared acquire is on top of queue
+            if len(prioqueue) > 1:
+              prioqueue.remove(cond)
+              prioqueue.insert(0, cond)
+
+            # Notify
+            cond.notifyAll()
+
+      assert not self.__is_exclusive()
+      assert self.__is_sharer()
+
+      return True
+    finally:
+      self.__lock.release()
+
   def release(self):
     """Release a Shared Lock.
 
@@ -741,9 +790,14 @@ class SharedLock(object):
         self.__shr.remove(threading.currentThread())
 
       # Notify topmost condition in queue
-      prioqueue = self.__find_first_pending_queue()
+      (priority, prioqueue) = self.__find_first_pending_queue()
       if prioqueue:
-        prioqueue[0].notifyAll()
+        cond = prioqueue[0]
+        cond.notifyAll()
+        if cond.shared:
+          # Prevent further shared acquires from sneaking in while waiters are
+          # notified
+          self.__pending_shared.pop(priority, None)
 
     finally:
       self.__lock.release()
@@ -845,8 +899,8 @@ class LockSet:
     # Lock monitor
     self.__monitor = monitor
 
-    # Used internally to guarantee coherency.
-    self.__lock = SharedLock(name)
+    # Used internally to guarantee coherency
+    self.__lock = SharedLock(self._GetLockName("[lockset]"), monitor=monitor)
 
     # The lockdict indexes the relationship name -> lock
     # The order-of-locking is implied by the alphabetical order of names
@@ -871,6 +925,21 @@ class LockSet:
     """
     return "%s/%s" % (self.name, mname)
 
+  def _get_lock(self):
+    """Returns the lockset-internal lock.
+
+    """
+    return self.__lock
+
+  def _get_lockdict(self):
+    """Returns the lockset-internal lock dictionary.
+
+    Accessing this structure is only safe in single-thread usage or when the
+    lockset-internal lock is held.
+
+    """
+    return self.__lockdict
+
   def _is_owned(self):
     """Is the current thread a current level owner?"""
     return threading.currentThread() in self.__owners
@@ -1112,6 +1181,42 @@ class LockSet:
 
     return acquired
 
+  def downgrade(self, names=None):
+    """Downgrade a set of resource locks from exclusive to shared mode.
+
+    The locks must have been acquired in exclusive mode.
+
+    """
+    assert self._is_owned(), ("downgrade on lockset %s while not owning any"
+                              " lock" % self.name)
+
+    # Support passing in a single resource to downgrade rather than many
+    if isinstance(names, basestring):
+      names = [names]
+
+    owned = self._list_owned()
+
+    if names is None:
+      names = owned
+    else:
+      names = set(names)
+      assert owned.issuperset(names), \
+        ("downgrade() on unheld resources %s (set %s)" %
+         (names.difference(owned), self.name))
+
+    for lockname in names:
+      self.__lockdict[lockname].downgrade()
+
+    # Do we own the lockset in exclusive mode?
+    if self.__lock._is_owned(shared=0):
+      # Have all locks been downgraded?
+      if not compat.any(lock._is_owned(shared=0)
+                        for lock in self.__lockdict.values()):
+        self.__lock.downgrade()
+        assert self.__lock._is_owned(shared=1)
+
+    return True
+
   def release(self, names=None):
     """Release a set of resource locks, at the same level.
 
@@ -1343,6 +1448,14 @@ class GanetiLockManager:
                               monitor=self._monitor),
       }
 
+  def AddToLockMonitor(self, provider):
+    """Registers a new lock with the monitor.
+
+    See L{LockMonitor.RegisterLock}.
+
+    """
+    return self._monitor.RegisterLock(provider)
+
   def QueryLocks(self, fields):
     """Queries information from all locks.
 
@@ -1448,6 +1561,22 @@ class GanetiLockManager:
     return self.__keyring[level].acquire(names, shared=shared, timeout=timeout,
                                          priority=priority)
 
+  def downgrade(self, level, names=None):
+    """Downgrade a set of resource locks from exclusive to shared mode.
+
+    You must have acquired the locks in exclusive mode.
+
+    @type level: member of locking.LEVELS
+    @param level: the level at which the locks shall be downgraded
+    @type names: list of strings, or None
+    @param names: the names of the locks which shall be downgraded
+        (defaults to all the locks acquired at the level)
+
+    """
+    assert level in LEVELS, "Invalid locking level %s" % level
+
+    return self.__keyring[level].downgrade(names=names)
+
   def release(self, level, names=None):
     """Release a set of resource locks, at the same level.
 
@@ -1517,15 +1646,17 @@ class GanetiLockManager:
     return self.__keyring[level].remove(names)
 
 
-def _MonitorSortKey((num, item)):
+def _MonitorSortKey((item, idx, num)):
   """Sorting key function.
 
-  Sort by name, then by incoming order.
+  Sort by name, registration order and then order of information. This provides
+  a stable sort order over different providers, even if they return the same
+  name.
 
   """
   (name, _, _, _) = item
 
-  return (utils.NiceSortKey(name), num)
+  return (utils.NiceSortKey(name), num, idx)
 
 
 class LockMonitor(object):
@@ -1545,12 +1676,19 @@ class LockMonitor(object):
     self._locks = weakref.WeakKeyDictionary()
 
   @ssynchronized(_LOCK_ATTR)
-  def RegisterLock(self, lock):
+  def RegisterLock(self, provider):
     """Registers a new lock.
 
+    @param provider: Object with a callable method named C{GetLockInfo}, taking
+      a single C{set} containing the requested information items
+    @note: It would be nicer to only receive the function generating the
+      requested information but, as it turns out, weak references to bound
+      methods (e.g. C{self.GetLockInfo}) are tricky; there are several
+      workarounds, but none of the ones I found works properly in combination
+      with a standard C{WeakKeyDictionary}
+
     """
-    logging.debug("Registering lock %s", lock.name)
-    assert lock not in self._locks, "Duplicate lock registration"
+    assert provider not in self._locks, "Duplicate registration"
 
     # There used to be a check for duplicate names here. As it turned out, when
     # a lock is re-created with the same name in a very short timeframe, the
@@ -1558,14 +1696,22 @@ class LockMonitor(object):
     # By keeping track of the order of incoming registrations, a stable sort
     # ordering can still be guaranteed.
 
-    self._locks[lock] = self._counter.next()
+    self._locks[provider] = self._counter.next()
 
-  @ssynchronized(_LOCK_ATTR)
   def _GetLockInfo(self, requested):
-    """Get information from all locks while the monitor lock is held.
+    """Get information from all locks.
 
     """
-    return [(num, lock.GetInfo(requested)) for lock, num in self._locks.items()]
+    # Must hold lock while getting consistent list of tracked items
+    self._lock.acquire(shared=1)
+    try:
+      items = self._locks.items()
+    finally:
+      self._lock.release()
+
+    return [(info, idx, num)
+            for (provider, num) in items
+            for (idx, info) in enumerate(provider.GetLockInfo(requested))]
 
   def _Query(self, fields):
     """Queries information from all locks.
@@ -1582,7 +1728,7 @@ class LockMonitor(object):
                       key=_MonitorSortKey)
 
     # Extract lock information and build query data
-    return (qobj, query.LockQueryData(map(operator.itemgetter(1), lockinfo)))
+    return (qobj, query.LockQueryData(map(compat.fst, lockinfo)))
 
   def QueryLocks(self, fields):
     """Queries information from all locks.