4 # Copyright (C) 2008, 2009, 2010 Google Inc.
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 # General Public License for more details.
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
22 """Base classes for worker pools.
31 from ganeti import compat
32 from ganeti import errors
39 class DeferTask(Exception):
40 """Special exception class to defer a task.
42 This class can be raised by L{BaseWorker.RunTask} to defer the execution of a
43 task. Optionally, the priority of the task can be changed.
46 def __init__(self, priority=None):
47 """Initializes this class.
49 @type priority: number
50 @param priority: New task priority (None means no change)
53 Exception.__init__(self)
54 self.priority = priority
57 class BaseWorker(threading.Thread, object):
58 """Base worker class for worker pools.
60 Users of a worker pool must override RunTask in a subclass.
63 # pylint: disable=W0212
64 def __init__(self, pool, worker_id):
65 """Constructor for BaseWorker thread.
67 @param pool: the parent worker pool
68 @param worker_id: identifier for this worker
71 super(BaseWorker, self).__init__(name=worker_id)
73 self._worker_id = worker_id
74 self._current_task = None
76 assert self.getName() == worker_id
78 def ShouldTerminate(self):
79 """Returns whether this worker should terminate.
81 Should only be called from within L{RunTask}.
84 self.pool._lock.acquire()
86 assert self._HasRunningTaskUnlocked()
87 return self.pool._ShouldWorkerTerminateUnlocked(self)
89 self.pool._lock.release()
91 def GetCurrentPriority(self):
92 """Returns the priority of the current task.
94 Should only be called from within L{RunTask}.
97 self.pool._lock.acquire()
99 assert self._HasRunningTaskUnlocked()
101 (priority, _, _) = self._current_task
105 self.pool._lock.release()
107 def SetTaskName(self, taskname):
108 """Sets the name of the current task.
110 Should only be called from within L{RunTask}.
112 @type taskname: string
113 @param taskname: Task's name
117 name = "%s/%s" % (self._worker_id, taskname)
119 name = self._worker_id
124 def _HasRunningTaskUnlocked(self):
125 """Returns whether this worker is currently running a task.
128 return (self._current_task is not None)
131 """Main thread function.
133 Waits for new tasks to show up in the queue.
139 assert self._current_task is None
143 # Wait on lock to be told either to terminate or to do a task
146 task = pool._WaitForTaskUnlocked(self)
148 if task is _TERMINATE:
153 # Spurious notification, ignore
156 self._current_task = task
158 # No longer needed, dispose of reference
161 assert self._HasRunningTaskUnlocked()
166 (priority, _, args) = self._current_task
168 # Run the actual task
170 logging.debug("Starting task %r, priority %s", args, priority)
171 assert self.getName() == self._worker_id
173 self.RunTask(*args) # pylint: disable=W0142
175 self.SetTaskName(None)
176 logging.debug("Done with task %r, priority %s", args, priority)
177 except DeferTask, err:
180 if defer.priority is None:
182 defer.priority = priority
184 logging.debug("Deferring task %r, new priority %s",
185 args, defer.priority)
187 assert self._HasRunningTaskUnlocked()
188 except: # pylint: disable=W0702
189 logging.exception("Caught unhandled exception")
191 assert self._HasRunningTaskUnlocked()
197 assert self._current_task
198 # Schedule again for later run
199 (_, _, args) = self._current_task
200 pool._AddTaskUnlocked(args, defer.priority)
202 if self._current_task:
203 self._current_task = None
204 pool._worker_to_pool.notifyAll()
208 assert not self._HasRunningTaskUnlocked()
210 logging.debug("Terminates")
212 def RunTask(self, *args):
213 """Function called to start a task.
215 This needs to be implemented by child classes.
218 raise NotImplementedError()
221 class WorkerPool(object):
222 """Worker pool with a queue.
224 This class is thread-safe.
226 Tasks are guaranteed to be started in the order in which they're
227 added to the pool. Due to the nature of threading, they're not
228 guaranteed to finish in the same order.
231 def __init__(self, name, num_workers, worker_class):
232 """Constructor for worker pool.
234 @param num_workers: number of workers to be started
235 (dynamic resizing is not yet implemented)
236 @param worker_class: the class to be instantiated for workers;
237 should derive from L{BaseWorker}
240 # Some of these variables are accessed by BaseWorker
241 self._lock = threading.Lock()
242 self._pool_to_pool = threading.Condition(self._lock)
243 self._pool_to_worker = threading.Condition(self._lock)
244 self._worker_to_pool = threading.Condition(self._lock)
245 self._worker_class = worker_class
247 self._last_worker_id = 0
249 self._quiescing = False
252 # Terminating workers
253 self._termworkers = []
256 self._counter = itertools.count()
260 self.Resize(num_workers)
262 # TODO: Implement dynamic resizing?
264 def _WaitWhileQuiescingUnlocked(self):
265 """Wait until the worker pool has finished quiescing.
268 while self._quiescing:
269 self._pool_to_pool.wait()
271 def _AddTaskUnlocked(self, args, priority):
272 """Adds a task to the internal queue.
275 @param args: Arguments passed to L{BaseWorker.RunTask}
276 @type priority: number
277 @param priority: Task priority
280 assert isinstance(args, (tuple, list)), "Arguments must be a sequence"
281 assert isinstance(priority, (int, long)), "Priority must be numeric"
283 # A counter is used to ensure elements are processed in their incoming
284 # order. For processing they're sorted by priority and then counter.
285 heapq.heappush(self._tasks, (priority, self._counter.next(), args))
287 # Notify a waiting worker
288 self._pool_to_worker.notify()
290 def AddTask(self, args, priority=_DEFAULT_PRIORITY):
291 """Adds a task to the queue.
294 @param args: arguments passed to L{BaseWorker.RunTask}
295 @type priority: number
296 @param priority: Task priority
301 self._WaitWhileQuiescingUnlocked()
302 self._AddTaskUnlocked(args, priority)
306 def AddManyTasks(self, tasks, priority=_DEFAULT_PRIORITY):
307 """Add a list of tasks to the queue.
309 @type tasks: list of tuples
310 @param tasks: list of args passed to L{BaseWorker.RunTask}
311 @type priority: number or list of numbers
312 @param priority: Priority for all added tasks or a list with the priority
316 assert compat.all(isinstance(task, (tuple, list)) for task in tasks), \
317 "Each task must be a sequence"
319 assert (isinstance(priority, (int, long)) or
320 compat.all(isinstance(prio, (int, long)) for prio in priority)), \
321 "Priority must be numeric or be a list of numeric values"
323 if isinstance(priority, (int, long)):
324 priority = [priority] * len(tasks)
325 elif len(priority) != len(tasks):
326 raise errors.ProgrammerError("Number of priorities (%s) doesn't match"
327 " number of tasks (%s)" %
328 (len(priority), len(tasks)))
332 self._WaitWhileQuiescingUnlocked()
334 assert compat.all(isinstance(prio, (int, long)) for prio in priority)
335 assert len(tasks) == len(priority)
337 for args, prio in zip(tasks, priority):
338 self._AddTaskUnlocked(args, prio)
342 def SetActive(self, active):
343 """Enable/disable processing of tasks.
345 This is different from L{Quiesce} in the sense that this function just
346 changes an internal flag and doesn't wait for the queue to be empty. Tasks
347 already being processed continue normally, but no new tasks will be
348 started. New tasks can still be added.
351 @param active: Whether tasks should be processed
356 self._active = active
359 # Tell all workers to continue processing
360 self._pool_to_worker.notifyAll()
364 def _WaitForTaskUnlocked(self, worker):
365 """Waits for a task for a worker.
367 @type worker: L{BaseWorker}
368 @param worker: Worker thread
372 if self._ShouldWorkerTerminateUnlocked(worker):
375 # If there's a pending task, return it immediately
376 if self._active and self._tasks:
377 # Get task from queue and tell pool about it
379 task = heapq.heappop(self._tasks)
381 self._worker_to_pool.notifyAll()
385 logging.debug("Waiting for tasks")
387 # wait() releases the lock and sleeps until notified
388 self._pool_to_worker.wait()
390 logging.debug("Notified while waiting")
392 def _ShouldWorkerTerminateUnlocked(self, worker):
393 """Returns whether a worker should terminate.
396 return (worker in self._termworkers)
398 def _HasRunningTasksUnlocked(self):
399 """Checks whether there's a task running in a worker.
402 for worker in self._workers + self._termworkers:
403 if worker._HasRunningTaskUnlocked(): # pylint: disable=W0212
407 def HasRunningTasks(self):
408 """Checks whether there's at least one task running.
413 return self._HasRunningTasksUnlocked()
418 """Waits until the task queue is empty.
423 self._quiescing = True
425 # Wait while there are tasks pending or running
426 while self._tasks or self._HasRunningTasksUnlocked():
427 self._worker_to_pool.wait()
430 self._quiescing = False
432 # Make sure AddTasks continues in case it was waiting
433 self._pool_to_pool.notifyAll()
437 def _NewWorkerIdUnlocked(self):
438 """Return an identifier for a new worker.
441 self._last_worker_id += 1
443 return "%s%d" % (self._name, self._last_worker_id)
445 def _ResizeUnlocked(self, num_workers):
446 """Changes the number of workers.
449 assert num_workers >= 0, "num_workers must be >= 0"
451 logging.debug("Resizing to %s workers", num_workers)
453 current_count = len(self._workers)
455 if current_count == num_workers:
459 elif current_count > num_workers:
461 # Create copy of list to iterate over while lock isn't held.
462 termworkers = self._workers[:]
465 # TODO: Implement partial downsizing
466 raise NotImplementedError()
469 self._termworkers += termworkers
471 # Notify workers that something has changed
472 self._pool_to_worker.notifyAll()
474 # Join all terminating workers
477 for worker in termworkers:
478 logging.debug("Waiting for thread %s", worker.getName())
483 # Remove terminated threads. This could be done in a more efficient way
484 # (del self._termworkers[:]), but checking worker.isAlive() makes sure we
485 # don't leave zombie threads around.
486 for worker in termworkers:
487 assert worker in self._termworkers, ("Worker not in list of"
488 " terminating workers")
489 if not worker.isAlive():
490 self._termworkers.remove(worker)
492 assert not self._termworkers, "Zombie worker detected"
494 elif current_count < num_workers:
495 # Create (num_workers - current_count) new workers
496 for _ in range(num_workers - current_count):
497 worker = self._worker_class(self, self._NewWorkerIdUnlocked())
498 self._workers.append(worker)
501 def Resize(self, num_workers):
502 """Changes the number of workers in the pool.
504 @param num_workers: the new number of workers
509 return self._ResizeUnlocked(num_workers)
513 def TerminateWorkers(self):
514 """Terminate all worker threads.
516 Unstarted tasks will be ignored.
519 logging.debug("Terminating all workers")
523 self._ResizeUnlocked(0)
526 logging.debug("There are %s tasks left", len(self._tasks))
530 logging.debug("All workers terminated")