Merge remote-tracking branch 'origin/stable-2.8'
[ganeti-local] / lib / workerpool.py
index 54b3fb7..6b558ce 100644 (file)
@@ -1,7 +1,7 @@
 #
 #
 
 #
 #
 
-# Copyright (C) 2008 Google Inc.
+# Copyright (C) 2008, 2009, 2010 Google Inc.
 #
 # This program is free software; you can redistribute it and/or modify
 # it under the terms of the GNU General Public License as published by
 #
 # This program is free software; you can redistribute it and/or modify
 # it under the terms of the GNU General Public License as published by
 
 """
 
 
 """
 
-import collections
 import logging
 import threading
 import logging
 import threading
+import heapq
+import itertools
+
+from ganeti import compat
+from ganeti import errors
+
+
+_TERMINATE = object()
+_DEFAULT_PRIORITY = 0
+
+
+class DeferTask(Exception):
+  """Special exception class to defer a task.
+
+  This class can be raised by L{BaseWorker.RunTask} to defer the execution of a
+  task. Optionally, the priority of the task can be changed.
+
+  """
+  def __init__(self, priority=None):
+    """Initializes this class.
+
+    @type priority: number
+    @param priority: New task priority (None means no change)
+
+    """
+    Exception.__init__(self)
+    self.priority = priority
+
+
+class NoSuchTask(Exception):
+  """Exception raised when a task can't be found.
+
+  """
 
 
 class BaseWorker(threading.Thread, object):
 
 
 class BaseWorker(threading.Thread, object):
@@ -34,6 +66,7 @@ class BaseWorker(threading.Thread, object):
   Users of a worker pool must override RunTask in a subclass.
 
   """
   Users of a worker pool must override RunTask in a subclass.
 
   """
+  # pylint: disable=W0212
   def __init__(self, pool, worker_id):
     """Constructor for BaseWorker thread.
 
   def __init__(self, pool, worker_id):
     """Constructor for BaseWorker thread.
 
@@ -41,16 +74,58 @@ class BaseWorker(threading.Thread, object):
     @param worker_id: identifier for this worker
 
     """
     @param worker_id: identifier for this worker
 
     """
-    super(BaseWorker, self).__init__()
+    super(BaseWorker, self).__init__(name=worker_id)
     self.pool = pool
     self.pool = pool
-    self.worker_id = worker_id
+    self._worker_id = worker_id
     self._current_task = None
 
     self._current_task = None
 
+    assert self.getName() == worker_id
+
   def ShouldTerminate(self):
   def ShouldTerminate(self):
-    """Returns whether a worker should terminate.
+    """Returns whether this worker should terminate.
+
+    Should only be called from within L{RunTask}.
+
+    """
+    self.pool._lock.acquire()
+    try:
+      assert self._HasRunningTaskUnlocked()
+      return self.pool._ShouldWorkerTerminateUnlocked(self)
+    finally:
+      self.pool._lock.release()
+
+  def GetCurrentPriority(self):
+    """Returns the priority of the current task.
+
+    Should only be called from within L{RunTask}.
+
+    """
+    self.pool._lock.acquire()
+    try:
+      assert self._HasRunningTaskUnlocked()
+
+      (priority, _, _, _) = self._current_task
+
+      return priority
+    finally:
+      self.pool._lock.release()
+
+  def SetTaskName(self, taskname):
+    """Sets the name of the current task.
+
+    Should only be called from within L{RunTask}.
+
+    @type taskname: string
+    @param taskname: Task's name
 
     """
 
     """
-    return self.pool.ShouldWorkerTerminate(self)
+    if taskname:
+      name = "%s/%s" % (self._worker_id, taskname)
+    else:
+      name = self._worker_id
+
+    # Set thread name
+    self.setName(name)
 
   def _HasRunningTaskUnlocked(self):
     """Returns whether this worker is currently running a task.
 
   def _HasRunningTaskUnlocked(self):
     """Returns whether this worker is currently running a task.
@@ -58,13 +133,19 @@ class BaseWorker(threading.Thread, object):
     """
     return (self._current_task is not None)
 
     """
     return (self._current_task is not None)
 
