4 # Copyright (C) 2008 Google Inc.
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 # General Public License for more details.
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
22 """Base classes for worker pools.
31 class BaseWorker(threading.Thread, object):
32 """Base worker class for worker pools.
34 Users of a worker pool must override RunTask in a subclass.
37 def __init__(self, pool, worker_id):
38 """Constructor for BaseWorker thread.
41 - pool: Parent worker pool
42 - worker_id: Identifier for this worker
45 super(BaseWorker, self).__init__()
47 self.worker_id = worker_id
48 self._current_task = None
50 def ShouldTerminate(self):
51 """Returns whether a worker should terminate.
54 return self.pool.ShouldWorkerTerminate(self)
56 def _HasRunningTaskUnlocked(self):
57 """Returns whether this worker is currently running a task.
60 return (self._current_task is not None)
62 def HasRunningTask(self):
63 """Returns whether this worker is currently running a task.
66 self.pool._lock.acquire()
68 return self._HasRunningTaskUnlocked()
70 self.pool._lock.release()
73 """Main thread function.
75 Waits for new tasks to show up in the queue.
80 assert not self.HasRunningTask()
84 # We wait on lock to be told either terminate or do a task.
87 if pool._ShouldWorkerTerminateUnlocked(self):
90 # We only wait if there's no task for us.
92 logging.debug("Worker %s: waiting for tasks", self.worker_id)
94 # wait() releases the lock and sleeps until notified
95 pool._pool_to_worker.wait()
97 logging.debug("Worker %s: notified while waiting", self.worker_id)
99 # Were we woken up in order to terminate?
100 if pool._ShouldWorkerTerminateUnlocked(self):
104 # Spurious notification, ignore
107 # Get task from queue and tell pool about it
109 self._current_task = pool._tasks.popleft()
111 pool._worker_to_pool.notifyAll()
115 # Run the actual task
117 logging.debug("Worker %s: starting task %r",
118 self.worker_id, self._current_task)
119 self.RunTask(*self._current_task)
120 logging.debug("Worker %s: done with task %r",
121 self.worker_id, self._current_task)
123 logging.error("Worker %s: Caught unhandled exception",
124 self.worker_id, exc_info=True)
129 if self._current_task:
130 self._current_task = None
131 pool._worker_to_pool.notifyAll()
135 logging.debug("Worker %s: terminates", self.worker_id)
137 def RunTask(self, *args):
138 """Function called to start a task.
141 raise NotImplementedError()
144 class WorkerPool(object):
145 """Worker pool with a queue.
147 This class is thread-safe.
149 Tasks are guaranteed to be started in the order in which they're added to the
150 pool. Due to the nature of threading, they're not guaranteed to finish in the
154 def __init__(self, num_workers, worker_class):
155 """Constructor for worker pool.
158 - num_workers: Number of workers to be started (dynamic resizing is not
160 - worker_class: Class to be instantiated for workers; should derive from
164 # Some of these variables are accessed by BaseWorker
165 self._lock = threading.Lock()
166 self._pool_to_pool = threading.Condition(self._lock)
167 self._pool_to_worker = threading.Condition(self._lock)
168 self._worker_to_pool = threading.Condition(self._lock)
169 self._worker_class = worker_class
170 self._last_worker_id = 0
172 self._quiescing = False
174 # Terminating workers
175 self._termworkers = []
178 self._tasks = collections.deque()
181 self.Resize(num_workers)
183 # TODO: Implement dynamic resizing?
185 def AddTask(self, *args):
186 """Adds a task to the queue.
189 - *args: Arguments passed to BaseWorker.RunTask
194 # Don't add new tasks while we're quiescing
195 while self._quiescing:
196 self._pool_to_pool.wait()
198 # Add task to internal queue
199 self._tasks.append(args)
201 # Wake one idling worker up
202 self._pool_to_worker.notify()
206 def _ShouldWorkerTerminateUnlocked(self, worker):
207 """Returns whether a worker should terminate.
210 return (worker in self._termworkers)
212 def ShouldWorkerTerminate(self, worker):
213 """Returns whether a worker should terminate.
218 return self._ShouldWorkerTerminateUnlocked(self)
222 def _HasRunningTasksUnlocked(self):
223 """Checks whether there's a task running in a worker.
226 for worker in self._workers + self._termworkers:
227 if worker._HasRunningTaskUnlocked():
232 """Waits until the task queue is empty.
237 self._quiescing = True
239 # Wait while there are tasks pending or running
240 while self._tasks or self._HasRunningTasksUnlocked():
241 self._worker_to_pool.wait()
244 self._quiescing = False
246 # Make sure AddTasks continues in case it was waiting
247 self._pool_to_pool.notifyAll()
251 def _NewWorkerIdUnlocked(self):
252 self._last_worker_id += 1
253 return self._last_worker_id
255 def _ResizeUnlocked(self, num_workers):
256 """Changes the number of workers.
259 assert num_workers >= 0, "num_workers must be >= 0"
261 logging.debug("Resizing to %s workers", num_workers)
263 current_count = len(self._workers)
265 if current_count == num_workers:
269 elif current_count > num_workers:
271 # Create copy of list to iterate over while lock isn't held.
272 termworkers = self._workers[:]
275 # TODO: Implement partial downsizing
276 raise NotImplementedError()
279 self._termworkers += termworkers
281 # Notify workers that something has changed
282 self._pool_to_worker.notifyAll()
284 # Join all terminating workers
287 for worker in termworkers:
288 logging.debug("Waiting for thread %s", worker.getName())
293 # Remove terminated threads. This could be done in a more efficient way
294 # (del self._termworkers[:]), but checking worker.isAlive() makes sure we
295 # don't leave zombie threads around.
296 for worker in termworkers:
297 assert worker in self._termworkers, ("Worker not in list of"
298 " terminating workers")
299 if not worker.isAlive():
300 self._termworkers.remove(worker)
302 assert not self._termworkers, "Zombie worker detected"
304 elif current_count < num_workers:
305 # Create (num_workers - current_count) new workers
306 for i in xrange(num_workers - current_count):
307 worker = self._worker_class(self, self._NewWorkerIdUnlocked())
308 self._workers.append(worker)
311 def Resize(self, num_workers):
312 """Changes the number of workers in the pool.
315 - num_workers: New number of workers
320 return self._ResizeUnlocked(num_workers)
324 def TerminateWorkers(self):
325 """Terminate all worker threads.
327 Unstarted tasks will be ignored.
330 logging.debug("Terminating all workers")
334 self._ResizeUnlocked(0)
337 logging.debug("There are %s tasks left", len(self._tasks))
341 logging.debug("All workers terminated")