Statistics
| Branch: | Tag: | Revision:

root / lib / workerpool.py @ 5a9c3f46

History | View | Annotate | Download (9 kB)

1
#
2
#
3

    
4
# Copyright (C) 2008 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Base classes for worker pools.
23

24
"""
25

    
26
import collections
27
import logging
28
import threading
29

    
30

    
31
class BaseWorker(threading.Thread, object):
32
  """Base worker class for worker pools.
33

34
  Users of a worker pool must override RunTask in a subclass.
35

36
  """
37
  def __init__(self, pool, worker_id):
38
    """Constructor for BaseWorker thread.
39

40
    @param pool: the parent worker pool
41
    @param worker_id: identifier for this worker
42

43
    """
44
    super(BaseWorker, self).__init__()
45
    self.pool = pool
46
    self.worker_id = worker_id
47
    self._current_task = None
48

    
49
  def ShouldTerminate(self):
50
    """Returns whether a worker should terminate.
51

52
    """
53
    return self.pool.ShouldWorkerTerminate(self)
54

    
55
  def _HasRunningTaskUnlocked(self):
56
    """Returns whether this worker is currently running a task.
57

58
    """
59
    return (self._current_task is not None)
60

    
61
  def HasRunningTask(self):
62
    """Returns whether this worker is currently running a task.
63

64
    """
65
    self.pool._lock.acquire()
66
    try:
67
      return self._HasRunningTaskUnlocked()
68
    finally:
69
      self.pool._lock.release()
70

    
71
  def run(self):
72
    """Main thread function.
73

74
    Waits for new tasks to show up in the queue.
75

76
    """
77
    pool = self.pool
78

    
79
    assert not self.HasRunningTask()
80

    
81
    while True:
82
      try:
83
        # We wait on lock to be told either terminate or do a task.
84
        pool._lock.acquire()
85
        try:
86
          if pool._ShouldWorkerTerminateUnlocked(self):
87
            break
88

    
89
          # We only wait if there's no task for us.
90
          if not pool._tasks:
91
            logging.debug("Worker %s: waiting for tasks", self.worker_id)
92

    
93
            # wait() releases the lock and sleeps until notified
94
            pool._pool_to_worker.wait()
95

    
96
            logging.debug("Worker %s: notified while waiting", self.worker_id)
97

    
98
            # Were we woken up in order to terminate?
99
            if pool._ShouldWorkerTerminateUnlocked(self):
100
              break
101

    
102
            if not pool._tasks:
103
              # Spurious notification, ignore
104
              continue
105

    
106
          # Get task from queue and tell pool about it
107
          try:
108
            self._current_task = pool._tasks.popleft()
109
          finally:
110
            pool._worker_to_pool.notifyAll()
111
        finally:
112
          pool._lock.release()
113

    
114
        # Run the actual task
115
        try:
116
          logging.debug("Worker %s: starting task %r",
117
                        self.worker_id, self._current_task)
118
          self.RunTask(*self._current_task)
119
          logging.debug("Worker %s: done with task %r",
120
                        self.worker_id, self._current_task)
121
        except:
122
          logging.error("Worker %s: Caught unhandled exception",
123
                        self.worker_id, exc_info=True)
124
      finally:
125
        # Notify pool
126
        pool._lock.acquire()
127
        try:
128
          if self._current_task:
129
            self._current_task = None
130
            pool._worker_to_pool.notifyAll()
131
        finally:
132
          pool._lock.release()
133

    
134
    logging.debug("Worker %s: terminates", self.worker_id)
135

    
136
  def RunTask(self, *args):
137
    """Function called to start a task.
138

139
    This needs to be implemented by child classes.
140

141
    """
142
    raise NotImplementedError()
143

    
144

    
145
class WorkerPool(object):
146
  """Worker pool with a queue.
147

148
  This class is thread-safe.
149

150
  Tasks are guaranteed to be started in the order in which they're
151
  added to the pool. Due to the nature of threading, they're not
152
  guaranteed to finish in the same order.
153

154
  """
155
  def __init__(self, num_workers, worker_class):
156
    """Constructor for worker pool.
157

158
    @param num_workers: number of workers to be started
159
        (dynamic resizing is not yet implemented)
160
    @param worker_class: the class to be instantiated for workers;
161
        should derive from L{BaseWorker}
162

163
    """
164
    # Some of these variables are accessed by BaseWorker
165
    self._lock = threading.Lock()
166
    self._pool_to_pool = threading.Condition(self._lock)
167
    self._pool_to_worker = threading.Condition(self._lock)
168
    self._worker_to_pool = threading.Condition(self._lock)
169
    self._worker_class = worker_class
170
    self._last_worker_id = 0
171
    self._workers = []
172
    self._quiescing = False
173

    
174
    # Terminating workers
175
    self._termworkers = []
176

    
177
    # Queued tasks
178
    self._tasks = collections.deque()
179

    
180
    # Start workers
181
    self.Resize(num_workers)
182

    
183
  # TODO: Implement dynamic resizing?
184

    
185
  def AddTask(self, *args):
186
    """Adds a task to the queue.
