4 # Copyright (C) 2008, 2009, 2010 Google Inc.
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 # General Public License for more details.
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
22 """Base classes for worker pools.
31 from ganeti import compat
32 from ganeti import errors
39 class DeferTask(Exception):
40 """Special exception class to defer a task.
42 This class can be raised by L{BaseWorker.RunTask} to defer the execution of a
43 task. Optionally, the priority of the task can be changed.
46 def __init__(self, priority=None):
47 """Initializes this class.
49 @type priority: number
50 @param priority: New task priority (None means no change)
53 Exception.__init__(self)
54 self.priority = priority
57 class NoSuchTask(Exception):
58 """Exception raised when a task can't be found.
63 class BaseWorker(threading.Thread, object):
64 """Base worker class for worker pools.
66 Users of a worker pool must override RunTask in a subclass.
69 # pylint: disable=W0212
70 def __init__(self, pool, worker_id):
71 """Constructor for BaseWorker thread.
73 @param pool: the parent worker pool
74 @param worker_id: identifier for this worker
77 super(BaseWorker, self).__init__(name=worker_id)
79 self._worker_id = worker_id
80 self._current_task = None
82 assert self.getName() == worker_id
84 def ShouldTerminate(self):
85 """Returns whether this worker should terminate.
87 Should only be called from within L{RunTask}.
90 self.pool._lock.acquire()
92 assert self._HasRunningTaskUnlocked()
93 return self.pool._ShouldWorkerTerminateUnlocked(self)
95 self.pool._lock.release()
97 def GetCurrentPriority(self):
98 """Returns the priority of the current task.
100 Should only be called from within L{RunTask}.
103 self.pool._lock.acquire()
105 assert self._HasRunningTaskUnlocked()
107 (priority, _, _, _) = self._current_task
111 self.pool._lock.release()
113 def SetTaskName(self, taskname):
114 """Sets the name of the current task.
116 Should only be called from within L{RunTask}.
118 @type taskname: string
119 @param taskname: Task's name
123 name = "%s/%s" % (self._worker_id, taskname)
125 name = self._worker_id
130 def _HasRunningTaskUnlocked(self):
131 """Returns whether this worker is currently running a task.
134 return (self._current_task is not None)
136 def _GetCurrentOrderAndTaskId(self):
137 """Returns the order and task ID of the current task.
139 Should only be called from within L{RunTask}.
142 self.pool._lock.acquire()
144 assert self._HasRunningTaskUnlocked()
146 (_, order_id, task_id, _) = self._current_task
148 return (order_id, task_id)
150 self.pool._lock.release()
153 """Main thread function.
155 Waits for new tasks to show up in the queue.
161 assert self._current_task is None
165 # Wait on lock to be told either to terminate or to do a task
168 task = pool._WaitForTaskUnlocked(self)
170 if task is _TERMINATE:
175 # Spurious notification, ignore
178 self._current_task = task
180 # No longer needed, dispose of reference
183 assert self._HasRunningTaskUnlocked()
188 (priority, _, _, args) = self._current_task
190 # Run the actual task
192 logging.debug("Starting task %r, priority %s", args, priority)
193 assert self.getName() == self._worker_id
195 self.RunTask(*args) # pylint: disable=W0142
197 self.SetTaskName(None)
198 logging.debug("Done with task %r, priority %s", args, priority)
199 except DeferTask, err:
202 if defer.priority is None:
204 defer.priority = priority
206 logging.debug("Deferring task %r, new priority %s",
207 args, defer.priority)
209 assert self._HasRunningTaskUnlocked()
210 except: # pylint: disable=W0702
211 logging.exception("Caught unhandled exception")
213 assert self._HasRunningTaskUnlocked()
219 assert self._current_task
220 # Schedule again for later run
221 (_, _, task_id, args) = self._current_task
222 pool._AddTaskUnlocked(args, defer.priority, task_id)
224 if self._current_task:
225 self._current_task = None
226 pool._worker_to_pool.notifyAll()
230 assert not self._HasRunningTaskUnlocked()
232 logging.debug("Terminates")
234 def RunTask(self, *args):
235 """Function called to start a task.
237 This needs to be implemented by child classes.
240 raise NotImplementedError()
243 class WorkerPool(object):
244 """Worker pool with a queue.
246 This class is thread-safe.
248 Tasks are guaranteed to be started in the order in which they're
249 added to the pool. Due to the nature of threading, they're not
250 guaranteed to finish in the same order.
252 @type _tasks: list of tuples
253 @ivar _tasks: Each tuple has the format (priority, order ID, task ID,
254 arguments). Priority and order ID are numeric and essentially control the
255 sort order. The order ID is an increasing number denoting the order in
256 which tasks are added to the queue. The task ID is controlled by user of
257 workerpool, see L{AddTask} for details. The task arguments are C{None} for
258 abandoned tasks, otherwise a sequence of arguments to be passed to
259 L{BaseWorker.RunTask}). The list must fulfill the heap property (for use by
260 the C{heapq} module).
261 @type _taskdata: dict; (task IDs as keys, tuples as values)
262 @ivar _taskdata: Mapping from task IDs to entries in L{_tasks}
265 def __init__(self, name, num_workers, worker_class):
266 """Constructor for worker pool.
268 @param num_workers: number of workers to be started
269 (dynamic resizing is not yet implemented)
270 @param worker_class: the class to be instantiated for workers;
271 should derive from L{BaseWorker}
274 # Some of these variables are accessed by BaseWorker
275 self._lock = threading.Lock()
276 self._pool_to_pool = threading.Condition(self._lock)
277 self._pool_to_worker = threading.Condition(self._lock)
278 self._worker_to_pool = threading.Condition(self._lock)
279 self._worker_class = worker_class
281 self._last_worker_id = 0
283 self._quiescing = False
286 # Terminating workers
287 self._termworkers = []
290 self._counter = itertools.count()
295 self.Resize(num_workers)
297 # TODO: Implement dynamic resizing?
299 def _WaitWhileQuiescingUnlocked(self):
300 """Wait until the worker pool has finished quiescing.
303 while self._quiescing:
304 self._pool_to_pool.wait()
306 def _AddTaskUnlocked(self, args, priority, task_id):
307 """Adds a task to the internal queue.
310 @param args: Arguments passed to L{BaseWorker.RunTask}
311 @type priority: number
312 @param priority: Task priority
313 @param task_id: Task ID
316 assert isinstance(args, (tuple, list)), "Arguments must be a sequence"
317 assert isinstance(priority, (int, long)), "Priority must be numeric"
318 assert task_id is None or isinstance(task_id, (int, long)), \
319 "Task ID must be numeric or None"
321 task = [priority, self._counter.next(), task_id, args]
323 if task_id is not None:
324 assert task_id not in self._taskdata
325 # Keep a reference to change priority later if necessary
326 self._taskdata[task_id] = task
328 # A counter is used to ensure elements are processed in their incoming
329 # order. For processing they're sorted by priority and then counter.
330 heapq.heappush(self._tasks, task)
332 # Notify a waiting worker
333 self._pool_to_worker.notify()
335 def AddTask(self, args, priority=_DEFAULT_PRIORITY, task_id=None):
336 """Adds a task to the queue.
339 @param args: arguments passed to L{BaseWorker.RunTask}
340 @type priority: number
341 @param priority: Task priority
342 @param task_id: Task ID
343 @note: The task ID can be essentially anything that can be used as a
344 dictionary key. Callers, however, must ensure a task ID is unique while a
345 task is in the pool or while it might return to the pool due to deferring
351 self._WaitWhileQuiescingUnlocked()
352 self._AddTaskUnlocked(args, priority, task_id)
356 def AddManyTasks(self, tasks, priority=_DEFAULT_PRIORITY, task_id=None):
357 """Add a list of tasks to the queue.
359 @type tasks: list of tuples
360 @param tasks: list of args passed to L{BaseWorker.RunTask}
361 @type priority: number or list of numbers
362 @param priority: Priority for all added tasks or a list with the priority
365 @param task_id: List with the ID for each task
366 @note: See L{AddTask} for a note on task IDs.
369 assert compat.all(isinstance(task, (tuple, list)) for task in tasks), \
370 "Each task must be a sequence"
371 assert (isinstance(priority, (int, long)) or
372 compat.all(isinstance(prio, (int, long)) for prio in priority)), \
373 "Priority must be numeric or be a list of numeric values"
374 assert task_id is None or isinstance(task_id, (tuple, list)), \
375 "Task IDs must be in a sequence"
377 if isinstance(priority, (int, long)):
378 priority = [priority] * len(tasks)
379 elif len(priority) != len(tasks):
380 raise errors.ProgrammerError("Number of priorities (%s) doesn't match"
381 " number of tasks (%s)" %
382 (len(priority), len(tasks)))
385 task_id = [None] * len(tasks)
386 elif len(task_id) != len(tasks):
387 raise errors.ProgrammerError("Number of task IDs (%s) doesn't match"
388 " number of tasks (%s)" %
389 (len(task_id), len(tasks)))
393 self._WaitWhileQuiescingUnlocked()
395 assert compat.all(isinstance(prio, (int, long)) for prio in priority)
396 assert len(tasks) == len(priority)
397 assert len(tasks) == len(task_id)
399 for (args, prio, tid) in zip(tasks, priority, task_id):
400 self._AddTaskUnlocked(args, prio, tid)
404 def ChangeTaskPriority(self, task_id, priority):
405 """Changes a task's priority.
407 @param task_id: Task ID
408 @type priority: number
409 @param priority: New task priority
410 @raise NoSuchTask: When the task referred by C{task_id} can not be found
411 (it may never have existed, may have already been processed, or is
415 assert isinstance(priority, (int, long)), "Priority must be numeric"
419 logging.debug("About to change priority of task %s to %s",
423 oldtask = self._taskdata.get(task_id, None)
425 msg = "Task '%s' was not found" % task_id
427 raise NoSuchTask(msg)
430 newtask = [priority] + oldtask[1:]
432 # Mark old entry as abandoned (this doesn't change the sort order and
433 # therefore doesn't invalidate the heap property of L{self._tasks}).
434 # See also <http://docs.python.org/library/heapq.html#priority-queue-
435 # implementation-notes>.
438 # Change reference to new task entry and forget the old one
439 assert task_id is not None
440 self._taskdata[task_id] = newtask
442 # Add a new task with the old number and arguments
443 heapq.heappush(self._tasks, newtask)
445 # Notify a waiting worker
446 self._pool_to_worker.notify()
450 def SetActive(self, active):
451 """Enable/disable processing of tasks.
453 This is different from L{Quiesce} in the sense that this function just
454 changes an internal flag and doesn't wait for the queue to be empty. Tasks
455 already being processed continue normally, but no new tasks will be
456 started. New tasks can still be added.
459 @param active: Whether tasks should be processed
464 self._active = active
467 # Tell all workers to continue processing
468 self._pool_to_worker.notifyAll()
472 def _WaitForTaskUnlocked(self, worker):
473 """Waits for a task for a worker.
475 @type worker: L{BaseWorker}
476 @param worker: Worker thread
480 if self._ShouldWorkerTerminateUnlocked(worker):
483 # If there's a pending task, return it immediately
484 if self._active and self._tasks:
485 # Get task from queue and tell pool about it
487 task = heapq.heappop(self._tasks)
489 self._worker_to_pool.notifyAll()
491 (_, _, task_id, args) = task
493 # If the priority was changed, "args" is None
496 logging.debug("Found abandoned task (%r)", task)
500 if task_id is not None:
501 del self._taskdata[task_id]
505 logging.debug("Waiting for tasks")
507 # wait() releases the lock and sleeps until notified
508 self._pool_to_worker.wait()
510 logging.debug("Notified while waiting")
512 def _ShouldWorkerTerminateUnlocked(self, worker):
513 """Returns whether a worker should terminate.
516 return (worker in self._termworkers)
518 def _HasRunningTasksUnlocked(self):
519 """Checks whether there's a task running in a worker.
522 for worker in self._workers + self._termworkers:
523 if worker._HasRunningTaskUnlocked(): # pylint: disable=W0212
527 def HasRunningTasks(self):
528 """Checks whether there's at least one task running.
533 return self._HasRunningTasksUnlocked()
538 """Waits until the task queue is empty.
543 self._quiescing = True
545 # Wait while there are tasks pending or running
546 while self._tasks or self._HasRunningTasksUnlocked():
547 self._worker_to_pool.wait()
550 self._quiescing = False
552 # Make sure AddTasks continues in case it was waiting
553 self._pool_to_pool.notifyAll()
557 def _NewWorkerIdUnlocked(self):
558 """Return an identifier for a new worker.
561 self._last_worker_id += 1
563 return "%s%d" % (self._name, self._last_worker_id)
565 def _ResizeUnlocked(self, num_workers):
566 """Changes the number of workers.
569 assert num_workers >= 0, "num_workers must be >= 0"
571 logging.debug("Resizing to %s workers", num_workers)
573 current_count = len(self._workers)
575 if current_count == num_workers:
579 elif current_count > num_workers:
581 # Create copy of list to iterate over while lock isn't held.
582 termworkers = self._workers[:]
585 # TODO: Implement partial downsizing
586 raise NotImplementedError()
589 self._termworkers += termworkers
591 # Notify workers that something has changed
592 self._pool_to_worker.notifyAll()
594 # Join all terminating workers
597 for worker in termworkers:
598 logging.debug("Waiting for thread %s", worker.getName())
603 # Remove terminated threads. This could be done in a more efficient way
604 # (del self._termworkers[:]), but checking worker.isAlive() makes sure we
605 # don't leave zombie threads around.
606 for worker in termworkers:
607 assert worker in self._termworkers, ("Worker not in list of"
608 " terminating workers")
609 if not worker.isAlive():
610 self._termworkers.remove(worker)
612 assert not self._termworkers, "Zombie worker detected"
614 elif current_count < num_workers:
615 # Create (num_workers - current_count) new workers
616 for _ in range(num_workers - current_count):
617 worker = self._worker_class(self, self._NewWorkerIdUnlocked())
618 self._workers.append(worker)
621 def Resize(self, num_workers):
622 """Changes the number of workers in the pool.
624 @param num_workers: the new number of workers
629 return self._ResizeUnlocked(num_workers)
633 def TerminateWorkers(self):
634 """Terminate all worker threads.
636 Unstarted tasks will be ignored.
639 logging.debug("Terminating all workers")
643 self._ResizeUnlocked(0)
646 logging.debug("There are %s tasks left", len(self._tasks))
650 logging.debug("All workers terminated")