Merge branch 'devel-2.7'
[ganeti-local] / lib / workerpool.py
index 91fb106..6b558ce 100644 (file)
@@ -26,6 +26,7 @@
 import logging
 import threading
 import heapq
 import logging
 import threading
 import heapq
+import itertools
 
 from ganeti import compat
 from ganeti import errors
 
 from ganeti import compat
 from ganeti import errors
@@ -53,6 +54,12 @@ class DeferTask(Exception):
     self.priority = priority
 
 
     self.priority = priority
 
 
+class NoSuchTask(Exception):
+  """Exception raised when a task can't be found.
+
+  """
+
+
 class BaseWorker(threading.Thread, object):
   """Base worker class for worker pools.
 
 class BaseWorker(threading.Thread, object):
   """Base worker class for worker pools.
 
@@ -97,7 +104,7 @@ class BaseWorker(threading.Thread, object):
     try:
       assert self._HasRunningTaskUnlocked()
 
     try:
       assert self._HasRunningTaskUnlocked()
 
-      (priority, _, _) = self._current_task
+      (priority, _, _, _) = self._current_task
 
       return priority
     finally:
 
       return priority
     finally:
@@ -126,6 +133,22 @@ class BaseWorker(threading.Thread, object):
     """
     return (self._current_task is not None)
 
     """
     return (self._current_task is not None)
 
+  def _GetCurrentOrderAndTaskId(self):
+    """Returns the order and task ID of the current task.
+
+    Should only be called from within L{RunTask}.
+
+    """
+    self.pool._lock.acquire()
+    try:
+      assert self._HasRunningTaskUnlocked()
+
+      (_, order_id, task_id, _) = self._current_task
+
+      return (order_id, task_id)
+    finally:
+      self.pool._lock.release()
+
   def run(self):
     """Main thread function.
 
   def run(self):
     """Main thread function.
 
@@ -162,7 +185,7 @@ class BaseWorker(threading.Thread, object):
         finally:
           pool._lock.release()
 
         finally:
           pool._lock.release()
 
-        (priority, _, args) = self._current_task
+        (priority, _, _, args) = self._current_task
         try:
           # Run the actual task
           assert defer is None
         try:
           # Run the actual task
           assert defer is None
@@ -195,8 +218,8 @@ class BaseWorker(threading.Thread, object):
           if defer:
             assert self._current_task
             # Schedule again for later run
           if defer:
             assert self._current_task
             # Schedule again for later run
-            (_, _, args) = self._current_task
-            pool._AddTaskUnlocked(args, defer.priority)
+            (_, _, task_id, args) = self._current_task
+            pool._AddTaskUnlocked(args, defer.priority, task_id)
 
           if self._current_task:
             self._current_task = None
 
           if self._current_task:
             self._current_task = None
@@ -226,6 +249,18 @@ class WorkerPool(object):
   added to the pool. Due to the nature of threading, they're not
   guaranteed to finish in the same order.
 
   added to the pool. Due to the nature of threading, they're not
   guaranteed to finish in the same order.
 
+  @type _tasks: list of tuples
+  @ivar _tasks: Each tuple has the format (priority, order ID, task ID,
+    arguments). Priority and order ID are numeric and essentially control the
+    sort order. The order ID is an increasing number denoting the order in
+    which tasks are added to the queue. The task ID is controlled by user of
+    workerpool, see L{AddTask} for details. The task arguments are C{None} for
+    abandoned tasks, otherwise a sequence of arguments to be passed to
+    L{BaseWorker.RunTask}). The list must fulfill the heap property (for use by
+    the C{heapq} module).
+  @type _taskdata: dict; (task IDs as keys, tuples as values)
+  @ivar _taskdata: Mapping from task IDs to entries in L{_tasks}
+
   """
   def __init__(self, name, num_workers, worker_class):
     """Constructor for worker pool.
   """
   def __init__(self, name, num_workers, worker_class):
     """Constructor for worker pool.
@@ -252,8 +287,9 @@ class WorkerPool(object):
     self._termworkers = []
 
     # Queued tasks
     self._termworkers = []
 
     # Queued tasks
-    self._counter = 0
+    self._counter = itertools.count()
     self._tasks = []
     self._tasks = []
