workerpool: Use itertools.count instead of manual counting
[ganeti-local] / lib / workerpool.py
1 #
2 #
3
4 # Copyright (C) 2008, 2009, 2010 Google Inc.
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 # 02110-1301, USA.
20
21
22 """Base classes for worker pools.
23
24 """
25
26 import logging
27 import threading
28 import heapq
29 import itertools
30
31 from ganeti import compat
32 from ganeti import errors
33
34
35 _TERMINATE = object()
36 _DEFAULT_PRIORITY = 0
37
38
39 class DeferTask(Exception):
40   """Special exception class to defer a task.
41
42   This class can be raised by L{BaseWorker.RunTask} to defer the execution of a
43   task. Optionally, the priority of the task can be changed.
44
45   """
46   def __init__(self, priority=None):
47     """Initializes this class.
48
49     @type priority: number
50     @param priority: New task priority (None means no change)
51
52     """
53     Exception.__init__(self)
54     self.priority = priority
55
56
57 class BaseWorker(threading.Thread, object):
58   """Base worker class for worker pools.
59
60   Users of a worker pool must override RunTask in a subclass.
61
62   """
63   # pylint: disable=W0212
64   def __init__(self, pool, worker_id):
65     """Constructor for BaseWorker thread.
66
67     @param pool: the parent worker pool
68     @param worker_id: identifier for this worker
69
70     """
71     super(BaseWorker, self).__init__(name=worker_id)
72     self.pool = pool
73     self._worker_id = worker_id
74     self._current_task = None
75
76     assert self.getName() == worker_id
77
78   def ShouldTerminate(self):
79     """Returns whether this worker should terminate.
80
81     Should only be called from within L{RunTask}.
82
83     """
84     self.pool._lock.acquire()
85     try:
86       assert self._HasRunningTaskUnlocked()
87       return self.pool._ShouldWorkerTerminateUnlocked(self)
88     finally:
89       self.pool._lock.release()
90
91   def GetCurrentPriority(self):
92     """Returns the priority of the current task.
93
94     Should only be called from within L{RunTask}.
95
96     """
97     self.pool._lock.acquire()
98     try:
99       assert self._HasRunningTaskUnlocked()
100
101       (priority, _, _) = self._current_task
102
103       return priority
104     finally:
105       self.pool._lock.release()
106
107   def SetTaskName(self, taskname):
108     """Sets the name of the current task.
109
110     Should only be called from within L{RunTask}.
111
112     @type taskname: string
113     @param taskname: Task's name
114
115     """
116     if taskname:
117       name = "%s/%s" % (self._worker_id, taskname)
118     else:
119       name = self._worker_id
120
121     # Set thread name
122     self.setName(name)
123
124   def _HasRunningTaskUnlocked(self):
125     """Returns whether this worker is currently running a task.
126
127     """
128     return (self._current_task is not None)
129
130   def run(self):
131     """Main thread function.
132
133     Waits for new tasks to show up in the queue.
134
135     """
136     pool = self.pool
137
138     while True:
139       assert self._current_task is None
140
141       defer = None
142       try:
143         # Wait on lock to be told either to terminate or to do a task
144         pool._lock.acquire()
145         try:
146           task = pool._WaitForTaskUnlocked(self)
147
148           if task is _TERMINATE:
149             # Told to terminate
150             break
151
152           if task is None:
153             # Spurious notification, ignore
154             continue
155
156           self._current_task = task
157
158           # No longer needed, dispose of reference
159           del task
160
161           assert self._HasRunningTaskUnlocked()
162
163         finally:
164           pool._lock.release()
165
166         (priority, _, args) = self._current_task
167         try:
168           # Run the actual task
169           assert defer is None
170           logging.debug("Starting task %r, priority %s", args, priority)
171           assert self.getName() == self._worker_id
172           try:
173             self.RunTask(*args) # pylint: disable=W0142
174           finally:
175             self.SetTaskName(None)
176           logging.debug("Done with task %r, priority %s", args, priority)
177         except DeferTask, err:
178           defer = err
179
180           if defer.priority is None:
181             # Use same priority
182             defer.priority = priority
183
184           logging.debug("Deferring task %r, new priority %s",
185                         args, defer.priority)
186
187           assert self._HasRunningTaskUnlocked()
188         except: # pylint: disable=W0702
189           logging.exception("Caught unhandled exception")
190
191         assert self._HasRunningTaskUnlocked()
192       finally:
193         # Notify pool
194         pool._lock.acquire()
195         try:
196           if defer:
197             assert self._current_task
198             # Schedule again for later run
199             (_, _, args) = self._current_task
200             pool._AddTaskUnlocked(args, defer.priority)
201
202           if self._current_task:
203             self._current_task = None
204             pool._worker_to_pool.notifyAll()
205         finally:
206           pool._lock.release()
207
208       assert not self._HasRunningTaskUnlocked()
209
210     logging.debug("Terminates")
211
212   def RunTask(self, *args):
213     """Function called to start a task.
214
215     This needs to be implemented by child classes.
216
217     """
218     raise NotImplementedError()
219
220
221 class WorkerPool(object):
222   """Worker pool with a queue.
223
224   This class is thread-safe.
225
226   Tasks are guaranteed to be started in the order in which they're
227   added to the pool. Due to the nature of threading, they're not
228   guaranteed to finish in the same order.
229
230   """
231   def __init__(self, name, num_workers, worker_class):
232     """Constructor for worker pool.
233
234     @param num_workers: number of workers to be started
235         (dynamic resizing is not yet implemented)
236     @param worker_class: the class to be instantiated for workers;
237         should derive from L{BaseWorker}
238
239     """
240     # Some of these variables are accessed by BaseWorker
241     self._lock = threading.Lock()
242     self._pool_to_pool = threading.Condition(self._lock)
243     self._pool_to_worker = threading.Condition(self._lock)
244     self._worker_to_pool = threading.Condition(self._lock)
245     self._worker_class = worker_class
246     self._name = name
247     self._last_worker_id = 0
248     self._workers = []
249     self._quiescing = False
250     self._active = True
251
252     # Terminating workers
253     self._termworkers = []
254
255     # Queued tasks
256     self._counter = itertools.count()
257     self._tasks = []
258
259     # Start workers
260     self.Resize(num_workers)
261
262   # TODO: Implement dynamic resizing?
263
264   def _WaitWhileQuiescingUnlocked(self):
265     """Wait until the worker pool has finished quiescing.
266
267     """
268     while self._quiescing:
269       self._pool_to_pool.wait()
270
271   def _AddTaskUnlocked(self, args, priority):
272     """Adds a task to the internal queue.
273
274     @type args: sequence
275     @param args: Arguments passed to L{BaseWorker.RunTask}
276     @type priority: number
277     @param priority: Task priority
278
279     """
280     assert isinstance(args, (tuple, list)), "Arguments must be a sequence"
281     assert isinstance(priority, (int, long)), "Priority must be numeric"
282
283     # A counter is used to ensure elements are processed in their incoming
284     # order. For processing they're sorted by priority and then counter.
285     heapq.heappush(self._tasks, (priority, self._counter.next(), args))
286
287     # Notify a waiting worker
288     self._pool_to_worker.notify()
289
290   def AddTask(self, args, priority=_DEFAULT_PRIORITY):
291     """Adds a task to the queue.
292
293     @type args: sequence
294     @param args: arguments passed to L{BaseWorker.RunTask}
295     @type priority: number
296     @param priority: Task priority
297
298     """
299     self._lock.acquire()
300     try:
301       self._WaitWhileQuiescingUnlocked()
302       self._AddTaskUnlocked(args, priority)
303     finally:
304       self._lock.release()
305
306   def AddManyTasks(self, tasks, priority=_DEFAULT_PRIORITY):
307     """Add a list of tasks to the queue.
308
309     @type tasks: list of tuples
310     @param tasks: list of args passed to L{BaseWorker.RunTask}
311     @type priority: number or list of numbers
312     @param priority: Priority for all added tasks or a list with the priority
313                      for each task
314
315     """
316     assert compat.all(isinstance(task, (tuple, list)) for task in tasks), \
317       "Each task must be a sequence"
318
319     assert (isinstance(priority, (int, long)) or
320             compat.all(isinstance(prio, (int, long)) for prio in priority)), \
321            "Priority must be numeric or be a list of numeric values"
322
323     if isinstance(priority, (int, long)):
324       priority = [priority] * len(tasks)
325     elif len(priority) != len(tasks):
326       raise errors.ProgrammerError("Number of priorities (%s) doesn't match"
327                                    " number of tasks (%s)" %
328                                    (len(priority), len(tasks)))
329
330     self._lock.acquire()
331     try:
332       self._WaitWhileQuiescingUnlocked()
333
334       assert compat.all(isinstance(prio, (int, long)) for prio in priority)
335       assert len(tasks) == len(priority)
336
337       for args, prio in zip(tasks, priority):
338         self._AddTaskUnlocked(args, prio)
339     finally:
340       self._lock.release()
341
342   def SetActive(self, active):
343     """Enable/disable processing of tasks.
344
345     This is different from L{Quiesce} in the sense that this function just
346     changes an internal flag and doesn't wait for the queue to be empty. Tasks
347     already being processed continue normally, but no new tasks will be
348     started. New tasks can still be added.
349
350     @type active: bool
351     @param active: Whether tasks should be processed
352
353     """
354     self._lock.acquire()
355     try:
356       self._active = active
357
358       if active:
359         # Tell all workers to continue processing
360         self._pool_to_worker.notifyAll()
361     finally:
362       self._lock.release()
363
364   def _WaitForTaskUnlocked(self, worker):
365     """Waits for a task for a worker.
366
367     @type worker: L{BaseWorker}
368     @param worker: Worker thread
369
370     """
371     while True:
372       if self._ShouldWorkerTerminateUnlocked(worker):
373         return _TERMINATE
374
375       # If there's a pending task, return it immediately
376       if self._active and self._tasks:
377         # Get task from queue and tell pool about it
378         try:
379           task = heapq.heappop(self._tasks)
380         finally:
381           self._worker_to_pool.notifyAll()
382
383         return task
384
385       logging.debug("Waiting for tasks")
386
387       # wait() releases the lock and sleeps until notified
388       self._pool_to_worker.wait()
389
390       logging.debug("Notified while waiting")
391
392   def _ShouldWorkerTerminateUnlocked(self, worker):
393     """Returns whether a worker should terminate.
394
395     """
396     return (worker in self._termworkers)
397
398   def _HasRunningTasksUnlocked(self):
399     """Checks whether there's a task running in a worker.
400
401     """
402     for worker in self._workers + self._termworkers:
403       if worker._HasRunningTaskUnlocked(): # pylint: disable=W0212
404         return True
405     return False
406
407   def HasRunningTasks(self):
408     """Checks whether there's at least one task running.
409
410     """
411     self._lock.acquire()
412     try:
413       return self._HasRunningTasksUnlocked()
414     finally:
415       self._lock.release()
416
417   def Quiesce(self):
418     """Waits until the task queue is empty.
419
420     """
421     self._lock.acquire()
422     try:
423       self._quiescing = True
424
425       # Wait while there are tasks pending or running
426       while self._tasks or self._HasRunningTasksUnlocked():
427         self._worker_to_pool.wait()
428
429     finally:
430       self._quiescing = False
431
432       # Make sure AddTasks continues in case it was waiting
433       self._pool_to_pool.notifyAll()
434
435       self._lock.release()
436
437   def _NewWorkerIdUnlocked(self):
438     """Return an identifier for a new worker.
439
440     """
441     self._last_worker_id += 1
442
443     return "%s%d" % (self._name, self._last_worker_id)
444
445   def _ResizeUnlocked(self, num_workers):
446     """Changes the number of workers.
447
448     """
449     assert num_workers >= 0, "num_workers must be >= 0"
450
451     logging.debug("Resizing to %s workers", num_workers)
452
453     current_count = len(self._workers)
454
455     if current_count == num_workers:
456       # Nothing to do
457       pass
458
459     elif current_count > num_workers:
460       if num_workers == 0:
461         # Create copy of list to iterate over while lock isn't held.
462         termworkers = self._workers[:]
463         del self._workers[:]
464       else:
465         # TODO: Implement partial downsizing
466         raise NotImplementedError()
467         #termworkers = ...
468
469       self._termworkers += termworkers
470
471       # Notify workers that something has changed
472       self._pool_to_worker.notifyAll()
473
474       # Join all terminating workers
475       self._lock.release()
476       try:
477         for worker in termworkers:
478           logging.debug("Waiting for thread %s", worker.getName())
479           worker.join()
480       finally:
481         self._lock.acquire()
482
483       # Remove terminated threads. This could be done in a more efficient way
484       # (del self._termworkers[:]), but checking worker.isAlive() makes sure we
485       # don't leave zombie threads around.
486       for worker in termworkers:
487         assert worker in self._termworkers, ("Worker not in list of"
488                                              " terminating workers")
489         if not worker.isAlive():
490           self._termworkers.remove(worker)
491
492       assert not self._termworkers, "Zombie worker detected"
493
494     elif current_count < num_workers:
495       # Create (num_workers - current_count) new workers
496       for _ in range(num_workers - current_count):
497         worker = self._worker_class(self, self._NewWorkerIdUnlocked())
498         self._workers.append(worker)
499         worker.start()
500
501   def Resize(self, num_workers):
502     """Changes the number of workers in the pool.
503
504     @param num_workers: the new number of workers
505
506     """
507     self._lock.acquire()
508     try:
509       return self._ResizeUnlocked(num_workers)
510     finally:
511       self._lock.release()
512
513   def TerminateWorkers(self):
514     """Terminate all worker threads.
515
516     Unstarted tasks will be ignored.
517
518     """
519     logging.debug("Terminating all workers")
520
521     self._lock.acquire()
522     try:
523       self._ResizeUnlocked(0)
524
525       if self._tasks:
526         logging.debug("There are %s tasks left", len(self._tasks))
527     finally:
528       self._lock.release()
529
530     logging.debug("All workers terminated")