4 # Copyright (C) 2008, 2009, 2010 Google Inc.
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 # General Public License for more details.
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
22 """Base classes for worker pools.
30 from ganeti import compat
31 from ganeti import errors
38 class DeferTask(Exception):
39 """Special exception class to defer a task.
41 This class can be raised by L{BaseWorker.RunTask} to defer the execution of a
42 task. Optionally, the priority of the task can be changed.
45 def __init__(self, priority=None):
46 """Initializes this class.
48 @type priority: number
49 @param priority: New task priority (None means no change)
52 Exception.__init__(self)
53 self.priority = priority
56 class BaseWorker(threading.Thread, object):
57 """Base worker class for worker pools.
59 Users of a worker pool must override RunTask in a subclass.
62 # pylint: disable-msg=W0212
63 def __init__(self, pool, worker_id):
64 """Constructor for BaseWorker thread.
66 @param pool: the parent worker pool
67 @param worker_id: identifier for this worker
70 super(BaseWorker, self).__init__(name=worker_id)
72 self._worker_id = worker_id
73 self._current_task = None
75 assert self.getName() == worker_id
77 def ShouldTerminate(self):
78 """Returns whether this worker should terminate.
80 Should only be called from within L{RunTask}.
83 self.pool._lock.acquire()
85 assert self._HasRunningTaskUnlocked()
86 return self.pool._ShouldWorkerTerminateUnlocked(self)
88 self.pool._lock.release()
90 def GetCurrentPriority(self):
91 """Returns the priority of the current task.
93 Should only be called from within L{RunTask}.
96 self.pool._lock.acquire()
98 assert self._HasRunningTaskUnlocked()
100 (priority, _, _) = self._current_task
104 self.pool._lock.release()
106 def SetTaskName(self, taskname):
107 """Sets the name of the current task.
109 Should only be called from within L{RunTask}.
111 @type taskname: string
112 @param taskname: Task's name
116 name = "%s/%s" % (self._worker_id, taskname)
118 name = self._worker_id
123 def _HasRunningTaskUnlocked(self):
124 """Returns whether this worker is currently running a task.
127 return (self._current_task is not None)
130 """Main thread function.
132 Waits for new tasks to show up in the queue.
138 assert self._current_task is None
142 # Wait on lock to be told either to terminate or to do a task
145 task = pool._WaitForTaskUnlocked(self)
147 if task is _TERMINATE:
152 # Spurious notification, ignore
155 self._current_task = task
157 # No longer needed, dispose of reference
160 assert self._HasRunningTaskUnlocked()
165 (priority, _, args) = self._current_task
167 # Run the actual task
169 logging.debug("Starting task %r, priority %s", args, priority)
170 assert self.getName() == self._worker_id
172 self.RunTask(*args) # pylint: disable-msg=W0142
174 self.SetTaskName(None)
175 logging.debug("Done with task %r, priority %s", args, priority)
176 except DeferTask, err:
179 if defer.priority is None:
181 defer.priority = priority
183 logging.debug("Deferring task %r, new priority %s",
184 args, defer.priority)
186 assert self._HasRunningTaskUnlocked()
187 except: # pylint: disable-msg=W0702
188 logging.exception("Caught unhandled exception")
190 assert self._HasRunningTaskUnlocked()
196 assert self._current_task
197 # Schedule again for later run
198 (_, _, args) = self._current_task
199 pool._AddTaskUnlocked(args, defer.priority)
201 if self._current_task:
202 self._current_task = None
203 pool._worker_to_pool.notifyAll()
207 assert not self._HasRunningTaskUnlocked()
209 logging.debug("Terminates")
211 def RunTask(self, *args):
212 """Function called to start a task.
214 This needs to be implemented by child classes.
217 raise NotImplementedError()
220 class WorkerPool(object):
221 """Worker pool with a queue.
223 This class is thread-safe.
225 Tasks are guaranteed to be started in the order in which they're
226 added to the pool. Due to the nature of threading, they're not
227 guaranteed to finish in the same order.
230 def __init__(self, name, num_workers, worker_class):
231 """Constructor for worker pool.
233 @param num_workers: number of workers to be started
234 (dynamic resizing is not yet implemented)
235 @param worker_class: the class to be instantiated for workers;
236 should derive from L{BaseWorker}
239 # Some of these variables are accessed by BaseWorker
240 self._lock = threading.Lock()
241 self._pool_to_pool = threading.Condition(self._lock)
242 self._pool_to_worker = threading.Condition(self._lock)
243 self._worker_to_pool = threading.Condition(self._lock)
244 self._worker_class = worker_class
246 self._last_worker_id = 0
248 self._quiescing = False
250 # Terminating workers
251 self._termworkers = []
258 self.Resize(num_workers)
260 # TODO: Implement dynamic resizing?
262 def _WaitWhileQuiescingUnlocked(self):
263 """Wait until the worker pool has finished quiescing.
266 while self._quiescing:
267 self._pool_to_pool.wait()
269 def _AddTaskUnlocked(self, args, priority):
270 """Adds a task to the internal queue.
273 @param args: Arguments passed to L{BaseWorker.RunTask}
274 @type priority: number
275 @param priority: Task priority
278 assert isinstance(args, (tuple, list)), "Arguments must be a sequence"
279 assert isinstance(priority, (int, long)), "Priority must be numeric"
281 # This counter is used to ensure elements are processed in their
282 # incoming order. For processing they're sorted by priority and then
286 heapq.heappush(self._tasks, (priority, self._counter, args))
288 # Notify a waiting worker
289 self._pool_to_worker.notify()
291 def AddTask(self, args, priority=_DEFAULT_PRIORITY):
292 """Adds a task to the queue.
295 @param args: arguments passed to L{BaseWorker.RunTask}
296 @type priority: number
297 @param priority: Task priority
302 self._WaitWhileQuiescingUnlocked()
303 self._AddTaskUnlocked(args, priority)
307 def AddManyTasks(self, tasks, priority=_DEFAULT_PRIORITY):
308 """Add a list of tasks to the queue.
310 @type tasks: list of tuples
311 @param tasks: list of args passed to L{BaseWorker.RunTask}
312 @type priority: number or list of numbers
313 @param priority: Priority for all added tasks or a list with the priority
317 assert compat.all(isinstance(task, (tuple, list)) for task in tasks), \
318 "Each task must be a sequence"
320 assert (isinstance(priority, (int, long)) or
321 compat.all(isinstance(prio, (int, long)) for prio in priority)), \
322 "Priority must be numeric or be a list of numeric values"
324 if isinstance(priority, (int, long)):
325 priority = [priority] * len(tasks)
326 elif len(priority) != len(tasks):
327 raise errors.ProgrammerError("Number of priorities (%s) doesn't match"
328 " number of tasks (%s)" %
329 (len(priority), len(tasks)))
333 self._WaitWhileQuiescingUnlocked()
335 assert compat.all(isinstance(prio, (int, long)) for prio in priority)
336 assert len(tasks) == len(priority)
338 for args, priority in zip(tasks, priority):
339 self._AddTaskUnlocked(args, priority)
343 def _WaitForTaskUnlocked(self, worker):
344 """Waits for a task for a worker.
346 @type worker: L{BaseWorker}
347 @param worker: Worker thread
350 if self._ShouldWorkerTerminateUnlocked(worker):
353 # We only wait if there's no task for us.
355 logging.debug("Waiting for tasks")
357 # wait() releases the lock and sleeps until notified
358 self._pool_to_worker.wait()
360 logging.debug("Notified while waiting")
362 # Were we woken up in order to terminate?
363 if self._ShouldWorkerTerminateUnlocked(worker):
367 # Spurious notification, ignore
370 # Get task from queue and tell pool about it
372 return heapq.heappop(self._tasks)
374 self._worker_to_pool.notifyAll()
376 def _ShouldWorkerTerminateUnlocked(self, worker):
377 """Returns whether a worker should terminate.
380 return (worker in self._termworkers)
382 def _HasRunningTasksUnlocked(self):
383 """Checks whether there's a task running in a worker.
386 for worker in self._workers + self._termworkers:
387 if worker._HasRunningTaskUnlocked(): # pylint: disable-msg=W0212
392 """Waits until the task queue is empty.
397 self._quiescing = True
399 # Wait while there are tasks pending or running
400 while self._tasks or self._HasRunningTasksUnlocked():
401 self._worker_to_pool.wait()
404 self._quiescing = False
406 # Make sure AddTasks continues in case it was waiting
407 self._pool_to_pool.notifyAll()
411 def _NewWorkerIdUnlocked(self):
412 """Return an identifier for a new worker.
415 self._last_worker_id += 1
417 return "%s%d" % (self._name, self._last_worker_id)
419 def _ResizeUnlocked(self, num_workers):
420 """Changes the number of workers.
423 assert num_workers >= 0, "num_workers must be >= 0"
425 logging.debug("Resizing to %s workers", num_workers)
427 current_count = len(self._workers)
429 if current_count == num_workers:
433 elif current_count > num_workers:
435 # Create copy of list to iterate over while lock isn't held.
436 termworkers = self._workers[:]
439 # TODO: Implement partial downsizing
440 raise NotImplementedError()
443 self._termworkers += termworkers
445 # Notify workers that something has changed
446 self._pool_to_worker.notifyAll()
448 # Join all terminating workers
451 for worker in termworkers:
452 logging.debug("Waiting for thread %s", worker.getName())
457 # Remove terminated threads. This could be done in a more efficient way
458 # (del self._termworkers[:]), but checking worker.isAlive() makes sure we
459 # don't leave zombie threads around.
460 for worker in termworkers:
461 assert worker in self._termworkers, ("Worker not in list of"
462 " terminating workers")
463 if not worker.isAlive():
464 self._termworkers.remove(worker)
466 assert not self._termworkers, "Zombie worker detected"
468 elif current_count < num_workers:
469 # Create (num_workers - current_count) new workers
470 for _ in range(num_workers - current_count):
471 worker = self._worker_class(self, self._NewWorkerIdUnlocked())
472 self._workers.append(worker)
475 def Resize(self, num_workers):
476 """Changes the number of workers in the pool.
478 @param num_workers: the new number of workers
483 return self._ResizeUnlocked(num_workers)
487 def TerminateWorkers(self):
488 """Terminate all worker threads.
490 Unstarted tasks will be ignored.
493 logging.debug("Terminating all workers")
497 self._ResizeUnlocked(0)
500 logging.debug("There are %s tasks left", len(self._tasks))
504 logging.debug("All workers terminated")