+    self._taskdata = {}
 
     # Start workers
     self.Resize(num_workers)
 
     # Start workers
     self.Resize(num_workers)
@@ -267,45 +303,57 @@ class WorkerPool(object):
     while self._quiescing:
       self._pool_to_pool.wait()
 
     while self._quiescing:
       self._pool_to_pool.wait()
 
-  def _AddTaskUnlocked(self, args, priority):
+  def _AddTaskUnlocked(self, args, priority, task_id):
     """Adds a task to the internal queue.
 
     @type args: sequence
     @param args: Arguments passed to L{BaseWorker.RunTask}
     @type priority: number
     @param priority: Task priority
     """Adds a task to the internal queue.
 
     @type args: sequence
     @param args: Arguments passed to L{BaseWorker.RunTask}
     @type priority: number
     @param priority: Task priority
+    @param task_id: Task ID
 
     """
     assert isinstance(args, (tuple, list)), "Arguments must be a sequence"
     assert isinstance(priority, (int, long)), "Priority must be numeric"
 
     """
     assert isinstance(args, (tuple, list)), "Arguments must be a sequence"
     assert isinstance(priority, (int, long)), "Priority must be numeric"
+    assert task_id is None or isinstance(task_id, (int, long)), \
+      "Task ID must be numeric or None"
+
+    task = [priority, self._counter.next(), task_id, args]
 
 
-    # This counter is used to ensure elements are processed in their
-    # incoming order. For processing they're sorted by priority and then
-    # counter.
-    self._counter += 1
+    if task_id is not None:
+      assert task_id not in self._taskdata
+      # Keep a reference to change priority later if necessary
+      self._taskdata[task_id] = task
 
 
-    heapq.heappush(self._tasks, (priority, self._counter, args))
+    # A counter is used to ensure elements are processed in their incoming
+    # order. For processing they're sorted by priority and then counter.
+    heapq.heappush(self._tasks, task)
 
     # Notify a waiting worker
     self._pool_to_worker.notify()
 
 
     # Notify a waiting worker
     self._pool_to_worker.notify()
 
-  def AddTask(self, args, priority=_DEFAULT_PRIORITY):
+  def AddTask(self, args, priority=_DEFAULT_PRIORITY, task_id=None):
     """Adds a task to the queue.
 
     @type args: sequence
     @param args: arguments passed to L{BaseWorker.RunTask}
     @type priority: number
     @param priority: Task priority
     """Adds a task to the queue.
 
     @type args: sequence
     @param args: arguments passed to L{BaseWorker.RunTask}
     @type priority: number
     @param priority: Task priority
+    @param task_id: Task ID
+    @note: The task ID can be essentially anything that can be used as a
+      dictionary key. Callers, however, must ensure a task ID is unique while a
+      task is in the pool or while it might return to the pool due to deferring
+      using L{DeferTask}.
 
     """
     self._lock.acquire()
     try:
       self._WaitWhileQuiescingUnlocked()
 
     """
     self._lock.acquire()
     try:
       self._WaitWhileQuiescingUnlocked()
-      self._AddTaskUnlocked(args, priority)
+      self._AddTaskUnlocked(args, priority, task_id)
     finally:
       self._lock.release()
 
     finally:
       self._lock.release()
 
-  def AddManyTasks(self, tasks, priority=_DEFAULT_PRIORITY):
+  def AddManyTasks(self, tasks, priority=_DEFAULT_PRIORITY, task_id=None):
     """Add a list of tasks to the queue.
 
     @type tasks: list of tuples
     """Add a list of tasks to the queue.
 
     @type tasks: list of tuples
@@ -313,14 +361,18 @@ class WorkerPool(object):
     @type priority: number or list of numbers
     @param priority: Priority for all added tasks or a list with the priority
                      for each task
     @type priority: number or list of numbers
     @param priority: Priority for all added tasks or a list with the priority
                      for each task
+    @type task_id: list
+    @param task_id: List with the ID for each task
+    @note: See L{AddTask} for a note on task IDs.
 
     """
     assert compat.all(isinstance(task, (tuple, list)) for task in tasks), \
 
     """
     assert compat.all(isinstance(task, (tuple, list)) for task in tasks), \
