workerpool: Simplify _WaitForTaskUnlocked
[ganeti-local] / lib / workerpool.py
1 #
2 #
3
4 # Copyright (C) 2008, 2009, 2010 Google Inc.
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 # 02110-1301, USA.
20
21
22 """Base classes for worker pools.
23
24 """
25
26 import logging
27 import threading
28 import heapq
29
30 from ganeti import compat
31 from ganeti import errors
32
33
34 _TERMINATE = object()
35 _DEFAULT_PRIORITY = 0
36
37
38 class DeferTask(Exception):
39   """Special exception class to defer a task.
40
41   This class can be raised by L{BaseWorker.RunTask} to defer the execution of a
42   task. Optionally, the priority of the task can be changed.
43
44   """
45   def __init__(self, priority=None):
46     """Initializes this class.
47
48     @type priority: number
49     @param priority: New task priority (None means no change)
50
51     """
52     Exception.__init__(self)
53     self.priority = priority
54
55
56 class BaseWorker(threading.Thread, object):
57   """Base worker class for worker pools.
58
59   Users of a worker pool must override RunTask in a subclass.
60
61   """
62   # pylint: disable=W0212
63   def __init__(self, pool, worker_id):
64     """Constructor for BaseWorker thread.
65
66     @param pool: the parent worker pool
67     @param worker_id: identifier for this worker
68
69     """
70     super(BaseWorker, self).__init__(name=worker_id)
71     self.pool = pool
72     self._worker_id = worker_id
73     self._current_task = None
74
75     assert self.getName() == worker_id
76
77   def ShouldTerminate(self):
78     """Returns whether this worker should terminate.
79
80     Should only be called from within L{RunTask}.
81
82     """
83     self.pool._lock.acquire()
84     try:
85       assert self._HasRunningTaskUnlocked()
86       return self.pool._ShouldWorkerTerminateUnlocked(self)
87     finally:
88       self.pool._lock.release()
89
90   def GetCurrentPriority(self):
91     """Returns the priority of the current task.
92
93     Should only be called from within L{RunTask}.
94
95     """
96     self.pool._lock.acquire()
97     try:
98       assert self._HasRunningTaskUnlocked()
99
100       (priority, _, _) = self._current_task
101
102       return priority
103     finally:
104       self.pool._lock.release()
105
106   def SetTaskName(self, taskname):
107     """Sets the name of the current task.
108
109     Should only be called from within L{RunTask}.
110
111     @type taskname: string
112     @param taskname: Task's name
113
114     """
115     if taskname:
116       name = "%s/%s" % (self._worker_id, taskname)
117     else:
118       name = self._worker_id
119
120     # Set thread name
121     self.setName(name)
122
123   def _HasRunningTaskUnlocked(self):
124     """Returns whether this worker is currently running a task.
125
126     """
127     return (self._current_task is not None)
128
129   def run(self):
130     """Main thread function.
131
132     Waits for new tasks to show up in the queue.
133
134     """
135     pool = self.pool
136
137     while True:
138       assert self._current_task is None
139
140       defer = None
141       try:
142         # Wait on lock to be told either to terminate or to do a task
143         pool._lock.acquire()
144         try:
145           task = pool._WaitForTaskUnlocked(self)
146
147           if task is _TERMINATE:
148             # Told to terminate
149             break
150
151           if task is None:
152             # Spurious notification, ignore
153             continue
154
155           self._current_task = task
156
157           # No longer needed, dispose of reference
158           del task
159
160           assert self._HasRunningTaskUnlocked()
161
162         finally:
163           pool._lock.release()
164
165         (priority, _, args) = self._current_task
166         try:
167           # Run the actual task
168           assert defer is None
169           logging.debug("Starting task %r, priority %s", args, priority)
170           assert self.getName() == self._worker_id
171           try:
172             self.RunTask(*args) # pylint: disable=W0142
173           finally:
174             self.SetTaskName(None)
175           logging.debug("Done with task %r, priority %s", args, priority)
176         except DeferTask, err:
177           defer = err
178
179           if defer.priority is None:
180             # Use same priority
181             defer.priority = priority
182
183           logging.debug("Deferring task %r, new priority %s",
184                         args, defer.priority)
185
186           assert self._HasRunningTaskUnlocked()
187         except: # pylint: disable=W0702
188           logging.exception("Caught unhandled exception")
189
190         assert self._HasRunningTaskUnlocked()
191       finally:
192         # Notify pool
193         pool._lock.acquire()
194         try:
195           if defer:
196             assert self._current_task
197             # Schedule again for later run
198             (_, _, args) = self._current_task
199             pool._AddTaskUnlocked(args, defer.priority)
200
201           if self._current_task:
202             self._current_task = None
203             pool._worker_to_pool.notifyAll()
204         finally:
205           pool._lock.release()
206
207       assert not self._HasRunningTaskUnlocked()
208
209     logging.debug("Terminates")
210
211   def RunTask(self, *args):
212     """Function called to start a task.
213
214     This needs to be implemented by child classes.
215
216     """
217     raise NotImplementedError()
218
219
220 class WorkerPool(object):
221   """Worker pool with a queue.
222
223   This class is thread-safe.
224
225   Tasks are guaranteed to be started in the order in which they're
226   added to the pool. Due to the nature of threading, they're not
227   guaranteed to finish in the same order.
228
229   """
230   def __init__(self, name, num_workers, worker_class):
231     """Constructor for worker pool.
232
233     @param num_workers: number of workers to be started
234         (dynamic resizing is not yet implemented)
235     @param worker_class: the class to be instantiated for workers;
236         should derive from L{BaseWorker}
237
238     """
239     # Some of these variables are accessed by BaseWorker
240     self._lock = threading.Lock()
241     self._pool_to_pool = threading.Condition(self._lock)
242     self._pool_to_worker = threading.Condition(self._lock)
243     self._worker_to_pool = threading.Condition(self._lock)
244     self._worker_class = worker_class
245     self._name = name
246     self._last_worker_id = 0
247     self._workers = []
248     self._quiescing = False
249     self._active = True
250
251     # Terminating workers
252     self._termworkers = []
253
254     # Queued tasks
255     self._counter = 0
256     self._tasks = []
257
258     # Start workers
259     self.Resize(num_workers)
260
261   # TODO: Implement dynamic resizing?
262
263   def _WaitWhileQuiescingUnlocked(self):
264     """Wait until the worker pool has finished quiescing.
265
266     """
267     while self._quiescing:
268       self._pool_to_pool.wait()
269
270   def _AddTaskUnlocked(self, args, priority):
271     """Adds a task to the internal queue.
272
273     @type args: sequence
274     @param args: Arguments passed to L{BaseWorker.RunTask}
275     @type priority: number
276     @param priority: Task priority
277
278     """
279     assert isinstance(args, (tuple, list)), "Arguments must be a sequence"
280     assert isinstance(priority, (int, long)), "Priority must be numeric"
281
282     # This counter is used to ensure elements are processed in their
283     # incoming order. For processing they're sorted by priority and then
284     # counter.
285     self._counter += 1
286
287     heapq.heappush(self._tasks, (priority, self._counter, args))
288
289     # Notify a waiting worker
290     self._pool_to_worker.notify()
291
292   def AddTask(self, args, priority=_DEFAULT_PRIORITY):
293     """Adds a task to the queue.
294
295     @type args: sequence
296     @param args: arguments passed to L{BaseWorker.RunTask}
297     @type priority: number
298     @param priority: Task priority
299
300     """
301     self._lock.acquire()
302     try:
303       self._WaitWhileQuiescingUnlocked()
304       self._AddTaskUnlocked(args, priority)
305     finally:
306       self._lock.release()
307
308   def AddManyTasks(self, tasks, priority=_DEFAULT_PRIORITY):
309     """Add a list of tasks to the queue.
310
311     @type tasks: list of tuples
312     @param tasks: list of args passed to L{BaseWorker.RunTask}
313     @type priority: number or list of numbers
314     @param priority: Priority for all added tasks or a list with the priority
315                      for each task
316
317     """
318     assert compat.all(isinstance(task, (tuple, list)) for task in tasks), \
319       "Each task must be a sequence"
320
321     assert (isinstance(priority, (int, long)) or
322             compat.all(isinstance(prio, (int, long)) for prio in priority)), \
323            "Priority must be numeric or be a list of numeric values"
324
325     if isinstance(priority, (int, long)):
326       priority = [priority] * len(tasks)
327     elif len(priority) != len(tasks):
328       raise errors.ProgrammerError("Number of priorities (%s) doesn't match"
329                                    " number of tasks (%s)" %
330                                    (len(priority), len(tasks)))
331
332     self._lock.acquire()
333     try:
334       self._WaitWhileQuiescingUnlocked()
335
336       assert compat.all(isinstance(prio, (int, long)) for prio in priority)
337       assert len(tasks) == len(priority)
338
339       for args, priority in zip(tasks, priority):
340         self._AddTaskUnlocked(args, priority)
341     finally:
342       self._lock.release()
343
344   def SetActive(self, active):
345     """Enable/disable processing of tasks.
346
347     This is different from L{Quiesce} in the sense that this function just
348     changes an internal flag and doesn't wait for the queue to be empty. Tasks
349     already being processed continue normally, but no new tasks will be
350     started. New tasks can still be added.
351
352     @type active: bool
353     @param active: Whether tasks should be processed
354
355     """
356     self._lock.acquire()
357     try:
358       self._active = active
359
360       if active:
361         # Tell all workers to continue processing
362         self._pool_to_worker.notifyAll()
363     finally:
364       self._lock.release()
365
366   def _WaitForTaskUnlocked(self, worker):
367     """Waits for a task for a worker.
368
369     @type worker: L{BaseWorker}
370     @param worker: Worker thread
371
372     """
373     while True:
374       if self._ShouldWorkerTerminateUnlocked(worker):
375         return _TERMINATE
376
377       # If there's a pending task, return it immediately
378       if self._active and self._tasks:
379         # Get task from queue and tell pool about it
380         try:
381           task = heapq.heappop(self._tasks)
382         finally:
383           self._worker_to_pool.notifyAll()
384
385         return task
386
387       logging.debug("Waiting for tasks")
388
389       # wait() releases the lock and sleeps until notified
390       self._pool_to_worker.wait()
391
392       logging.debug("Notified while waiting")
393
394   def _ShouldWorkerTerminateUnlocked(self, worker):
395     """Returns whether a worker should terminate.
396
397     """
398     return (worker in self._termworkers)
399
400   def _HasRunningTasksUnlocked(self):
401     """Checks whether there's a task running in a worker.
402
403     """
404     for worker in self._workers + self._termworkers:
405       if worker._HasRunningTaskUnlocked(): # pylint: disable=W0212
406         return True
407     return False
408
409   def HasRunningTasks(self):
410     """Checks whether there's at least one task running.
411
412     """
413     self._lock.acquire()
414     try:
415       return self._HasRunningTasksUnlocked()
416     finally:
417       self._lock.release()
418
419   def Quiesce(self):
420     """Waits until the task queue is empty.
421
422     """
423     self._lock.acquire()
424     try:
425       self._quiescing = True
426
427       # Wait while there are tasks pending or running
428       while self._tasks or self._HasRunningTasksUnlocked():
429         self._worker_to_pool.wait()
430
431     finally:
432       self._quiescing = False
433
434       # Make sure AddTasks continues in case it was waiting
435       self._pool_to_pool.notifyAll()
436
437       self._lock.release()
438
439   def _NewWorkerIdUnlocked(self):
440     """Return an identifier for a new worker.
441
442     """
443     self._last_worker_id += 1
444
445     return "%s%d" % (self._name, self._last_worker_id)
446
447   def _ResizeUnlocked(self, num_workers):
448     """Changes the number of workers.
449
450     """
451     assert num_workers >= 0, "num_workers must be >= 0"
452
453     logging.debug("Resizing to %s workers", num_workers)
454
455     current_count = len(self._workers)
456
457     if current_count == num_workers:
458       # Nothing to do
459       pass
460
461     elif current_count > num_workers:
462       if num_workers == 0:
463         # Create copy of list to iterate over while lock isn't held.
464         termworkers = self._workers[:]
465         del self._workers[:]
466       else:
467         # TODO: Implement partial downsizing
468         raise NotImplementedError()
469         #termworkers = ...
470
471       self._termworkers += termworkers
472
473       # Notify workers that something has changed
474       self._pool_to_worker.notifyAll()
475
476       # Join all terminating workers
477       self._lock.release()
478       try:
479         for worker in termworkers:
480           logging.debug("Waiting for thread %s", worker.getName())
481           worker.join()
482       finally:
483         self._lock.acquire()
484
485       # Remove terminated threads. This could be done in a more efficient way
486       # (del self._termworkers[:]), but checking worker.isAlive() makes sure we
487       # don't leave zombie threads around.
488       for worker in termworkers:
489         assert worker in self._termworkers, ("Worker not in list of"
490                                              " terminating workers")
491         if not worker.isAlive():
492           self._termworkers.remove(worker)
493
494       assert not self._termworkers, "Zombie worker detected"
495
496     elif current_count < num_workers:
497       # Create (num_workers - current_count) new workers
498       for _ in range(num_workers - current_count):
499         worker = self._worker_class(self, self._NewWorkerIdUnlocked())
500         self._workers.append(worker)
501         worker.start()
502
503   def Resize(self, num_workers):
504     """Changes the number of workers in the pool.
505
506     @param num_workers: the new number of workers
507
508     """
509     self._lock.acquire()
510     try:
511       return self._ResizeUnlocked(num_workers)
512     finally:
513       self._lock.release()
514
515   def TerminateWorkers(self):
516     """Terminate all worker threads.
517
518     Unstarted tasks will be ignored.
519
520     """
521     logging.debug("Terminating all workers")
522
523     self._lock.acquire()
524     try:
525       self._ResizeUnlocked(0)
526
527       if self._tasks:
528         logging.debug("There are %s tasks left", len(self._tasks))
529     finally:
530       self._lock.release()
531
532     logging.debug("All workers terminated")