Statistics
| Branch: | Tag: | Revision:

root / lib / workerpool.py @ 5fbbd028

History | View | Annotate | Download (13.3 kB)

1
#
2
#
3

    
4
# Copyright (C) 2008, 2009, 2010 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Base classes for worker pools.
23

24
"""
25

    
26
import logging
27
import threading
28
import heapq
29

    
30
from ganeti import compat
31
from ganeti import errors
32

    
33

    
34
_TERMINATE = object()
35
_DEFAULT_PRIORITY = 0
36

    
37

    
38
class DeferTask(Exception):
39
  """Special exception class to defer a task.
40

41
  This class can be raised by L{BaseWorker.RunTask} to defer the execution of a
42
  task. Optionally, the priority of the task can be changed.
43

44
  """
45
  def __init__(self, priority=None):
46
    """Initializes this class.
47

48
    @type priority: number
49
    @param priority: New task priority (None means no change)
50

51
    """
52
    Exception.__init__(self)
53
    self.priority = priority
54

    
55

    
56
class BaseWorker(threading.Thread, object):
57
  """Base worker class for worker pools.
58

59
  Users of a worker pool must override RunTask in a subclass.
60

61
  """
62
  # pylint: disable-msg=W0212
63
  def __init__(self, pool, worker_id):
64
    """Constructor for BaseWorker thread.
65

66
    @param pool: the parent worker pool
67
    @param worker_id: identifier for this worker
68

69
    """
70
    super(BaseWorker, self).__init__(name=worker_id)
71
    self.pool = pool
72
    self._worker_id = worker_id
73
    self._current_task = None
74

    
75
    assert self.getName() == worker_id
76

    
77
  def ShouldTerminate(self):
78
    """Returns whether this worker should terminate.
79

80
    Should only be called from within L{RunTask}.
81

82
    """
83
    self.pool._lock.acquire()
84
    try:
85
      assert self._HasRunningTaskUnlocked()
86
      return self.pool._ShouldWorkerTerminateUnlocked(self)
87
    finally:
88
      self.pool._lock.release()
89

    
90
  def GetCurrentPriority(self):
91
    """Returns the priority of the current task.
92

93
    Should only be called from within L{RunTask}.
94

95
    """
96
    self.pool._lock.acquire()
97
    try:
98
      assert self._HasRunningTaskUnlocked()
99

    
100
      (priority, _, _) = self._current_task
101

    
102
      return priority
103
    finally:
104
      self.pool._lock.release()
105

    
106
  def SetTaskName(self, taskname):
107
    """Sets the name of the current task.
108

109
    Should only be called from within L{RunTask}.
110

111
    @type taskname: string
112
    @param taskname: Task's name
113

114
    """
115
    if taskname:
116
      name = "%s/%s" % (self._worker_id, taskname)
117
    else:
118
      name = self._worker_id
119

    
120
    # Set thread name
121
    self.setName(name)
122

    
123
  def _HasRunningTaskUnlocked(self):
124
    """Returns whether this worker is currently running a task.
125

126
    """
127
    return (self._current_task is not None)
128

    
129
  def run(self):
130
    """Main thread function.
131

132
    Waits for new tasks to show up in the queue.
133

134
    """
135
    pool = self.pool
136

    
137
    while True:
138
      assert self._current_task is None
139

    
140
      defer = None
141
      try:
142
        # Wait on lock to be told either to terminate or to do a task
143
        pool._lock.acquire()
144
        try:
145
          task = pool._WaitForTaskUnlocked(self)
146

    
147
          if task is _TERMINATE:
148
            # Told to terminate
149
            break
150

    
151
          if task is None:
152
            # Spurious notification, ignore
153
            continue
154

    
155
          self._current_task = task
156

    
157
          # No longer needed, dispose of reference
158
          del task
159

    
160
          assert self._HasRunningTaskUnlocked()
161

    
162
        finally:
163
          pool._lock.release()
164

    
165
        (priority, _, args) = self._current_task
166
        try:
167
          # Run the actual task
168
          assert defer is None
169
          logging.debug("Starting task %r, priority %s", args, priority)
170
          assert self.getName() == self._worker_id
171
          try:
172
            self.RunTask(*args) # pylint: disable-msg=W0142
173
          finally:
174
            self.SetTaskName(None)
175
          logging.debug("Done with task %r, priority %s", args, priority)
176
        except DeferTask, err:
177
          defer = err
178

    
179
          if defer.priority is None:
180
            # Use same priority
181
            defer.priority = priority
182

    
183
          logging.debug("Deferring task %r, new priority %s", defer.priority)
184

    
185
          assert self._HasRunningTaskUnlocked()
186
        except: # pylint: disable-msg=W0702
187
          logging.exception("Caught unhandled exception")
188

    
189
        assert self._HasRunningTaskUnlocked()
190
      finally:
191
        # Notify pool
192
        pool._lock.acquire()
193
        try:
194
          if defer:
195
            assert self._current_task
196
            # Schedule again for later run
197
            (_, _, args) = self._current_task
198
            pool._AddTaskUnlocked(args, defer.priority)
199

    
200
          if self._current_task:
201
            self._current_task = None
202
            pool._worker_to_pool.notifyAll()
203
        finally:
204
          pool._lock.release()
205

    
206
      assert not self._HasRunningTaskUnlocked()
207

    
208
    logging.debug("Terminates")
209

    
210
  def RunTask(self, *args):
211
    """Function called to start a task.