187

188
    @param args: arguments passed to L{BaseWorker.RunTask}
189

190
    """
191
    self._lock.acquire()
192
    try:
193
      # Don't add new tasks while we're quiescing
194
      while self._quiescing:
195
        self._pool_to_pool.wait()
196

    
197
      # Add task to internal queue
198
      self._tasks.append(args)
199

    
200
      # Wake one idling worker up
201
      self._pool_to_worker.notify()
202
    finally:
203
      self._lock.release()
204

    
205
  def _ShouldWorkerTerminateUnlocked(self, worker):
206
    """Returns whether a worker should terminate.
207

208
    """
209
    return (worker in self._termworkers)
210

    
211
  def ShouldWorkerTerminate(self, worker):
212
    """Returns whether a worker should terminate.
213

214
    """
215
    self._lock.acquire()
216
    try:
217
      return self._ShouldWorkerTerminateUnlocked(worker)
218
    finally:
219
      self._lock.release()
220

    
221
  def _HasRunningTasksUnlocked(self):
222
    """Checks whether there's a task running in a worker.
223

224
    """
225
    for worker in self._workers + self._termworkers:
226
      if worker._HasRunningTaskUnlocked():
227
        return True
228
    return False
229

    
230
  def Quiesce(self):
231
    """Waits until the task queue is empty.
232

233
    """
234
    self._lock.acquire()
235
    try:
236
      self._quiescing = True
237

    
238
      # Wait while there are tasks pending or running
239
      while self._tasks or self._HasRunningTasksUnlocked():
240
        self._worker_to_pool.wait()
241

    
242
    finally:
243
      self._quiescing = False
244

    
245
      # Make sure AddTasks continues in case it was waiting
246
      self._pool_to_pool.notifyAll()
247

    
248
      self._lock.release()
249

    
250
  def _NewWorkerIdUnlocked(self):
251
    """Return an identifier for a new worker.
252

253
    """
254
    self._last_worker_id += 1
255
    return self._last_worker_id
256

    
257
  def _ResizeUnlocked(self, num_workers):
258
    """Changes the number of workers.
259

260
    """
261
    assert num_workers >= 0, "num_workers must be >= 0"
262

    
263
    logging.debug("Resizing to %s workers", num_workers)
264

    
265
    current_count = len(self._workers)
266

    
267
    if current_count == num_workers:
268
      # Nothing to do
269
      pass
270

    
271
    elif current_count > num_workers:
272
      if num_workers == 0:
273
        # Create copy of list to iterate over while lock isn't held.
274
        termworkers = self._workers[:]
275
        del self._workers[:]
276
      else:
277
        # TODO: Implement partial downsizing
278
        raise NotImplementedError()
279
        #termworkers = ...
280

    
281
      self._termworkers += termworkers
282

    
283
      # Notify workers that something has changed
284
      self._pool_to_worker.notifyAll()
285

    
286
      # Join all terminating workers
287
      self._lock.release()
288
      try:
289
        for worker in termworkers:
290
          logging.debug("Waiting for thread %s", worker.getName())
291
          worker.join()
292
      finally:
293
        self._lock.acquire()
294

    
295
      # Remove terminated threads. This could be done in a more efficient way
296
      # (del self._termworkers[:]), but checking worker.isAlive() makes sure we
297
      # don't leave zombie threads around.
298
      for worker in termworkers:
299
        assert worker in self._termworkers, ("Worker not in list of"
300
                                             " terminating workers")
301
        if not worker.isAlive():
302
          self._termworkers.remove(worker)
303

    
304
      assert not self._termworkers, "Zombie worker detected"
305

    
306
    elif current_count < num_workers:
307
      # Create (num_workers - current_count) new workers
308
      for _ in xrange(num_workers - current_count):
309
        worker = self._worker_class(self, self._NewWorkerIdUnlocked())
310
        self._workers.append(worker)
311
        worker.start()
312

    
313
  def Resize(self, num_workers):
314
    """Changes the number of workers in the pool.
315

316
    @param num_workers: the new number of workers
317

318
    """
319
    self._lock.acquire()
320
    try:
321
      return self._ResizeUnlocked(num_workers)
322
    finally:
323
      self._lock.release()
324

    
325
  def TerminateWorkers(self):
326
    """Terminate all worker threads.
327

328
    Unstarted tasks will be ignored.
329

330
    """
331
    logging.debug("Terminating all workers")
332

    
333
    self._lock.acquire()
334
    try:
335
      self._ResizeUnlocked(0)
336

    
337
      if self._tasks:
338
        logging.debug("There are %s tasks left", len(self._tasks))
339
    finally:
340
      self._lock.release()
341

    
342
    logging.debug("All workers terminated")