objects: Add custom de-/serializing code for query responses
[ganeti-local] / lib / workerpool.py
1 #
2 #
3
4 # Copyright (C) 2008, 2009, 2010 Google Inc.
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 # 02110-1301, USA.
20
21
22 """Base classes for worker pools.
23
24 """
25
26 import logging
27 import threading
28 import heapq
29
30 from ganeti import compat
31 from ganeti import errors
32
33
34 _TERMINATE = object()
35 _DEFAULT_PRIORITY = 0
36
37
38 class DeferTask(Exception):
39   """Special exception class to defer a task.
40
41   This class can be raised by L{BaseWorker.RunTask} to defer the execution of a
42   task. Optionally, the priority of the task can be changed.
43
44   """
45   def __init__(self, priority=None):
46     """Initializes this class.
47
48     @type priority: number
49     @param priority: New task priority (None means no change)
50
51     """
52     Exception.__init__(self)
53     self.priority = priority
54
55
56 class BaseWorker(threading.Thread, object):
57   """Base worker class for worker pools.
58
59   Users of a worker pool must override RunTask in a subclass.
60
61   """
62   # pylint: disable-msg=W0212
63   def __init__(self, pool, worker_id):
64     """Constructor for BaseWorker thread.
65
66     @param pool: the parent worker pool
67     @param worker_id: identifier for this worker
68
69     """
70     super(BaseWorker, self).__init__(name=worker_id)
71     self.pool = pool
72     self._worker_id = worker_id
73     self._current_task = None
74
75     assert self.getName() == worker_id
76
77   def ShouldTerminate(self):
78     """Returns whether this worker should terminate.
79
80     Should only be called from within L{RunTask}.
81
82     """
83     self.pool._lock.acquire()
84     try:
85       assert self._HasRunningTaskUnlocked()
86       return self.pool._ShouldWorkerTerminateUnlocked(self)
87     finally:
88       self.pool._lock.release()
89
90   def GetCurrentPriority(self):
91     """Returns the priority of the current task.
92
93     Should only be called from within L{RunTask}.
94
95     """
96     self.pool._lock.acquire()
97     try:
98       assert self._HasRunningTaskUnlocked()
99
100       (priority, _, _) = self._current_task
101
102       return priority
103     finally:
104       self.pool._lock.release()
105
106   def SetTaskName(self, taskname):
107     """Sets the name of the current task.
108
109     Should only be called from within L{RunTask}.
110
111     @type taskname: string
112     @param taskname: Task's name
113
114     """
115     if taskname:
116       name = "%s/%s" % (self._worker_id, taskname)
117     else:
118       name = self._worker_id
119
120     # Set thread name
121     self.setName(name)
122
123   def _HasRunningTaskUnlocked(self):
124     """Returns whether this worker is currently running a task.
125
126     """
127     return (self._current_task is not None)
128
129   def run(self):
130     """Main thread function.
131
132     Waits for new tasks to show up in the queue.
133
134     """
135     pool = self.pool
136
137     while True:
138       assert self._current_task is None
139
140       defer = None
141       try:
142         # Wait on lock to be told either to terminate or to do a task
143         pool._lock.acquire()
144         try:
145           task = pool._WaitForTaskUnlocked(self)
146
147           if task is _TERMINATE:
148             # Told to terminate
149             break
150
151           if task is None:
152             # Spurious notification, ignore
153             continue
154
155           self._current_task = task
156
157           # No longer needed, dispose of reference
158           del task
159
160           assert self._HasRunningTaskUnlocked()
161
162         finally:
163           pool._lock.release()
164
165         (priority, _, args) = self._current_task
166         try:
167           # Run the actual task
168           assert defer is None
169           logging.debug("Starting task %r, priority %s", args, priority)
170           assert self.getName() == self._worker_id
171           try:
172             self.RunTask(*args) # pylint: disable-msg=W0142
173           finally:
174             self.SetTaskName(None)
175           logging.debug("Done with task %r, priority %s", args, priority)
176         except DeferTask, err:
177           defer = err
178
179           if defer.priority is None:
180             # Use same priority
181             defer.priority = priority
182
183           logging.debug("Deferring task %r, new priority %s",
184                         args, defer.priority)
185
186           assert self._HasRunningTaskUnlocked()
187         except: # pylint: disable-msg=W0702
188           logging.exception("Caught unhandled exception")
189
190         assert self._HasRunningTaskUnlocked()
191       finally:
192         # Notify pool
193         pool._lock.acquire()
194         try:
195           if defer:
196             assert self._current_task
197             # Schedule again for later run
198             (_, _, args) = self._current_task
199             pool._AddTaskUnlocked(args, defer.priority)
200
201           if self._current_task:
202             self._current_task = None
203             pool._worker_to_pool.notifyAll()
204         finally:
205           pool._lock.release()
206
207       assert not self._HasRunningTaskUnlocked()
208
209     logging.debug("Terminates")
210
211   def RunTask(self, *args):
212     """Function called to start a task.
213
214     This needs to be implemented by child classes.
215
216     """
217     raise NotImplementedError()
218
219
220 class WorkerPool(object):
221   """Worker pool with a queue.
222
223   This class is thread-safe.
224
225   Tasks are guaranteed to be started in the order in which they're
226   added to the pool. Due to the nature of threading, they're not
227   guaranteed to finish in the same order.
228
229   """
230   def __init__(self, name, num_workers, worker_class):
231     """Constructor for worker pool.
232
233     @param num_workers: number of workers to be started
234         (dynamic resizing is not yet implemented)
235     @param worker_class: the class to be instantiated for workers;
236         should derive from L{BaseWorker}
237
238     """
239     # Some of these variables are accessed by BaseWorker
240     self._lock = threading.Lock()
241     self._pool_to_pool = threading.Condition(self._lock)
242     self._pool_to_worker = threading.Condition(self._lock)
243     self._worker_to_pool = threading.Condition(self._lock)
244     self._worker_class = worker_class
245     self._name = name
246     self._last_worker_id = 0
247     self._workers = []
248     self._quiescing = False
249
250     # Terminating workers
251     self._termworkers = []
252
253     # Queued tasks
254     self._counter = 0
255     self._tasks = []
256
257     # Start workers
258     self.Resize(num_workers)
259
260   # TODO: Implement dynamic resizing?
261
262   def _WaitWhileQuiescingUnlocked(self):
263     """Wait until the worker pool has finished quiescing.
264
265     """
266     while self._quiescing:
267       self._pool_to_pool.wait()
268
269   def _AddTaskUnlocked(self, args, priority):
270     """Adds a task to the internal queue.
271
272     @type args: sequence
273     @param args: Arguments passed to L{BaseWorker.RunTask}
274     @type priority: number
275     @param priority: Task priority
276
277     """
278     assert isinstance(args, (tuple, list)), "Arguments must be a sequence"
279     assert isinstance(priority, (int, long)), "Priority must be numeric"
280
281     # This counter is used to ensure elements are processed in their
282     # incoming order. For processing they're sorted by priority and then
283     # counter.
284     self._counter += 1
285
286     heapq.heappush(self._tasks, (priority, self._counter, args))
287
288     # Notify a waiting worker
289     self._pool_to_worker.notify()
290
291   def AddTask(self, args, priority=_DEFAULT_PRIORITY):
292     """Adds a task to the queue.
293
294     @type args: sequence
295     @param args: arguments passed to L{BaseWorker.RunTask}
296     @type priority: number
297     @param priority: Task priority
298
299     """
300     self._lock.acquire()
301     try:
302       self._WaitWhileQuiescingUnlocked()
303       self._AddTaskUnlocked(args, priority)
304     finally:
305       self._lock.release()
306
307   def AddManyTasks(self, tasks, priority=_DEFAULT_PRIORITY):
308     """Add a list of tasks to the queue.
309
310     @type tasks: list of tuples
311     @param tasks: list of args passed to L{BaseWorker.RunTask}
312     @type priority: number or list of numbers
313     @param priority: Priority for all added tasks or a list with the priority
314                      for each task
315
316     """
317     assert compat.all(isinstance(task, (tuple, list)) for task in tasks), \
318       "Each task must be a sequence"
319
320     assert (isinstance(priority, (int, long)) or
321             compat.all(isinstance(prio, (int, long)) for prio in priority)), \
322            "Priority must be numeric or be a list of numeric values"
323
324     if isinstance(priority, (int, long)):
325       priority = [priority] * len(tasks)
326     elif len(priority) != len(tasks):
327       raise errors.ProgrammerError("Number of priorities (%s) doesn't match"
328                                    " number of tasks (%s)" %
329                                    (len(priority), len(tasks)))
330
331     self._lock.acquire()
332     try:
333       self._WaitWhileQuiescingUnlocked()
334
335       assert compat.all(isinstance(prio, (int, long)) for prio in priority)
336       assert len(tasks) == len(priority)
337
338       for args, priority in zip(tasks, priority):
339         self._AddTaskUnlocked(args, priority)
340     finally:
341       self._lock.release()
342
343   def _WaitForTaskUnlocked(self, worker):
344     """Waits for a task for a worker.
345
346     @type worker: L{BaseWorker}
347     @param worker: Worker thread
348
349     """
350     if self._ShouldWorkerTerminateUnlocked(worker):
351       return _TERMINATE
352
353     # We only wait if there's no task for us.
354     if not self._tasks:
355       logging.debug("Waiting for tasks")
356
357       # wait() releases the lock and sleeps until notified
358       self._pool_to_worker.wait()
359
360       logging.debug("Notified while waiting")
361
362       # Were we woken up in order to terminate?
363       if self._ShouldWorkerTerminateUnlocked(worker):
364         return _TERMINATE
365
366       if not self._tasks:
367         # Spurious notification, ignore
368         return None
369
370     # Get task from queue and tell pool about it
371     try:
372       return heapq.heappop(self._tasks)
373     finally:
374       self._worker_to_pool.notifyAll()
375
376   def _ShouldWorkerTerminateUnlocked(self, worker):
377     """Returns whether a worker should terminate.
378
379     """
380     return (worker in self._termworkers)
381
382   def _HasRunningTasksUnlocked(self):
383     """Checks whether there's a task running in a worker.
384
385     """
386     for worker in self._workers + self._termworkers:
387       if worker._HasRunningTaskUnlocked(): # pylint: disable-msg=W0212
388         return True
389     return False
390
391   def Quiesce(self):
392     """Waits until the task queue is empty.
393
394     """
395     self._lock.acquire()
396     try:
397       self._quiescing = True
398
399       # Wait while there are tasks pending or running
400       while self._tasks or self._HasRunningTasksUnlocked():
401         self._worker_to_pool.wait()
402
403     finally:
404       self._quiescing = False
405
406       # Make sure AddTasks continues in case it was waiting
407       self._pool_to_pool.notifyAll()
408
409       self._lock.release()
410
411   def _NewWorkerIdUnlocked(self):
412     """Return an identifier for a new worker.
413
414     """
415     self._last_worker_id += 1
416
417     return "%s%d" % (self._name, self._last_worker_id)
418
419   def _ResizeUnlocked(self, num_workers):
420     """Changes the number of workers.
421
422     """
423     assert num_workers >= 0, "num_workers must be >= 0"
424
425     logging.debug("Resizing to %s workers", num_workers)
426
427     current_count = len(self._workers)
428
429     if current_count == num_workers:
430       # Nothing to do
431       pass
432
433     elif current_count > num_workers:
434       if num_workers == 0:
435         # Create copy of list to iterate over while lock isn't held.
436         termworkers = self._workers[:]
437         del self._workers[:]
438       else:
439         # TODO: Implement partial downsizing
440         raise NotImplementedError()
441         #termworkers = ...
442
443       self._termworkers += termworkers
444
445       # Notify workers that something has changed
446       self._pool_to_worker.notifyAll()
447
448       # Join all terminating workers
449       self._lock.release()
450       try:
451         for worker in termworkers:
452           logging.debug("Waiting for thread %s", worker.getName())
453           worker.join()
454       finally:
455         self._lock.acquire()
456
457       # Remove terminated threads. This could be done in a more efficient way
458       # (del self._termworkers[:]), but checking worker.isAlive() makes sure we
459       # don't leave zombie threads around.
460       for worker in termworkers:
461         assert worker in self._termworkers, ("Worker not in list of"
462                                              " terminating workers")
463         if not worker.isAlive():
464           self._termworkers.remove(worker)
465
466       assert not self._termworkers, "Zombie worker detected"
467
468     elif current_count < num_workers:
469       # Create (num_workers - current_count) new workers
470       for _ in range(num_workers - current_count):
471         worker = self._worker_class(self, self._NewWorkerIdUnlocked())
472         self._workers.append(worker)
473         worker.start()
474
475   def Resize(self, num_workers):
476     """Changes the number of workers in the pool.
477
478     @param num_workers: the new number of workers
479
480     """
481     self._lock.acquire()
482     try:
483       return self._ResizeUnlocked(num_workers)
484     finally:
485       self._lock.release()
486
487   def TerminateWorkers(self):
488     """Terminate all worker threads.
489
490     Unstarted tasks will be ignored.
491
492     """
493     logging.debug("Terminating all workers")
494
495     self._lock.acquire()
496     try:
497       self._ResizeUnlocked(0)
498
499       if self._tasks:
500         logging.debug("There are %s tasks left", len(self._tasks))
501     finally:
502       self._lock.release()
503
504     logging.debug("All workers terminated")