-  def HasRunningTask(self):
-    """Returns whether this worker is currently running a task.
+  def _GetCurrentOrderAndTaskId(self):
+    """Returns the order and task ID of the current task.
+
+    Should only be called from within L{RunTask}.
 
     """
     self.pool._lock.acquire()
     try:
 
     """
     self.pool._lock.acquire()
     try:
-      return self._HasRunningTaskUnlocked()
+      assert self._HasRunningTaskUnlocked()
+
+      (_, order_id, task_id, _) = self._current_task
+
+      return (order_id, task_id)
     finally:
       self.pool._lock.release()
 
     finally:
       self.pool._lock.release()
 
@@ -76,62 +157,79 @@ class BaseWorker(threading.Thread, object):
     """
     pool = self.pool
 
     """
     pool = self.pool
 
-    assert not self.HasRunningTask()
-
     while True:
     while True:
+      assert self._current_task is None
+
+      defer = None
       try:
       try:
-        # We wait on lock to be told either terminate or do a task.
+        # Wait on lock to be told either to terminate or to do a task
         pool._lock.acquire()
         try:
         pool._lock.acquire()
         try:
-          if pool._ShouldWorkerTerminateUnlocked(self):
-            break
+          task = pool._WaitForTaskUnlocked(self)
 
 
-          # We only wait if there's no task for us.
-          if not pool._tasks:
-            logging.debug("Worker %s: waiting for tasks", self.worker_id)
+          if task is _TERMINATE:
+            # Told to terminate
+            break
 
 
-            # wait() releases the lock and sleeps until notified
-            pool._pool_to_worker.wait()
+          if task is None:
+            # Spurious notification, ignore
+            continue
 
 
-            logging.debug("Worker %s: notified while waiting", self.worker_id)
+          self._current_task = task
 
 
-            # Were we woken up in order to terminate?
-            if pool._ShouldWorkerTerminateUnlocked(self):
-              break
+          # No longer needed, dispose of reference
+          del task
 
 
-            if not pool._tasks:
-              # Spurious notification, ignore
-              continue
+          assert self._HasRunningTaskUnlocked()
 
 
-          # Get task from queue and tell pool about it
-          try:
-            self._current_task = pool._tasks.popleft()
-          finally:
-            pool._worker_to_pool.notifyAll()
         finally:
           pool._lock.release()
 
         finally:
           pool._lock.release()
 
-        # Run the actual task
+        (priority, _, _, args) = self._current_task
         try:
         try:
-          logging.debug("Worker %s: starting task %r",
-                        self.worker_id, self._current_task)
-          self.RunTask(*self._current_task)
-          logging.debug("Worker %s: done with task %r",
-                        self.worker_id, self._current_task)
-        except:
-          logging.error("Worker %s: Caught unhandled exception",
-                        self.worker_id, exc_info=True)
+          # Run the actual task
+          assert defer is None
+          logging.debug("Starting task %r, priority %s", args, priority)
+          assert self.getName() == self._worker_id
+          try:
+            self.RunTask(*args) # pylint: disable=W0142
+          finally:
+            self.SetTaskName(None)
+          logging.debug("Done with task %r, priority %s", args, priority)
+        except DeferTask, err:
+          defer = err
+
+          if defer.priority is None:
+            # Use same priority
+            defer.priority = priority
+
+          logging.debug("Deferring task %r, new priority %s",
+                        args, defer.priority)
+
+          assert self._HasRunningTaskUnlocked()
+        except: # pylint: disable=W0702
+          logging.exception("Caught unhandled exception")
+
+        assert self._HasRunningTaskUnlocked()
       finally:
         # Notify pool
         pool._lock.acquire()
         try:
       finally:
         # Notify pool
         pool._lock.acquire()
         try:
+          if defer:
+            assert self._current_task
+            # Schedule again for later run
+            (_, _, task_id, args) = self._current_task
+            pool._AddTaskUnlocked(args, defer.priority, task_id)
+
           if self._current_task:
             self._current_task = None
             pool._worker_to_pool.notifyAll()
         finally:
           pool._lock.release()
 
           if self._current_task:
             self._current_task = None
             pool._worker_to_pool.notifyAll()
         finally:
           pool._lock.release()
 
