Fix the downgrade function of cfgupgrade
[ganeti-local] / lib / workerpool.py
index 8127329..6b558ce 100644 (file)
 
 """
 
-import collections
 import logging
 import threading
+import heapq
+import itertools
 
 from ganeti import compat
+from ganeti import errors
 
 
 _TERMINATE = object()
+_DEFAULT_PRIORITY = 0
+
+
+class DeferTask(Exception):
+  """Special exception class to defer a task.
+
+  This class can be raised by L{BaseWorker.RunTask} to defer the execution of a
+  task. Optionally, the priority of the task can be changed.
+
+  """
+  def __init__(self, priority=None):
+    """Initializes this class.
+
+    @type priority: number
+    @param priority: New task priority (None means no change)
+
+    """
+    Exception.__init__(self)
+    self.priority = priority
+
+
+class NoSuchTask(Exception):
+  """Exception raised when a task can't be found.
+
+  """
 
 
 class BaseWorker(threading.Thread, object):
@@ -39,7 +66,7 @@ class BaseWorker(threading.Thread, object):
   Users of a worker pool must override RunTask in a subclass.
 
   """
-  # pylint: disable-msg=W0212
+  # pylint: disable=W0212
   def __init__(self, pool, worker_id):
     """Constructor for BaseWorker thread.
 
@@ -49,8 +76,11 @@ class BaseWorker(threading.Thread, object):
     """
     super(BaseWorker, self).__init__(name=worker_id)
     self.pool = pool
+    self._worker_id = worker_id
     self._current_task = None
 
+    assert self.getName() == worker_id
+
   def ShouldTerminate(self):
     """Returns whether this worker should terminate.
 
@@ -64,12 +94,61 @@ class BaseWorker(threading.Thread, object):
     finally:
       self.pool._lock.release()
 
+  def GetCurrentPriority(self):
+    """Returns the priority of the current task.
+
+    Should only be called from within L{RunTask}.
+
+    """
+    self.pool._lock.acquire()
+    try:
+      assert self._HasRunningTaskUnlocked()
+
+      (priority, _, _, _) = self._current_task
+
+      return priority
+    finally:
+      self.pool._lock.release()
+
+  def SetTaskName(self, taskname):
+    """Sets the name of the current task.
+
+    Should only be called from within L{RunTask}.
+
+    @type taskname: string
+    @param taskname: Task's name
+
+    """
+    if taskname:
+      name = "%s/%s" % (self._worker_id, taskname)
+    else:
+      name = self._worker_id
+
+    # Set thread name
+    self.setName(name)
+
   def _HasRunningTaskUnlocked(self):
     """Returns whether this worker is currently running a task.
 
     """
     return (self._current_task is not None)
 
+  def _GetCurrentOrderAndTaskId(self):
+    """Returns the order and task ID of the current task.
+
+    Should only be called from within L{RunTask}.
+
+    """
+    self.pool._lock.acquire()
+    try:
+      assert self._HasRunningTaskUnlocked()
+
+      (_, order_id, task_id, _) = self._current_task
+
+      return (order_id, task_id)
+    finally:
+      self.pool._lock.release()
+
   def run(self):
     """Main thread function.
 
@@ -80,6 +159,8 @@ class BaseWorker(threading.Thread, object):
 
     while True:
       assert self._current_task is None
+
+      defer = None
       try:
         # Wait on lock to be told either to terminate or to do a task
         pool._lock.acquire()
@@ -104,12 +185,29 @@ class BaseWorker(threading.Thread, object):
         finally:
           pool._lock.release()
 
-        # Run the actual task
+        (priority, _, _, args) = self._current_task
         try:
-          logging.debug("Starting task %r", self._current_task)
-          self.RunTask(*self._current_task)
-          logging.debug("Done with task %r", self._current_task)
-        except: # pylint: disable-msg=W0702
+          # Run the actual task
+          assert defer is None
+          logging.debug("Starting task %r, priority %s", args, priority)
+          assert self.getName() == self._worker_id
+          try:
+            self.RunTask(*args) # pylint: disable=W0142
+          finally:
+            self.SetTaskName(None)
+          logging.debug("Done with task %r, priority %s", args, priority)
+        except DeferTask, err:
+          defer = err
+
+          if defer.priority is None:
+            # Use same priority
+            defer.priority = priority
+
+          logging.debug("Deferring task %r, new priority %s",
+                        args, defer.priority)
+
+          assert self._HasRunningTaskUnlocked()
+        except: # pylint: disable=W0702
           logging.exception("Caught unhandled exception")
 
         assert self._HasRunningTaskUnlocked()
@@ -117,6 +215,12 @@ class BaseWorker(threading.Thread, object):
         # Notify pool
         pool._lock.acquire()
         try:
+          if defer:
+            assert self._current_task
+            # Schedule again for later run
+            (_, _, task_id, args) = self._current_task
+            pool._AddTaskUnlocked(args, defer.priority, task_id)
+
           if self._current_task:
             self._current_task = None
             pool._worker_to_pool.notifyAll()
