workerpool: Additional check in BaseWorker.ShouldTerminate
[ganeti-local] / lib / workerpool.py
1 #
2 #
3
4 # Copyright (C) 2008, 2009, 2010 Google Inc.
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 # 02110-1301, USA.
20
21
22 """Base classes for worker pools.
23
24 """
25
26 import collections
27 import logging
28 import threading
29
30 from ganeti import compat
31
32
33 _TERMINATE = object()
34
35
36 class BaseWorker(threading.Thread, object):
37   """Base worker class for worker pools.
38
39   Users of a worker pool must override RunTask in a subclass.
40
41   """
42   # pylint: disable-msg=W0212
43   def __init__(self, pool, worker_id):
44     """Constructor for BaseWorker thread.
45
46     @param pool: the parent worker pool
47     @param worker_id: identifier for this worker
48
49     """
50     super(BaseWorker, self).__init__(name=worker_id)
51     self.pool = pool
52     self._current_task = None
53
54   def ShouldTerminate(self):
55     """Returns whether this worker should terminate.
56
57     Should only be called from within L{RunTask}.
58
59     """
60     self.pool._lock.acquire()
61     try:
62       assert self._HasRunningTaskUnlocked()
63       return self.pool._ShouldWorkerTerminateUnlocked(self)
64     finally:
65       self.pool._lock.release()
66
67   def _HasRunningTaskUnlocked(self):
68     """Returns whether this worker is currently running a task.
69
70     """
71     return (self._current_task is not None)
72
73   def run(self):
74     """Main thread function.
75
76     Waits for new tasks to show up in the queue.
77
78     """
79     pool = self.pool
80
81     assert self._current_task is None
82
83     while True:
84       try:
85         # Wait on lock to be told either to terminate or to do a task
86         pool._lock.acquire()
87         try:
88           task = pool._WaitForTaskUnlocked(self)
89
90           if task is _TERMINATE:
91             # Told to terminate
92             break
93
94           if task is None:
95             # Spurious notification, ignore
96             continue
97
98           self._current_task = task
99
100           assert self._HasRunningTaskUnlocked()
101         finally:
102           pool._lock.release()
103
104         # Run the actual task
105         try:
106           logging.debug("Starting task %r", self._current_task)
107           self.RunTask(*self._current_task)
108           logging.debug("Done with task %r", self._current_task)
109         except: # pylint: disable-msg=W0702
110           logging.exception("Caught unhandled exception")
111       finally:
112         # Notify pool
113         pool._lock.acquire()
114         try:
115           if self._current_task:
116             self._current_task = None
117             pool._worker_to_pool.notifyAll()
118         finally:
119           pool._lock.release()
120
121     logging.debug("Terminates")
122
123   def RunTask(self, *args):
124     """Function called to start a task.
125
126     This needs to be implemented by child classes.
127
128     """
129     raise NotImplementedError()
130
131
132 class WorkerPool(object):
133   """Worker pool with a queue.
134
135   This class is thread-safe.
136
137   Tasks are guaranteed to be started in the order in which they're
138   added to the pool. Due to the nature of threading, they're not
139   guaranteed to finish in the same order.
140
141   """
142   def __init__(self, name, num_workers, worker_class):
143     """Constructor for worker pool.
144
145     @param num_workers: number of workers to be started
146         (dynamic resizing is not yet implemented)
147     @param worker_class: the class to be instantiated for workers;
148         should derive from L{BaseWorker}
149
150     """
151     # Some of these variables are accessed by BaseWorker
152     self._lock = threading.Lock()
153     self._pool_to_pool = threading.Condition(self._lock)
154     self._pool_to_worker = threading.Condition(self._lock)
155     self._worker_to_pool = threading.Condition(self._lock)
156     self._worker_class = worker_class
157     self._name = name
158     self._last_worker_id = 0
159     self._workers = []
160     self._quiescing = False
161
162     # Terminating workers
163     self._termworkers = []
164
165     # Queued tasks
166     self._tasks = collections.deque()
167
168     # Start workers
169     self.Resize(num_workers)
170
171   # TODO: Implement dynamic resizing?
172
173   def _WaitWhileQuiescingUnlocked(self):
174     """Wait until the worker pool has finished quiescing.
175
176     """
177     while self._quiescing:
178       self._pool_to_pool.wait()
179
180   def _AddTaskUnlocked(self, args):
181     assert isinstance(args, (tuple, list)), "Arguments must be a sequence"
182
183     self._tasks.append(args)
184
185     # Notify a waiting worker
186     self._pool_to_worker.notify()
187
188   def AddTask(self, *args):
189     """Adds a task to the queue.
190
191     @param args: arguments passed to L{BaseWorker.RunTask}
192
193     """
194     self._lock.acquire()
195     try:
196       self._WaitWhileQuiescingUnlocked()
197       self._AddTaskUnlocked(args)
198     finally:
199       self._lock.release()
200
201   def AddManyTasks(self, tasks):
202     """Add a list of tasks to the queue.
203
204     @type tasks: list of tuples
205     @param tasks: list of args passed to L{BaseWorker.RunTask}
206
207     """
208     assert compat.all(isinstance(task, (tuple, list)) for task in tasks), \
209       "Each task must be a sequence"
210
211     self._lock.acquire()
212     try:
213       self._WaitWhileQuiescingUnlocked()
214
215       for args in tasks:
216         self._AddTaskUnlocked(args)
217     finally:
218       self._lock.release()
219
220   def _WaitForTaskUnlocked(self, worker):
221     """Waits for a task for a worker.
222
223     @type worker: L{BaseWorker}
224     @param worker: Worker thread
225
226     """
227     if self._ShouldWorkerTerminateUnlocked(worker):
228       return _TERMINATE
229
230     # We only wait if there's no task for us.
231     if not self._tasks:
232       logging.debug("Waiting for tasks")
233
234       # wait() releases the lock and sleeps until notified
235       self._pool_to_worker.wait()
236
237       logging.debug("Notified while waiting")
238
239       # Were we woken up in order to terminate?
240       if self._ShouldWorkerTerminateUnlocked(worker):
241         return _TERMINATE
242
243       if not self._tasks:
244         # Spurious notification, ignore
245         return None
246
247     # Get task from queue and tell pool about it
248     try:
249       return self._tasks.popleft()
250     finally:
251       self._worker_to_pool.notifyAll()
252
253   def _ShouldWorkerTerminateUnlocked(self, worker):
254     """Returns whether a worker should terminate.
255
256     """
257     return (worker in self._termworkers)
258
259   def _HasRunningTasksUnlocked(self):
260     """Checks whether there's a task running in a worker.
261
262     """
263     for worker in self._workers + self._termworkers:
264       if worker._HasRunningTaskUnlocked(): # pylint: disable-msg=W0212
265         return True
266     return False
267
268   def Quiesce(self):
269     """Waits until the task queue is empty.
270
271     """
272     self._lock.acquire()
273     try:
274       self._quiescing = True
275
276       # Wait while there are tasks pending or running
277       while self._tasks or self._HasRunningTasksUnlocked():
278         self._worker_to_pool.wait()
279
280     finally:
281       self._quiescing = False
282
283       # Make sure AddTasks continues in case it was waiting
284       self._pool_to_pool.notifyAll()
285
286       self._lock.release()
287
288   def _NewWorkerIdUnlocked(self):
289     """Return an identifier for a new worker.
290
291     """
292     self._last_worker_id += 1
293
294     return "%s%d" % (self._name, self._last_worker_id)
295
296   def _ResizeUnlocked(self, num_workers):
297     """Changes the number of workers.
298
299     """
300     assert num_workers >= 0, "num_workers must be >= 0"
301
302     logging.debug("Resizing to %s workers", num_workers)
303
304     current_count = len(self._workers)
305
306     if current_count == num_workers:
307       # Nothing to do
308       pass
309
310     elif current_count > num_workers:
311       if num_workers == 0:
312         # Create copy of list to iterate over while lock isn't held.
313         termworkers = self._workers[:]
314         del self._workers[:]
315       else:
316         # TODO: Implement partial downsizing
317         raise NotImplementedError()
318         #termworkers = ...
319
320       self._termworkers += termworkers
321
322       # Notify workers that something has changed
323       self._pool_to_worker.notifyAll()
324
325       # Join all terminating workers
326       self._lock.release()
327       try:
328         for worker in termworkers:
329           logging.debug("Waiting for thread %s", worker.getName())
330           worker.join()
331       finally:
332         self._lock.acquire()
333
334       # Remove terminated threads. This could be done in a more efficient way
335       # (del self._termworkers[:]), but checking worker.isAlive() makes sure we
336       # don't leave zombie threads around.
337       for worker in termworkers:
338         assert worker in self._termworkers, ("Worker not in list of"
339                                              " terminating workers")
340         if not worker.isAlive():
341           self._termworkers.remove(worker)
342
343       assert not self._termworkers, "Zombie worker detected"
344
345     elif current_count < num_workers:
346       # Create (num_workers - current_count) new workers
347       for _ in range(num_workers - current_count):
348         worker = self._worker_class(self, self._NewWorkerIdUnlocked())
349         self._workers.append(worker)
350         worker.start()
351
352   def Resize(self, num_workers):
353     """Changes the number of workers in the pool.
354
355     @param num_workers: the new number of workers
356
357     """
358     self._lock.acquire()
359     try:
360       return self._ResizeUnlocked(num_workers)
361     finally:
362       self._lock.release()
363
364   def TerminateWorkers(self):
365     """Terminate all worker threads.
366
367     Unstarted tasks will be ignored.
368
369     """
370     logging.debug("Terminating all workers")
371
372     self._lock.acquire()
373     try:
374       self._ResizeUnlocked(0)
375
376       if self._tasks:
377         logging.debug("There are %s tasks left", len(self._tasks))
378     finally:
379       self._lock.release()
380
381     logging.debug("All workers terminated")