4 # Copyright (C) 2008 Google Inc.
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 # General Public License for more details.
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
22 """Base classes for worker pools.
31 class BaseWorker(threading.Thread, object):
32 """Base worker class for worker pools.
34 Users of a worker pool must override RunTask in a subclass.
37 # pylint: disable-msg=W0212
38 def __init__(self, pool, worker_id):
39 """Constructor for BaseWorker thread.
41 @param pool: the parent worker pool
42 @param worker_id: identifier for this worker
45 super(BaseWorker, self).__init__(name=worker_id)
47 self._current_task = None
49 def ShouldTerminate(self):
50 """Returns whether a worker should terminate.
53 return self.pool.ShouldWorkerTerminate(self)
55 def _HasRunningTaskUnlocked(self):
56 """Returns whether this worker is currently running a task.
59 return (self._current_task is not None)
61 def HasRunningTask(self):
62 """Returns whether this worker is currently running a task.
65 self.pool._lock.acquire()
67 return self._HasRunningTaskUnlocked()
69 self.pool._lock.release()
72 """Main thread function.
74 Waits for new tasks to show up in the queue.
79 assert not self.HasRunningTask()
83 # We wait on lock to be told either terminate or do a task.
86 if pool._ShouldWorkerTerminateUnlocked(self):
89 # We only wait if there's no task for us.
91 logging.debug("Waiting for tasks")
93 # wait() releases the lock and sleeps until notified
94 pool._pool_to_worker.wait()
96 logging.debug("Notified while waiting")
98 # Were we woken up in order to terminate?
99 if pool._ShouldWorkerTerminateUnlocked(self):
103 # Spurious notification, ignore
106 # Get task from queue and tell pool about it
108 self._current_task = pool._tasks.popleft()
110 pool._worker_to_pool.notifyAll()
114 # Run the actual task
116 logging.debug("Starting task %r", self._current_task)
117 self.RunTask(*self._current_task)
118 logging.debug("Done with task %r", self._current_task)
119 except: # pylint: disable-msg=W0702
120 logging.exception("Caught unhandled exception")
125 if self._current_task:
126 self._current_task = None
127 pool._worker_to_pool.notifyAll()
131 logging.debug("Terminates")
133 def RunTask(self, *args):
134 """Function called to start a task.
136 This needs to be implemented by child classes.
139 raise NotImplementedError()
142 class WorkerPool(object):
143 """Worker pool with a queue.
145 This class is thread-safe.
147 Tasks are guaranteed to be started in the order in which they're
148 added to the pool. Due to the nature of threading, they're not
149 guaranteed to finish in the same order.
152 def __init__(self, name, num_workers, worker_class):
153 """Constructor for worker pool.
155 @param num_workers: number of workers to be started
156 (dynamic resizing is not yet implemented)
157 @param worker_class: the class to be instantiated for workers;
158 should derive from L{BaseWorker}
161 # Some of these variables are accessed by BaseWorker
162 self._lock = threading.Lock()
163 self._pool_to_pool = threading.Condition(self._lock)
164 self._pool_to_worker = threading.Condition(self._lock)
165 self._worker_to_pool = threading.Condition(self._lock)
166 self._worker_class = worker_class
168 self._last_worker_id = 0
170 self._quiescing = False
172 # Terminating workers
173 self._termworkers = []
176 self._tasks = collections.deque()
179 self.Resize(num_workers)
181 # TODO: Implement dynamic resizing?
183 def AddTask(self, *args):
184 """Adds a task to the queue.
186 @param args: arguments passed to L{BaseWorker.RunTask}
191 # Don't add new tasks while we're quiescing
192 while self._quiescing:
193 self._pool_to_pool.wait()
195 # Add task to internal queue
196 self._tasks.append(args)
198 # Wake one idling worker up
199 self._pool_to_worker.notify()
203 def _ShouldWorkerTerminateUnlocked(self, worker):
204 """Returns whether a worker should terminate.
207 return (worker in self._termworkers)
209 def ShouldWorkerTerminate(self, worker):
210 """Returns whether a worker should terminate.
215 return self._ShouldWorkerTerminateUnlocked(worker)
219 def _HasRunningTasksUnlocked(self):
220 """Checks whether there's a task running in a worker.
223 for worker in self._workers + self._termworkers:
224 if worker._HasRunningTaskUnlocked(): # pylint: disable-msg=W0212
229 """Waits until the task queue is empty.
234 self._quiescing = True
236 # Wait while there are tasks pending or running
237 while self._tasks or self._HasRunningTasksUnlocked():
238 self._worker_to_pool.wait()
241 self._quiescing = False
243 # Make sure AddTasks continues in case it was waiting
244 self._pool_to_pool.notifyAll()
248 def _NewWorkerIdUnlocked(self):
249 """Return an identifier for a new worker.
252 self._last_worker_id += 1
254 return "%s%d" % (self._name, self._last_worker_id)
256 def _ResizeUnlocked(self, num_workers):
257 """Changes the number of workers.
260 assert num_workers >= 0, "num_workers must be >= 0"
262 logging.debug("Resizing to %s workers", num_workers)
264 current_count = len(self._workers)
266 if current_count == num_workers:
270 elif current_count > num_workers:
272 # Create copy of list to iterate over while lock isn't held.
273 termworkers = self._workers[:]
276 # TODO: Implement partial downsizing
277 raise NotImplementedError()
280 self._termworkers += termworkers
282 # Notify workers that something has changed
283 self._pool_to_worker.notifyAll()
285 # Join all terminating workers
288 for worker in termworkers:
289 logging.debug("Waiting for thread %s", worker.getName())
294 # Remove terminated threads. This could be done in a more efficient way
295 # (del self._termworkers[:]), but checking worker.isAlive() makes sure we
296 # don't leave zombie threads around.
297 for worker in termworkers:
298 assert worker in self._termworkers, ("Worker not in list of"
299 " terminating workers")
300 if not worker.isAlive():
301 self._termworkers.remove(worker)
303 assert not self._termworkers, "Zombie worker detected"
305 elif current_count < num_workers:
306 # Create (num_workers - current_count) new workers
307 for _ in range(num_workers - current_count):
308 worker = self._worker_class(self, self._NewWorkerIdUnlocked())
309 self._workers.append(worker)
312 def Resize(self, num_workers):
313 """Changes the number of workers in the pool.
315 @param num_workers: the new number of workers
320 return self._ResizeUnlocked(num_workers)
324 def TerminateWorkers(self):
325 """Terminate all worker threads.
327 Unstarted tasks will be ignored.
330 logging.debug("Terminating all workers")
334 self._ResizeUnlocked(0)
337 logging.debug("There are %s tasks left", len(self._tasks))
341 logging.debug("All workers terminated")