@@ -145,6 +249,18 @@ class WorkerPool(object):
   added to the pool. Due to the nature of threading, they're not
   guaranteed to finish in the same order.
 
+  @type _tasks: list of tuples
+  @ivar _tasks: Each tuple has the format (priority, order ID, task ID,
+    arguments). Priority and order ID are numeric and essentially control the
+    sort order. The order ID is an increasing number denoting the order in
+    which tasks are added to the queue. The task ID is controlled by user of
+    workerpool, see L{AddTask} for details. The task arguments are C{None} for
+    abandoned tasks, otherwise a sequence of arguments to be passed to
+    L{BaseWorker.RunTask}). The list must fulfill the heap property (for use by
+    the C{heapq} module).
+  @type _taskdata: dict; (task IDs as keys, tuples as values)
+  @ivar _taskdata: Mapping from task IDs to entries in L{_tasks}
+
   """
   def __init__(self, name, num_workers, worker_class):
     """Constructor for worker pool.
@@ -165,12 +281,15 @@ class WorkerPool(object):
     self._last_worker_id = 0
     self._workers = []
     self._quiescing = False
+    self._active = True
 
     # Terminating workers
     self._termworkers = []
 
     # Queued tasks
-    self._tasks = collections.deque()
+    self._counter = itertools.count()
+    self._tasks = []
+    self._taskdata = {}
 
     # Start workers
     self.Resize(num_workers)
@@ -184,44 +303,169 @@ class WorkerPool(object):
     while self._quiescing:
       self._pool_to_pool.wait()
 
-  def _AddTaskUnlocked(self, args):
+  def _AddTaskUnlocked(self, args, priority, task_id):
+    """Adds a task to the internal queue.
+
+    @type args: sequence
+    @param args: Arguments passed to L{BaseWorker.RunTask}
+    @type priority: number
+    @param priority: Task priority
+    @param task_id: Task ID
+
+    """
     assert isinstance(args, (tuple, list)), "Arguments must be a sequence"
+    assert isinstance(priority, (int, long)), "Priority must be numeric"
+    assert task_id is None or isinstance(task_id, (int, long)), \
+      "Task ID must be numeric or None"
+
+    task = [priority, self._counter.next(), task_id, args]
 
-    self._tasks.append(args)
+    if task_id is not None:
+      assert task_id not in self._taskdata
+      # Keep a reference to change priority later if necessary
+      self._taskdata[task_id] = task
+
+    # A counter is used to ensure elements are processed in their incoming
+    # order. For processing they're sorted by priority and then counter.
+    heapq.heappush(self._tasks, task)
 
     # Notify a waiting worker
     self._pool_to_worker.notify()
 
-  def AddTask(self, args):
+  def AddTask(self, args, priority=_DEFAULT_PRIORITY, task_id=None):
     """Adds a task to the queue.
 
     @type args: sequence
     @param args: arguments passed to L{BaseWorker.RunTask}
+    @type priority: number
+    @param priority: Task priority
+    @param task_id: Task ID
+    @note: The task ID can be essentially anything that can be used as a
+      dictionary key. Callers, however, must ensure a task ID is unique while a
+      task is in the pool or while it might return to the pool due to deferring
+      using L{DeferTask}.
 
     """
     self._lock.acquire()
     try:
       self._WaitWhileQuiescingUnlocked()
-      self._AddTaskUnlocked(args)
+      self._AddTaskUnlocked(args, priority, task_id)
     finally:
       self._lock.release()
 
-  def AddManyTasks(self, tasks):
+  def AddManyTasks(self, tasks, priority=_DEFAULT_PRIORITY, task_id=None):
     """Add a list of tasks to the queue.
 
     @type tasks: list of tuples
     @param tasks: list of args passed to L{BaseWorker.RunTask}
+    @type priority: number or list of numbers
+    @param priority: Priority for all added tasks or a list with the priority
+                     for each task
+    @type task_id: list
+    @param task_id: List with the ID for each task
+    @note: See L{AddTask} for a note on task IDs.
 
     """
     assert compat.all(isinstance(task, (tuple, list)) for task in tasks), \
-      "Each task must be a sequence"
+           "Each task must be a sequence"
+    assert (isinstance(priority, (int, long)) or
+            compat.all(isinstance(prio, (int, long)) for prio in priority)), \
+           "Priority must be numeric or be a list of numeric values"
+    assert task_id is None or isinstance(task_id, (tuple, list)), \
+           "Task IDs must be in a sequence"
+
+    if isinstance(priority, (int, long)):
+      priority = [priority] * len(tasks)
+    elif len(priority) != len(tasks):
+      raise errors.ProgrammerError("Number of priorities (%s) doesn't match"
+                                   " number of tasks (%s)" %
+                                   (len(priority), len(tasks)))
+
+    if task_id is None:
+      task_id = [None] * len(tasks)
+    elif len(task_id) != len(tasks):
+      raise errors.ProgrammerError("Number of task IDs (%s) doesn't match"
+                                   " number of tasks (%s)" %
+                                   (len(task_id), len(tasks)))
 
     self._lock.acquire()
     try:
       self._WaitWhileQuiescingUnlocked()
 