-      "Each task must be a sequence"
-
+           "Each task must be a sequence"
     assert (isinstance(priority, (int, long)) or
             compat.all(isinstance(prio, (int, long)) for prio in priority)), \
            "Priority must be numeric or be a list of numeric values"
     assert (isinstance(priority, (int, long)) or
             compat.all(isinstance(prio, (int, long)) for prio in priority)), \
            "Priority must be numeric or be a list of numeric values"
+    assert task_id is None or isinstance(task_id, (tuple, list)), \
+           "Task IDs must be in a sequence"
 
     if isinstance(priority, (int, long)):
       priority = [priority] * len(tasks)
 
     if isinstance(priority, (int, long)):
       priority = [priority] * len(tasks)
@@ -329,15 +381,69 @@ class WorkerPool(object):
                                    " number of tasks (%s)" %
                                    (len(priority), len(tasks)))
 
                                    " number of tasks (%s)" %
                                    (len(priority), len(tasks)))
 
+    if task_id is None:
+      task_id = [None] * len(tasks)
+    elif len(task_id) != len(tasks):
+      raise errors.ProgrammerError("Number of task IDs (%s) doesn't match"
+                                   " number of tasks (%s)" %
+                                   (len(task_id), len(tasks)))
+
     self._lock.acquire()
     try:
       self._WaitWhileQuiescingUnlocked()
 
       assert compat.all(isinstance(prio, (int, long)) for prio in priority)
       assert len(tasks) == len(priority)
     self._lock.acquire()
     try:
       self._WaitWhileQuiescingUnlocked()
 
       assert compat.all(isinstance(prio, (int, long)) for prio in priority)
       assert len(tasks) == len(priority)
+      assert len(tasks) == len(task_id)
 
 
-      for args, priority in zip(tasks, priority):
-        self._AddTaskUnlocked(args, priority)
+      for (args, prio, tid) in zip(tasks, priority, task_id):
+        self._AddTaskUnlocked(args, prio, tid)
+    finally:
+      self._lock.release()
+
+  def ChangeTaskPriority(self, task_id, priority):
+    """Changes a task's priority.
+
+    @param task_id: Task ID
+    @type priority: number
+    @param priority: New task priority
+    @raise NoSuchTask: When the task referred by C{task_id} can not be found
+      (it may never have existed, may have already been processed, or is
+      currently running)
+
+    """
+    assert isinstance(priority, (int, long)), "Priority must be numeric"
+
+    self._lock.acquire()
+    try:
+      logging.debug("About to change priority of task %s to %s",
+                    task_id, priority)
+
+      # Find old task
+      oldtask = self._taskdata.get(task_id, None)
+      if oldtask is None:
+        msg = "Task '%s' was not found" % task_id
+        logging.debug(msg)
+        raise NoSuchTask(msg)
+
+      # Prepare new task
+      newtask = [priority] + oldtask[1:]
+
+      # Mark old entry as abandoned (this doesn't change the sort order and
+      # therefore doesn't invalidate the heap property of L{self._tasks}).
+      # See also <http://docs.python.org/library/heapq.html#priority-queue-
+      # implementation-notes>.
+      oldtask[-1] = None
+
+      # Change reference to new task entry and forget the old one
+      assert task_id is not None
+      self._taskdata[task_id] = newtask
+
+      # Add a new task with the old number and arguments
+      heapq.heappush(self._tasks, newtask)
+
+      # Notify a waiting worker
+      self._pool_to_worker.notify()
     finally:
       self._lock.release()
 
     finally:
       self._lock.release()
 
@@ -382,6 +488,18 @@ class WorkerPool(object):
         finally:
           self._worker_to_pool.notifyAll()
 
         finally:
           self._worker_to_pool.notifyAll()
 
+        (_, _, task_id, args) = task
+
+        # If the priority was changed, "args" is None
+        if args is None:
+          # Try again
+          logging.debug("Found abandoned task (%r)", task)
+          continue
+
+        # Delete reference
+        if task_id is not None:
+          del self._taskdata[task_id]
+
         return task
 
       logging.debug("Waiting for tasks")
         return task
 
       logging.debug("Waiting for tasks")