4 # Copyright (C) 2008 Google Inc.
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 # General Public License for more details.
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
22 """Base classes for worker pools.
31 class BaseWorker(threading.Thread, object):
32 """Base worker class for worker pools.
34 Users of a worker pool must override RunTask in a subclass.
37 # pylint: disable-msg=W0212
38 def __init__(self, pool, worker_id):
39 """Constructor for BaseWorker thread.
41 @param pool: the parent worker pool
42 @param worker_id: identifier for this worker
45 super(BaseWorker, self).__init__()
47 self.worker_id = worker_id
48 self._current_task = None
50 def ShouldTerminate(self):
51 """Returns whether a worker should terminate.
54 return self.pool.ShouldWorkerTerminate(self)
56 def _HasRunningTaskUnlocked(self):
57 """Returns whether this worker is currently running a task.
60 return (self._current_task is not None)
62 def HasRunningTask(self):
63 """Returns whether this worker is currently running a task.
66 self.pool._lock.acquire()
68 return self._HasRunningTaskUnlocked()
70 self.pool._lock.release()
73 """Main thread function.
75 Waits for new tasks to show up in the queue.
80 assert not self.HasRunningTask()
84 # We wait on lock to be told either terminate or do a task.
87 if pool._ShouldWorkerTerminateUnlocked(self):
90 # We only wait if there's no task for us.
92 logging.debug("Worker %s: waiting for tasks", self.worker_id)
94 # wait() releases the lock and sleeps until notified
95 pool._pool_to_worker.wait()
97 logging.debug("Worker %s: notified while waiting", self.worker_id)
99 # Were we woken up in order to terminate?
100 if pool._ShouldWorkerTerminateUnlocked(self):
104 # Spurious notification, ignore
107 # Get task from queue and tell pool about it
109 self._current_task = pool._tasks.popleft()
111 pool._worker_to_pool.notifyAll()
115 # Run the actual task
117 logging.debug("Worker %s: starting task %r",
118 self.worker_id, self._current_task)
119 self.RunTask(*self._current_task)
120 logging.debug("Worker %s: done with task %r",
121 self.worker_id, self._current_task)
122 except: # pylint: disable-msg=W0702
123 logging.error("Worker %s: Caught unhandled exception",
124 self.worker_id, exc_info=True)
129 if self._current_task:
130 self._current_task = None
131 pool._worker_to_pool.notifyAll()
135 logging.debug("Worker %s: terminates", self.worker_id)
137 def RunTask(self, *args):
138 """Function called to start a task.
140 This needs to be implemented by child classes.
143 raise NotImplementedError()
146 class WorkerPool(object):
147 """Worker pool with a queue.
149 This class is thread-safe.
151 Tasks are guaranteed to be started in the order in which they're
152 added to the pool. Due to the nature of threading, they're not
153 guaranteed to finish in the same order.
156 def __init__(self, name, num_workers, worker_class):
157 """Constructor for worker pool.
159 @param num_workers: number of workers to be started
160 (dynamic resizing is not yet implemented)
161 @param worker_class: the class to be instantiated for workers;
162 should derive from L{BaseWorker}
165 # Some of these variables are accessed by BaseWorker
166 self._lock = threading.Lock()
167 self._pool_to_pool = threading.Condition(self._lock)
168 self._pool_to_worker = threading.Condition(self._lock)
169 self._worker_to_pool = threading.Condition(self._lock)
170 self._worker_class = worker_class
172 self._last_worker_id = 0
174 self._quiescing = False
176 # Terminating workers
177 self._termworkers = []
180 self._tasks = collections.deque()
183 self.Resize(num_workers)
185 # TODO: Implement dynamic resizing?
187 def AddTask(self, *args):
188 """Adds a task to the queue.
190 @param args: arguments passed to L{BaseWorker.RunTask}
195 # Don't add new tasks while we're quiescing
196 while self._quiescing:
197 self._pool_to_pool.wait()
199 # Add task to internal queue
200 self._tasks.append(args)
202 # Wake one idling worker up
203 self._pool_to_worker.notify()
207 def _ShouldWorkerTerminateUnlocked(self, worker):
208 """Returns whether a worker should terminate.
211 return (worker in self._termworkers)
213 def ShouldWorkerTerminate(self, worker):
214 """Returns whether a worker should terminate.
219 return self._ShouldWorkerTerminateUnlocked(worker)
223 def _HasRunningTasksUnlocked(self):
224 """Checks whether there's a task running in a worker.
227 for worker in self._workers + self._termworkers:
228 if worker._HasRunningTaskUnlocked(): # pylint: disable-msg=W0212
233 """Waits until the task queue is empty.
238 self._quiescing = True
240 # Wait while there are tasks pending or running
241 while self._tasks or self._HasRunningTasksUnlocked():
242 self._worker_to_pool.wait()
245 self._quiescing = False
247 # Make sure AddTasks continues in case it was waiting
248 self._pool_to_pool.notifyAll()
252 def _NewWorkerIdUnlocked(self):
253 """Return an identifier for a new worker.
256 self._last_worker_id += 1
258 return "%s%d" % (self._name, self._last_worker_id)
260 def _ResizeUnlocked(self, num_workers):
261 """Changes the number of workers.
264 assert num_workers >= 0, "num_workers must be >= 0"
266 logging.debug("Resizing to %s workers", num_workers)
268 current_count = len(self._workers)
270 if current_count == num_workers:
274 elif current_count > num_workers:
276 # Create copy of list to iterate over while lock isn't held.
277 termworkers = self._workers[:]
280 # TODO: Implement partial downsizing
281 raise NotImplementedError()
284 self._termworkers += termworkers
286 # Notify workers that something has changed
287 self._pool_to_worker.notifyAll()
289 # Join all terminating workers
292 for worker in termworkers:
293 logging.debug("Waiting for thread %s", worker.getName())
298 # Remove terminated threads. This could be done in a more efficient way
299 # (del self._termworkers[:]), but checking worker.isAlive() makes sure we
300 # don't leave zombie threads around.
301 for worker in termworkers:
302 assert worker in self._termworkers, ("Worker not in list of"
303 " terminating workers")
304 if not worker.isAlive():
305 self._termworkers.remove(worker)
307 assert not self._termworkers, "Zombie worker detected"
309 elif current_count < num_workers:
310 # Create (num_workers - current_count) new workers
311 for _ in range(num_workers - current_count):
312 worker = self._worker_class(self, self._NewWorkerIdUnlocked())
313 self._workers.append(worker)
316 def Resize(self, num_workers):
317 """Changes the number of workers in the pool.
319 @param num_workers: the new number of workers
324 return self._ResizeUnlocked(num_workers)
328 def TerminateWorkers(self):
329 """Terminate all worker threads.
331 Unstarted tasks will be ignored.
334 logging.debug("Terminating all workers")
338 self._ResizeUnlocked(0)
341 logging.debug("There are %s tasks left", len(self._tasks))
345 logging.debug("All workers terminated")