-      for args in tasks:
-        self._AddTaskUnlocked(args)
+      assert compat.all(isinstance(prio, (int, long)) for prio in priority)
+      assert len(tasks) == len(priority)
+      assert len(tasks) == len(task_id)
+
+      for (args, prio, tid) in zip(tasks, priority, task_id):
+        self._AddTaskUnlocked(args, prio, tid)
+    finally:
+      self._lock.release()
+
+  def ChangeTaskPriority(self, task_id, priority):
+    """Changes a task's priority.
+
+    @param task_id: Task ID
+    @type priority: number
+    @param priority: New task priority
+    @raise NoSuchTask: When the task referred by C{task_id} can not be found
+      (it may never have existed, may have already been processed, or is
+      currently running)
+
+    """
+    assert isinstance(priority, (int, long)), "Priority must be numeric"
+
+    self._lock.acquire()
+    try:
+      logging.debug("About to change priority of task %s to %s",
+                    task_id, priority)
+
+      # Find old task
+      oldtask = self._taskdata.get(task_id, None)
+      if oldtask is None:
+        msg = "Task '%s' was not found" % task_id
+        logging.debug(msg)
+        raise NoSuchTask(msg)
+
+      # Prepare new task
+      newtask = [priority] + oldtask[1:]
+
+      # Mark old entry as abandoned (this doesn't change the sort order and
+      # therefore doesn't invalidate the heap property of L{self._tasks}).
+      # See also <http://docs.python.org/library/heapq.html#priority-queue-
+      # implementation-notes>.
+      oldtask[-1] = None
+
+      # Change reference to new task entry and forget the old one
+      assert task_id is not None
+      self._taskdata[task_id] = newtask
+
+      # Add a new task with the old number and arguments
+      heapq.heappush(self._tasks, newtask)
+
+      # Notify a waiting worker
+      self._pool_to_worker.notify()
+    finally:
+      self._lock.release()
+
+  def SetActive(self, active):
+    """Enable/disable processing of tasks.
+
+    This is different from L{Quiesce} in the sense that this function just
+    changes an internal flag and doesn't wait for the queue to be empty. Tasks
+    already being processed continue normally, but no new tasks will be
+    started. New tasks can still be added.
+
+    @type active: bool
+    @param active: Whether tasks should be processed
+
+    """
+    self._lock.acquire()
+    try:
+      self._active = active
+
+      if active:
+        # Tell all workers to continue processing
+        self._pool_to_worker.notifyAll()
     finally:
       self._lock.release()
 
@@ -232,11 +476,32 @@ class WorkerPool(object):
     @param worker: Worker thread
 
     """
-    if self._ShouldWorkerTerminateUnlocked(worker):
-      return _TERMINATE
+    while True:
+      if self._ShouldWorkerTerminateUnlocked(worker):
+        return _TERMINATE
+
+      # If there's a pending task, return it immediately
+      if self._active and self._tasks:
+        # Get task from queue and tell pool about it
+        try:
+          task = heapq.heappop(self._tasks)
+        finally:
+          self._worker_to_pool.notifyAll()
+
+        (_, _, task_id, args) = task
+
+        # If the priority was changed, "args" is None
+        if args is None:
+          # Try again
+          logging.debug("Found abandoned task (%r)", task)
+          continue
+
+        # Delete reference
+        if task_id is not None:
+          del self._taskdata[task_id]
+
+        return task
 
-    # We only wait if there's no task for us.
-    if not self._tasks:
       logging.debug("Waiting for tasks")
 
       # wait() releases the lock and sleeps until notified
@@ -244,20 +509,6 @@ class WorkerPool(object):
 
       logging.debug("Notified while waiting")
 
-      # Were we woken up in order to terminate?
-      if self._ShouldWorkerTerminateUnlocked(worker):
-        return _TERMINATE
-
-      if not self._tasks:
-        # Spurious notification, ignore
-        return None
-
-    # Get task from queue and tell pool about it
-    try:
-      return self._tasks.popleft()
-    finally:
-      self._worker_to_pool.notifyAll()
-
   def _ShouldWorkerTerminateUnlocked(self, worker):
     """Returns whether a worker should terminate.
 
@@ -269,10 +520,20 @@ class WorkerPool(object):
 
     """
     for worker in self._workers + self._termworkers:
-      if worker._HasRunningTaskUnlocked(): # pylint: disable-msg=W0212
+      if worker._HasRunningTaskUnlocked(): # pylint: disable=W0212
         return True
     return False
 
+  def HasRunningTasks(self):
+    """Checks whether there's at least one task running.
+
+    """
+    self._lock.acquire()
+    try:
+      return self._HasRunningTasksUnlocked()
+    finally:
+      self._lock.release()
+
   def Quiesce(self):
     """Waits until the task queue is empty.