212

213
    This needs to be implemented by child classes.
214

215
    """
216
    raise NotImplementedError()
217

    
218

    
219
class WorkerPool(object):
220
  """Worker pool with a queue.
221

222
  This class is thread-safe.
223

224
  Tasks are guaranteed to be started in the order in which they're
225
  added to the pool. Due to the nature of threading, they're not
226
  guaranteed to finish in the same order.
227

228
  """
229
  def __init__(self, name, num_workers, worker_class):
230
    """Constructor for worker pool.
231

232
    @param num_workers: number of workers to be started
233
        (dynamic resizing is not yet implemented)
234
    @param worker_class: the class to be instantiated for workers;
235
        should derive from L{BaseWorker}
236

237
    """
238
    # Some of these variables are accessed by BaseWorker
239
    self._lock = threading.Lock()
240
    self._pool_to_pool = threading.Condition(self._lock)
241
    self._pool_to_worker = threading.Condition(self._lock)
242
    self._worker_to_pool = threading.Condition(self._lock)
243
    self._worker_class = worker_class
244
    self._name = name
245
    self._last_worker_id = 0
246
    self._workers = []
247
    self._quiescing = False
248

    
249
    # Terminating workers
250
    self._termworkers = []
251

    
252
    # Queued tasks
253
    self._counter = 0
254
    self._tasks = []
255

    
256
    # Start workers
257
    self.Resize(num_workers)
258

    
259
  # TODO: Implement dynamic resizing?
260

    
261
  def _WaitWhileQuiescingUnlocked(self):
262
    """Wait until the worker pool has finished quiescing.
263

264
    """
265
    while self._quiescing:
266
      self._pool_to_pool.wait()
267

    
268
  def _AddTaskUnlocked(self, args, priority):
269
    """Adds a task to the internal queue.
270

271
    @type args: sequence
272
    @param args: Arguments passed to L{BaseWorker.RunTask}
273
    @type priority: number
274
    @param priority: Task priority
275

276
    """
277
    assert isinstance(args, (tuple, list)), "Arguments must be a sequence"
278
    assert isinstance(priority, (int, long)), "Priority must be numeric"
279

    
280
    # This counter is used to ensure elements are processed in their
281
    # incoming order. For processing they're sorted by priority and then
282
    # counter.
283
    self._counter += 1
284

    
285
    heapq.heappush(self._tasks, (priority, self._counter, args))
286

    
287
    # Notify a waiting worker
288
    self._pool_to_worker.notify()
289

    
290
  def AddTask(self, args, priority=_DEFAULT_PRIORITY):
291
    """Adds a task to the queue.
292

293
    @type args: sequence
294
    @param args: arguments passed to L{BaseWorker.RunTask}
295
    @type priority: number
296
    @param priority: Task priority
297

298
    """
299
    self._lock.acquire()
300
    try:
301
      self._WaitWhileQuiescingUnlocked()
302
      self._AddTaskUnlocked(args, priority)
303
    finally:
304
      self._lock.release()
305

    
306
  def AddManyTasks(self, tasks, priority=_DEFAULT_PRIORITY):
307
    """Add a list of tasks to the queue.
308

309
    @type tasks: list of tuples
310
    @param tasks: list of args passed to L{BaseWorker.RunTask}
311
    @type priority: number or list of numbers
312
    @param priority: Priority for all added tasks or a list with the priority
313
                     for each task
314

315
    """
316
    assert compat.all(isinstance(task, (tuple, list)) for task in tasks), \
317
      "Each task must be a sequence"
318

    
319
    assert (isinstance(priority, (int, long)) or
320
            compat.all(isinstance(prio, (int, long)) for prio in priority)), \
321
           "Priority must be numeric or be a list of numeric values"
322

    
323
    if isinstance(priority, (int, long)):
324
      priority = [priority] * len(tasks)
325
    elif len(priority) != len(tasks):
326
      raise errors.ProgrammerError("Number of priorities (%s) doesn't match"
327
                                   " number of tasks (%s)" %
328
                                   (len(priority), len(tasks)))
329

    
330
    self._lock.acquire()
331
    try:
332
      self._WaitWhileQuiescingUnlocked()
333

    
334
      assert compat.all(isinstance(prio, (int, long)) for prio in priority)
335
      assert len(tasks) == len(priority)
336

    
337
      for args, priority in zip(tasks, priority):
338
        self._AddTaskUnlocked(args, priority)
339
    finally:
340
      self._lock.release()
341

    
342
  def _WaitForTaskUnlocked(self, worker):
343
    """Waits for a task for a worker.
