4 # Copyright (C) 2008 Google Inc.
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 # General Public License for more details.
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
22 """Base classes for worker pools.
31 class BaseWorker(threading.Thread, object):
32 """Base worker class for worker pools.
34 Users of a worker pool must override RunTask in a subclass.
37 def __init__(self, pool, worker_id):
38 """Constructor for BaseWorker thread.
40 @param pool: the parent worker pool
41 @param worker_id: identifier for this worker
44 super(BaseWorker, self).__init__()
46 self.worker_id = worker_id
47 self._current_task = None
49 def ShouldTerminate(self):
50 """Returns whether a worker should terminate.
53 return self.pool.ShouldWorkerTerminate(self)
55 def _HasRunningTaskUnlocked(self):
56 """Returns whether this worker is currently running a task.
59 return (self._current_task is not None)
61 def HasRunningTask(self):
62 """Returns whether this worker is currently running a task.
65 self.pool._lock.acquire()
67 return self._HasRunningTaskUnlocked()
69 self.pool._lock.release()
72 """Main thread function.
74 Waits for new tasks to show up in the queue.
79 assert not self.HasRunningTask()
83 # We wait on lock to be told either terminate or do a task.
86 if pool._ShouldWorkerTerminateUnlocked(self):
89 # We only wait if there's no task for us.
91 logging.debug("Worker %s: waiting for tasks", self.worker_id)
93 # wait() releases the lock and sleeps until notified
94 pool._pool_to_worker.wait()
96 logging.debug("Worker %s: notified while waiting", self.worker_id)
98 # Were we woken up in order to terminate?
99 if pool._ShouldWorkerTerminateUnlocked(self):
103 # Spurious notification, ignore
106 # Get task from queue and tell pool about it
108 self._current_task = pool._tasks.popleft()
110 pool._worker_to_pool.notifyAll()
114 # Run the actual task
116 logging.debug("Worker %s: starting task %r",
117 self.worker_id, self._current_task)
118 self.RunTask(*self._current_task)
119 logging.debug("Worker %s: done with task %r",
120 self.worker_id, self._current_task)
122 logging.error("Worker %s: Caught unhandled exception",
123 self.worker_id, exc_info=True)
128 if self._current_task:
129 self._current_task = None
130 pool._worker_to_pool.notifyAll()
134 logging.debug("Worker %s: terminates", self.worker_id)
136 def RunTask(self, *args):
137 """Function called to start a task.
139 This needs to be implemented by child classes.
142 raise NotImplementedError()
145 class WorkerPool(object):
146 """Worker pool with a queue.
148 This class is thread-safe.
150 Tasks are guaranteed to be started in the order in which they're
151 added to the pool. Due to the nature of threading, they're not
152 guaranteed to finish in the same order.
155 def __init__(self, num_workers, worker_class):
156 """Constructor for worker pool.
158 @param num_workers: number of workers to be started
159 (dynamic resizing is not yet implemented)
160 @param worker_class: the class to be instantiated for workers;
161 should derive from L{BaseWorker}
164 # Some of these variables are accessed by BaseWorker
165 self._lock = threading.Lock()
166 self._pool_to_pool = threading.Condition(self._lock)
167 self._pool_to_worker = threading.Condition(self._lock)
168 self._worker_to_pool = threading.Condition(self._lock)
169 self._worker_class = worker_class
170 self._last_worker_id = 0
172 self._quiescing = False
174 # Terminating workers
175 self._termworkers = []
178 self._tasks = collections.deque()
181 self.Resize(num_workers)
183 # TODO: Implement dynamic resizing?
185 def AddTask(self, *args):
186 """Adds a task to the queue.
188 @param args: arguments passed to L{BaseWorker.RunTask}
193 # Don't add new tasks while we're quiescing
194 while self._quiescing:
195 self._pool_to_pool.wait()
197 # Add task to internal queue
198 self._tasks.append(args)
200 # Wake one idling worker up
201 self._pool_to_worker.notify()
205 def _ShouldWorkerTerminateUnlocked(self, worker):
206 """Returns whether a worker should terminate.
209 return (worker in self._termworkers)
211 def ShouldWorkerTerminate(self, worker):
212 """Returns whether a worker should terminate.
217 return self._ShouldWorkerTerminateUnlocked(worker)
221 def _HasRunningTasksUnlocked(self):
222 """Checks whether there's a task running in a worker.
225 for worker in self._workers + self._termworkers:
226 if worker._HasRunningTaskUnlocked():
231 """Waits until the task queue is empty.
236 self._quiescing = True
238 # Wait while there are tasks pending or running
239 while self._tasks or self._HasRunningTasksUnlocked():
240 self._worker_to_pool.wait()
243 self._quiescing = False
245 # Make sure AddTasks continues in case it was waiting
246 self._pool_to_pool.notifyAll()
250 def _NewWorkerIdUnlocked(self):
251 """Return an identifier for a new worker.
254 self._last_worker_id += 1
255 return self._last_worker_id
257 def _ResizeUnlocked(self, num_workers):
258 """Changes the number of workers.
261 assert num_workers >= 0, "num_workers must be >= 0"
263 logging.debug("Resizing to %s workers", num_workers)
265 current_count = len(self._workers)
267 if current_count == num_workers:
271 elif current_count > num_workers:
273 # Create copy of list to iterate over while lock isn't held.
274 termworkers = self._workers[:]
277 # TODO: Implement partial downsizing
278 raise NotImplementedError()
281 self._termworkers += termworkers
283 # Notify workers that something has changed
284 self._pool_to_worker.notifyAll()
286 # Join all terminating workers
289 for worker in termworkers:
290 logging.debug("Waiting for thread %s", worker.getName())
295 # Remove terminated threads. This could be done in a more efficient way
296 # (del self._termworkers[:]), but checking worker.isAlive() makes sure we
297 # don't leave zombie threads around.
298 for worker in termworkers:
299 assert worker in self._termworkers, ("Worker not in list of"
300 " terminating workers")
301 if not worker.isAlive():
302 self._termworkers.remove(worker)
304 assert not self._termworkers, "Zombie worker detected"
306 elif current_count < num_workers:
307 # Create (num_workers - current_count) new workers
308 for _ in xrange(num_workers - current_count):
309 worker = self._worker_class(self, self._NewWorkerIdUnlocked())
310 self._workers.append(worker)
313 def Resize(self, num_workers):
314 """Changes the number of workers in the pool.
316 @param num_workers: the new number of workers
321 return self._ResizeUnlocked(num_workers)
325 def TerminateWorkers(self):
326 """Terminate all worker threads.
328 Unstarted tasks will be ignored.
331 logging.debug("Terminating all workers")
335 self._ResizeUnlocked(0)
338 logging.debug("There are %s tasks left", len(self._tasks))
342 logging.debug("All workers terminated")