-    logging.debug("Worker %s: terminates", self.worker_id)
+      assert not self._HasRunningTaskUnlocked()
+
+    logging.debug("Terminates")
 
   def RunTask(self, *args):
     """Function called to start a task.
 
   def RunTask(self, *args):
     """Function called to start a task.
@@ -151,8 +249,20 @@ class WorkerPool(object):
   added to the pool. Due to the nature of threading, they're not
   guaranteed to finish in the same order.
 
   added to the pool. Due to the nature of threading, they're not
   guaranteed to finish in the same order.
 
+  @type _tasks: list of tuples
+  @ivar _tasks: Each tuple has the format (priority, order ID, task ID,
+    arguments). Priority and order ID are numeric and essentially control the
+    sort order. The order ID is an increasing number denoting the order in
+    which tasks are added to the queue. The task ID is controlled by user of
+    workerpool, see L{AddTask} for details. The task arguments are C{None} for
+    abandoned tasks, otherwise a sequence of arguments to be passed to
+    L{BaseWorker.RunTask}). The list must fulfill the heap property (for use by
+    the C{heapq} module).
+  @type _taskdata: dict; (task IDs as keys, tuples as values)
+  @ivar _taskdata: Mapping from task IDs to entries in L{_tasks}
+
   """
   """
-  def __init__(self, num_workers, worker_class):
+  def __init__(self, name, num_workers, worker_class):
     """Constructor for worker pool.
 
     @param num_workers: number of workers to be started
     """Constructor for worker pool.
 
     @param num_workers: number of workers to be started
@@ -167,66 +277,263 @@ class WorkerPool(object):
     self._pool_to_worker = threading.Condition(self._lock)
     self._worker_to_pool = threading.Condition(self._lock)
     self._worker_class = worker_class
     self._pool_to_worker = threading.Condition(self._lock)
     self._worker_to_pool = threading.Condition(self._lock)
     self._worker_class = worker_class
+    self._name = name
     self._last_worker_id = 0
     self._workers = []
     self._quiescing = False
     self._last_worker_id = 0
     self._workers = []
     self._quiescing = False
+    self._active = True
 
     # Terminating workers
     self._termworkers = []
 
     # Queued tasks
 
     # Terminating workers
     self._termworkers = []
 
     # Queued tasks
-    self._tasks = collections.deque()
+    self._counter = itertools.count()
+    self._tasks = []
+    self._taskdata = {}
 
     # Start workers
     self.Resize(num_workers)
 
   # TODO: Implement dynamic resizing?
 
 
     # Start workers
     self.Resize(num_workers)
 
   # TODO: Implement dynamic resizing?
 
-  def AddTask(self, *args):
+  def _WaitWhileQuiescingUnlocked(self):
+    """Wait until the worker pool has finished quiescing.
+
+    """
+    while self._quiescing:
+      self._pool_to_pool.wait()
+
+  def _AddTaskUnlocked(self, args, priority, task_id):
+    """Adds a task to the internal queue.
+
+    @type args: sequence
+    @param args: Arguments passed to L{BaseWorker.RunTask}
+    @type priority: number
+    @param priority: Task priority
+    @param task_id: Task ID
+
+    """
+    assert isinstance(args, (tuple, list)), "Arguments must be a sequence"
+    assert isinstance(priority, (int, long)), "Priority must be numeric"
+    assert task_id is None or isinstance(task_id, (int, long)), \
+      "Task ID must be numeric or None"
+
+    task = [priority, self._counter.next(), task_id, args]
+
+    if task_id is not None:
+      assert task_id not in self._taskdata
+      # Keep a reference to change priority later if necessary
+      self._taskdata[task_id] = task
+
+    # A counter is used to ensure elements are processed in their incoming
+    # order. For processing they're sorted by priority and then counter.
+    heapq.heappush(self._tasks, task)
+
+    # Notify a waiting worker
+    self._pool_to_worker.notify()
+
+  def AddTask(self, args, priority=_DEFAULT_PRIORITY, task_id=None):
     """Adds a task to the queue.
 
     """Adds a task to the queue.
 
+    @type args: sequence
     @param args: arguments passed to L{BaseWorker.RunTask}
     @param args: arguments passed to L{BaseWorker.RunTask}