344

345
    @type worker: L{BaseWorker}
346
    @param worker: Worker thread
347

348
    """
349
    if self._ShouldWorkerTerminateUnlocked(worker):
350
      return _TERMINATE
351

    
352
    # We only wait if there's no task for us.
353
    if not self._tasks:
354
      logging.debug("Waiting for tasks")
355

    
356
      # wait() releases the lock and sleeps until notified
357
      self._pool_to_worker.wait()
358

    
359
      logging.debug("Notified while waiting")
360

    
361
      # Were we woken up in order to terminate?
362
      if self._ShouldWorkerTerminateUnlocked(worker):
363
        return _TERMINATE
364

    
365
      if not self._tasks:
366
        # Spurious notification, ignore
367
        return None
368

    
369
    # Get task from queue and tell pool about it
370
    try:
371
      return heapq.heappop(self._tasks)
372
    finally:
373
      self._worker_to_pool.notifyAll()
374

    
375
  def _ShouldWorkerTerminateUnlocked(self, worker):
376
    """Returns whether a worker should terminate.
377

378
    """
379
    return (worker in self._termworkers)
380

    
381
  def _HasRunningTasksUnlocked(self):
382
    """Checks whether there's a task running in a worker.
383

384
    """
385
    for worker in self._workers + self._termworkers:
386
      if worker._HasRunningTaskUnlocked(): # pylint: disable-msg=W0212
387
        return True
388
    return False
389

    
390
  def Quiesce(self):
391
    """Waits until the task queue is empty.
392

393
    """
394
    self._lock.acquire()
395
    try:
396
      self._quiescing = True
397

    
398
      # Wait while there are tasks pending or running
399
      while self._tasks or self._HasRunningTasksUnlocked():
400
        self._worker_to_pool.wait()
401

    
402
    finally:
403
      self._quiescing = False
404

    
405
      # Make sure AddTasks continues in case it was waiting
406
      self._pool_to_pool.notifyAll()
407

    
408
      self._lock.release()
409

    
410
  def _NewWorkerIdUnlocked(self):
411
    """Return an identifier for a new worker.
412

413
    """
414
    self._last_worker_id += 1
415

    
416
    return "%s%d" % (self._name, self._last_worker_id)
417

    
418
  def _ResizeUnlocked(self, num_workers):
419
    """Changes the number of workers.
420

421
    """
422
    assert num_workers >= 0, "num_workers must be >= 0"
423

    
424
    logging.debug("Resizing to %s workers", num_workers)
425

    
426
    current_count = len(self._workers)
427

    
428
    if current_count == num_workers:
429
      # Nothing to do
430
      pass
431

    
432
    elif current_count > num_workers:
433
      if num_workers == 0:
434
        # Create copy of list to iterate over while lock isn't held.
435
        termworkers = self._workers[:]
436
        del self._workers[:]
437
      else:
438
        # TODO: Implement partial downsizing
439
        raise NotImplementedError()
440
        #termworkers = ...
441

    
442
      self._termworkers += termworkers
443

    
444
      # Notify workers that something has changed
445
      self._pool_to_worker.notifyAll()
446

    
447
      # Join all terminating workers
448
      self._lock.release()
449
      try:
450
        for worker in termworkers:
451
          logging.debug("Waiting for thread %s", worker.getName())
452
          worker.join()
453
      finally:
454
        self._lock.acquire()
455

    
456
      # Remove terminated threads. This could be done in a more efficient way
457
      # (del self._termworkers[:]), but checking worker.isAlive() makes sure we
458
      # don't leave zombie threads around.
459
      for worker in termworkers:
460
        assert worker in self._termworkers, ("Worker not in list of"
461
                                             " terminating workers")
462
        if not worker.isAlive():
463
          self._termworkers.remove(worker)
464

    
465
      assert not self._termworkers, "Zombie worker detected"
466

    
467
    elif current_count < num_workers:
468
      # Create (num_workers - current_count) new workers
469
      for _ in range(num_workers - current_count):
470
        worker = self._worker_class(self, self._NewWorkerIdUnlocked())
471
        self._workers.append(worker)
472
        worker.start()
473

    
474
  def Resize(self, num_workers):
475
    """Changes the number of workers in the pool.
476

477
    @param num_workers: the new number of workers
478

479
    """
480
    self._lock.acquire()
481
    try:
482
      return self._ResizeUnlocked(num_workers)
483
    finally:
484
      self._lock.release()
485

    
486
  def TerminateWorkers(self):
487
    """Terminate all worker threads.
488

489
    Unstarted tasks will be ignored.
490

491
    """
492
    logging.debug("Terminating all workers")
493

    
494
    self._lock.acquire()
495
    try:
496
      self._ResizeUnlocked(0)
497

    
498
      if self._tasks:
499
        logging.debug("There are %s tasks left", len(self._tasks))
500
    finally:
501
      self._lock.release()
502

    
503
    logging.debug("All workers terminated")