+    @type priority: number
+    @param priority: Task priority
+    @param task_id: Task ID
+    @note: The task ID can be essentially anything that can be used as a
+      dictionary key. Callers, however, must ensure a task ID is unique while a
+      task is in the pool or while it might return to the pool due to deferring
+      using L{DeferTask}.
 
     """
     self._lock.acquire()
     try:
 
     """
     self._lock.acquire()
     try:
-      # Don't add new tasks while we're quiescing
-      while self._quiescing:
-        self._pool_to_pool.wait()
+      self._WaitWhileQuiescingUnlocked()
+      self._AddTaskUnlocked(args, priority, task_id)
+    finally:
+      self._lock.release()
 
 
-      # Add task to internal queue
-      self._tasks.append(args)
+  def AddManyTasks(self, tasks, priority=_DEFAULT_PRIORITY, task_id=None):
+    """Add a list of tasks to the queue.
 
 
-      # Wake one idling worker up
-      self._pool_to_worker.notify()
+    @type tasks: list of tuples
+    @param tasks: list of args passed to L{BaseWorker.RunTask}
+    @type priority: number or list of numbers
+    @param priority: Priority for all added tasks or a list with the priority
+                     for each task
+    @type task_id: list
+    @param task_id: List with the ID for each task
+    @note: See L{AddTask} for a note on task IDs.
+
+    """
+    assert compat.all(isinstance(task, (tuple, list)) for task in tasks), \
+           "Each task must be a sequence"
+    assert (isinstance(priority, (int, long)) or
+            compat.all(isinstance(prio, (int, long)) for prio in priority)), \
+           "Priority must be numeric or be a list of numeric values"
+    assert task_id is None or isinstance(task_id, (tuple, list)), \
+           "Task IDs must be in a sequence"
+
+    if isinstance(priority, (int, long)):
+      priority = [priority] * len(tasks)
+    elif len(priority) != len(tasks):
+      raise errors.ProgrammerError("Number of priorities (%s) doesn't match"
+                                   " number of tasks (%s)" %
+                                   (len(priority), len(tasks)))
+
+    if task_id is None:
+      task_id = [None] * len(tasks)
+    elif len(task_id) != len(tasks):
+      raise errors.ProgrammerError("Number of task IDs (%s) doesn't match"
+                                   " number of tasks (%s)" %
+                                   (len(task_id), len(tasks)))
+
+    self._lock.acquire()
+    try:
+      self._WaitWhileQuiescingUnlocked()
+
+      assert compat.all(isinstance(prio, (int, long)) for prio in priority)
+      assert len(tasks) == len(priority)
+      assert len(tasks) == len(task_id)
+
+      for (args, prio, tid) in zip(tasks, priority, task_id):
+        self._AddTaskUnlocked(args, prio, tid)
     finally:
       self._lock.release()
 
     finally:
       self._lock.release()
 
-  def _ShouldWorkerTerminateUnlocked(self, worker):
-    """Returns whether a worker should terminate.
+  def ChangeTaskPriority(self, task_id, priority):
+    """Changes a task's priority.
+
+    @param task_id: Task ID
+    @type priority: number
+    @param priority: New task priority
+    @raise NoSuchTask: When the task referred by C{task_id} can not be found
+      (it may never have existed, may have already been processed, or is
+      currently running)
 
     """
 
     """
-    return (worker in self._termworkers)
+    assert isinstance(priority, (int, long)), "Priority must be numeric"
 
 
-  def ShouldWorkerTerminate(self, worker):
-    """Returns whether a worker should terminate.
+    self._lock.acquire()
+    try:
+      logging.debug("About to change priority of task %s to %s",
+                    task_id, priority)
+
+      # Find old task
+      oldtask = self._taskdata.get(task_id, None)
+      if oldtask is None:
+        msg = "Task '%s' was not found" % task_id
+        logging.debug(msg)
+        raise NoSuchTask(msg)
+
+      # Prepare new task
+      newtask = [priority] + oldtask[1:]
+
+      # Mark old entry as abandoned (this doesn't change the sort order and
+      # therefore doesn't invalidate the heap property of L{self._tasks}).
+      # See also <http://docs.python.org/library/heapq.html#priority-queue-
+      # implementation-notes>.
+      oldtask[-1] = None
+
+      # Change reference to new task entry and forget the old one
+      assert task_id is not None
+      self._taskdata[task_id] = newtask
+
+      # Add a new task with the old number and arguments
+      heapq.heappush(self._tasks, newtask)
+
+      # Notify a waiting worker
+      self._pool_to_worker.notify()
+    finally:
+      self._lock.release()
+
+  def SetActive(self, active):
+    """Enable/disable processing of tasks.
+
+    This is different from L{Quiesce} in the sense that this function just
+    changes an internal flag and doesn't wait for the queue to be empty. Tasks
+    already being processed continue normally, but no new tasks will be
+    started. New tasks can still be added.
+
+    @type active: bool
+    @param active: Whether tasks should be processed
 
     """
     self._lock.acquire()
     try:
 
     """
     self._lock.acquire()
     try:
-      return self._ShouldWorkerTerminateUnlocked(worker)
+      self._active = active
+
+      if active:
+        # Tell all workers to continue processing
+        self._pool_to_worker.notifyAll()
     finally:
       self._lock.release()
 
     finally:
       self._lock.release()
 
+  def _WaitForTaskUnlocked(self, worker):
+    """Waits for a task for a worker.
+
+    @type worker: L{BaseWorker}
+    @param worker: Worker thread
+
+    """
+    while True:
+      if self._ShouldWorkerTerminateUnlocked(worker):
+        return _TERMINATE
+
+      # If there's a pending task, return it immediately
+      if self._active and self._tasks:
+        # Get task from queue and tell pool about it
+        try:
+          task = heapq.heappop(self._tasks)
+        finally:
+          self._worker_to_pool.notifyAll()
+
+        (_, _, task_id, args) = task
+
+        # If the priority was changed, "args" is None
+        if args is None:
+          # Try again
+          logging.debug("Found abandoned task (%r)", task)
+          continue
+
+        # Delete reference
+        if task_id is not None:
+          del self._taskdata[task_id]
+
+        return task
+
+      logging.debug("Waiting for tasks")
+
+      # wait() releases the lock and sleeps until notified
+      self._pool_to_worker.wait()
+
+      logging.debug("Notified while waiting")
+
+  def _ShouldWorkerTerminateUnlocked(self, worker):
+    """Returns whether a worker should terminate.
+
+    """
+    return (worker in self._termworkers)
+
   def _HasRunningTasksUnlocked(self):
     """Checks whether there's a task running in a worker.
 
     """
     for worker in self._workers + self._termworkers:
   def _HasRunningTasksUnlocked(self):
     """Checks whether there's a task running in a worker.
 
     """
     for worker in self._workers + self._termworkers:
-      if worker._HasRunningTaskUnlocked():
+      if worker._HasRunningTaskUnlocked(): # pylint: disable=W0212
         return True
     return False
 
         return True
     return False
 
+  def HasRunningTasks(self):
+    """Checks whether there's at least one task running.
+
+    """
+    self._lock.acquire()
+    try:
+      return self._HasRunningTasksUnlocked()
+    finally:
+      self._lock.release()
+
   def Quiesce(self):
     """Waits until the task queue is empty.
 
   def Quiesce(self):
     """Waits until the task queue is empty.
 
@@ -252,7 +559,8 @@ class WorkerPool(object):
 
     """
     self._last_worker_id += 1
 
     """
     self._last_worker_id += 1
-    return self._last_worker_id
+
+    return "%s%d" % (self._name, self._last_worker_id)
 
   def _ResizeUnlocked(self, num_workers):
     """Changes the number of workers.
 
   def _ResizeUnlocked(self, num_workers):
     """Changes the number of workers.
@@ -305,7 +613,7 @@ class WorkerPool(object):
 
     elif current_count < num_workers:
       # Create (num_workers - current_count) new workers
 
     elif current_count < num_workers:
       # Create (num_workers - current_count) new workers
-      for _ in xrange(num_workers - current_count):
+      for _ in range(num_workers - current_count):
         worker = self._worker_class(self, self._NewWorkerIdUnlocked())
         self._workers.append(worker)
         worker.start()
         worker = self._worker_class(self, self._NewWorkerIdUnlocked())
         self._workers.append(worker)
         worker.start()