Statistics
| Branch: | Revision:

root / trunk / Libraries / ParallelExtensionsExtras / TaskSchedulers / WorkStealingTaskScheduler.cs @ d78cbf09

History | View | Annotate | Download (18.8 kB)

1
//--------------------------------------------------------------------------
2
// 
3
//  Copyright (c) Microsoft Corporation.  All rights reserved. 
4
// 
5
//  File: WorkStealingTaskScheduler.cs
6
//
7
//--------------------------------------------------------------------------
8

    
9
using System.Collections.Generic;
10

    
11
namespace System.Threading.Tasks.Schedulers
12
{
13
    /// <summary>Provides a work-stealing scheduler.</summary>
14
    public class WorkStealingTaskScheduler : TaskScheduler, IDisposable
15
    {
16
        private readonly int m_concurrencyLevel;
17
        private readonly Queue<Task> m_queue = new Queue<Task>();
18
        private WorkStealingQueue<Task>[] m_wsQueues = new WorkStealingQueue<Task>[Environment.ProcessorCount];
19
        private Lazy<Thread[]> m_threads;
20
        private int m_threadsWaiting;
21
        private bool m_shutdown;
22
        [ThreadStatic]
23
        private static WorkStealingQueue<Task> m_wsq;
24

    
25
        /// <summary>Initializes a new instance of the WorkStealingTaskScheduler class.</summary>
26
        /// <remarks>This constructors defaults to using twice as many threads as there are processors.</remarks>
27
        public WorkStealingTaskScheduler() : this(Environment.ProcessorCount * 2) { }
28

    
29
        /// <summary>Initializes a new instance of the WorkStealingTaskScheduler class.</summary>
30
        /// <param name="concurrencyLevel">The number of threads to use in the scheduler.</param>
31
        public WorkStealingTaskScheduler(int concurrencyLevel)
32
        {
33
            // Store the concurrency level
34
            if (concurrencyLevel <= 0) throw new ArgumentOutOfRangeException("concurrencyLevel");
35
            m_concurrencyLevel = concurrencyLevel;
36

    
37
            // Set up threads
38
            m_threads = new Lazy<Thread[]>(() =>
39
            {
40
                var threads = new Thread[m_concurrencyLevel];
41
                for (int i = 0; i < threads.Length; i++)
42
                {
43
                    threads[i] = new Thread(DispatchLoop) { IsBackground = true };
44
                    threads[i].Start();
45
                }
46
                return threads;
47
            });
48
        }
49

    
50
        /// <summary>Queues a task to the scheduler.</summary>
51
        /// <param name="task">The task to be scheduled.</param>
52
        protected override void QueueTask(Task task)
53
        {
54
            // Make sure the pool is started, e.g. that all threads have been created.
55
            m_threads.Force();
56

    
57
            // If the task is marked as long-running, give it its own dedicated thread
58
            // rather than queueing it.
59
            if ((task.CreationOptions & TaskCreationOptions.LongRunning) != 0)
60
            {
61
                new Thread(state => base.TryExecuteTask((Task)state)) { IsBackground = true }.Start(task);
62
            }
63
            else
64
            {
65
                // Otherwise, insert the work item into a queue, possibly waking a thread.
66
                // If there's a local queue and the task does not prefer to be in the global queue,
67
                // add it to the local queue.
68
                WorkStealingQueue<Task> wsq = m_wsq;
69
                if (wsq != null && ((task.CreationOptions & TaskCreationOptions.PreferFairness) == 0))
70
                {
71
                    // Add to the local queue and notify any waiting threads that work is available.
72
                    // Races may occur which result in missed event notifications, but they're benign in that
73
                    // this thread will eventually pick up the work item anyway, as will other threads when another
74
                    // work item notification is received.
75
                    wsq.LocalPush(task);
76
                    if (m_threadsWaiting > 0) // OK to read lock-free.
77
                    {
78
                        lock (m_queue) { Monitor.Pulse(m_queue); }
79
                    }
80
                }
81
                // Otherwise, add the work item to the global queue
82
                else
83
                {
84
                    lock (m_queue)
85
                    {
86
                        m_queue.Enqueue(task);
87
                        if (m_threadsWaiting > 0) Monitor.Pulse(m_queue);
88
                    }
89
                }
90
            }
91
        }
92

    
93
        /// <summary>Executes a task on the current thread.</summary>
94
        /// <param name="task">The task to be executed.</param>
95
        /// <param name="taskWasPreviouslyQueued">Ignored.</param>
96
        /// <returns>Whether the task could be executed.</returns>
97
        protected override bool TryExecuteTaskInline(Task task, bool taskWasPreviouslyQueued)
98
        {
99
            return TryExecuteTask(task);
100

    
101
            // // Optional replacement: Instead of always trying to execute the task (which could
102
            // // benignly leave a task in the queue that's already been executed), we
103
            // // can search the current work-stealing queue and remove the task,
104
            // // executing it inline only if it's found.
105
            // WorkStealingQueue<Task> wsq = m_wsq;
106
            // return wsq != null && wsq.TryFindAndPop(task) && TryExecuteTask(task);
107
        }
108

    
109
        /// <summary>Gets the maximum concurrency level supported by this scheduler.</summary>
110
        public override int MaximumConcurrencyLevel
111
        {
112
            get { return m_concurrencyLevel; }
113
        }
114

    
115
        /// <summary>Gets all of the tasks currently scheduled to this scheduler.</summary>
116
        /// <returns>An enumerable containing all of the scheduled tasks.</returns>
117
        protected override IEnumerable<Task> GetScheduledTasks()
118
        {
119
            // Keep track of all of the tasks we find
120
            List<Task> tasks = new List<Task>();
121

    
122
            // Get all of the global tasks.  We use TryEnter so as not to hang
123
            // a debugger if the lock is held by a frozen thread.
124
            bool lockTaken = false;
125
            try
126
            {
127
                Monitor.TryEnter(m_queue, ref lockTaken);
128
                if (lockTaken) tasks.AddRange(m_queue.ToArray());
129
                else throw new NotSupportedException();
130
            }
131
            finally
132
            {
133
                if (lockTaken) Monitor.Exit(m_queue);
134
            }
135

    
136
            // Now get all of the tasks from the work-stealing queues
137
            WorkStealingQueue<Task>[] queues = m_wsQueues;
138
            for (int i = 0; i < queues.Length; i++)
139
            {
140
                WorkStealingQueue<Task> wsq = queues[i];
141
                if (wsq != null) tasks.AddRange(wsq.ToArray());
142
            }
143

    
144
            // Return to the debugger all of the collected task instances
145
            return tasks;
146
        }
147

    
148
        /// <summary>Adds a work-stealing queue to the set of queues.</summary>
149
        /// <param name="wsq">The queue to be added.</param>
150
        private void AddWsq(WorkStealingQueue<Task> wsq)
151
        {
152
            lock (m_wsQueues)
153
            {
154
                // Find the next open slot in the array. If we find one,
155
                // store the queue and we're done.
156
                int i;
157
                for (i = 0; i < m_wsQueues.Length; i++)
158
                {
159
                    if (m_wsQueues[i] == null)
160
                    {
161
                        m_wsQueues[i] = wsq;
162
                        return;
163
                    }
164
                }
165

    
166
                // We couldn't find an open slot, so double the length 
167
                // of the array by creating a new one, copying over,
168
                // and storing the new one. Here, i == m_wsQueues.Length.
169
                WorkStealingQueue<Task>[] queues = new WorkStealingQueue<Task>[i * 2];
170
                Array.Copy(m_wsQueues, queues, i);
171
                queues[i] = wsq;
172
                m_wsQueues = queues;
173
            }
174
        }
175

    
176
        /// <summary>Remove a work-stealing queue from the set of queues.</summary>
177
        /// <param name="wsq">The work-stealing queue to remove.</param>
178
        private void RemoveWsq(WorkStealingQueue<Task> wsq)
179
        {
180
            lock (m_wsQueues)
181
            {
182
                // Find the queue, and if/when we find it, null out its array slot
183
                for (int i = 0; i < m_wsQueues.Length; i++)
184
                {
185
                    if (m_wsQueues[i] == wsq)
186
                    {
187
                        m_wsQueues[i] = null;
188
                    }
189
                }
190
            }
191
        }
192

    
193
        /// <summary>
194
        /// The dispatch loop run by each thread in the scheduler.
195
        /// </summary>
196
        private void DispatchLoop()
197
        {
198
            // Create a new queue for this thread, store it in TLS for later retrieval,
199
            // and add it to the set of queues for this scheduler.
200
            WorkStealingQueue<Task> wsq = new WorkStealingQueue<Task>();
201
            m_wsq = wsq;
202
            AddWsq(wsq);
203

    
204
            try
205
            {
206
                // Until there's no more work to do...
207
                while (true)
208
                {
209
                    Task wi = null;
210

    
211
                    // Search order: (1) local WSQ, (2) global Q, (3) steals from other queues.
212
                    if (!wsq.LocalPop(ref wi))
213
                    {
214
                        // We weren't able to get a task from the local WSQ
215
                        bool searchedForSteals = false;
216
                        while (true)
217
                        {
218
                            lock (m_queue)
219
                            {
220
                                // If shutdown was requested, exit the thread.
221
                                if (m_shutdown)
222
                                    return;
223

    
224
                                // (2) try the global queue.
225
                                if (m_queue.Count != 0)
226
                                {
227
                                    // We found a work item! Grab it ...
228
                                    wi = m_queue.Dequeue();
229
                                    break;
230
                                }
231
                                else if (searchedForSteals)
232
                                {
233
                                    // Note that we're not waiting for work, and then wait
234
                                    m_threadsWaiting++;
235
                                    try { Monitor.Wait(m_queue); }
236
                                    finally { m_threadsWaiting--; }
237

    
238
                                    // If we were signaled due to shutdown, exit the thread.
239
                                    if (m_shutdown)
240
                                        return;
241

    
242
                                    searchedForSteals = false;
243
                                    continue;
244
                                }
245
                            }
246

    
247
                            // (3) try to steal.
248
                            WorkStealingQueue<Task>[] wsQueues = m_wsQueues;
249
                            int i;
250
                            for (i = 0; i < wsQueues.Length; i++)
251
                            {
252
                                WorkStealingQueue<Task> q = wsQueues[i];
253
                                if (q != null && q != wsq && q.TrySteal(ref wi)) break;
254
                            }
255

    
256
                            if (i != wsQueues.Length) break;
257

    
258
                            searchedForSteals = true;
259
                        }
260
                    }
261

    
262
                    // ...and Invoke it.
263
                    TryExecuteTask(wi);
264
                }
265
            }
266
            finally
267
            {
268
                RemoveWsq(wsq);
269
            }
270
        }
271

    
272
        /// <summary>Signal the scheduler to shutdown and wait for all threads to finish.</summary>
273
        public void Dispose()
274
        {
275
            m_shutdown = true;
276
            if (m_queue != null && m_threads.IsValueCreated)
277
            {
278
                var threads = m_threads.Value;
279
                lock (m_queue) Monitor.PulseAll(m_queue);
280
                for (int i = 0; i < threads.Length; i++) threads[i].Join();
281
            }
282
        }
283
    }
284

    
285
    /// <summary>A work-stealing queue.</summary>
286
    /// <typeparam name="T">Specifies the type of data stored in the queue.</typeparam>
287
    internal class WorkStealingQueue<T> where T : class
288
    {
289
        private const int INITIAL_SIZE = 32;
290
        private T[] m_array = new T[INITIAL_SIZE];
291
        private int m_mask = INITIAL_SIZE - 1;
292
        private volatile int m_headIndex = 0;
293
        private volatile int m_tailIndex = 0;
294

    
295
        private object m_foreignLock = new object();
296

    
297
        internal void LocalPush(T obj)
298
        {
299
            int tail = m_tailIndex;
300

    
301
            // When there are at least 2 elements' worth of space, we can take the fast path.
302
            if (tail < m_headIndex + m_mask)
303
            {
304
                m_array[tail & m_mask] = obj;
305
                m_tailIndex = tail + 1;
306
            }
307
            else
308
            {
309
                // We need to contend with foreign pops, so we lock.
310
                lock (m_foreignLock)
311
                {
312
                    int head = m_headIndex;
313
                    int count = m_tailIndex - m_headIndex;
314

    
315
                    // If there is still space (one left), just add the element.
316
                    if (count >= m_mask)
317
                    {
318
                        // We're full; expand the queue by doubling its size.
319
                        T[] newArray = new T[m_array.Length << 1];
320
                        for (int i = 0; i < m_array.Length; i++)
321
                            newArray[i] = m_array[(i + head) & m_mask];
322

    
323
                        // Reset the field values, incl. the mask.
324
                        m_array = newArray;
325
                        m_headIndex = 0;
326
                        m_tailIndex = tail = count;
327
                        m_mask = (m_mask << 1) | 1;
328
                    }
329

    
330
                    m_array[tail & m_mask] = obj;
331
                    m_tailIndex = tail + 1;
332
                }
333
            }
334
        }
335

    
336
        internal bool LocalPop(ref T obj)
337
        {
338
            while (true)
339
            {
340
                // Decrement the tail using a fence to ensure subsequent read doesn't come before.
341
                int tail = m_tailIndex;
342
                if (m_headIndex >= tail)
343
                {
344
                    obj = null;
345
                    return false;
346
                }
347

    
348
                tail -= 1;
349
#pragma warning disable 0420
350
                Interlocked.Exchange(ref m_tailIndex, tail);
351
#pragma warning restore 0420
352

    
353
                // If there is no interaction with a take, we can head down the fast path.
354
                if (m_headIndex <= tail)
355
                {
356
                    int idx = tail & m_mask;
357
                    obj = m_array[idx];
358

    
359
                    // Check for nulls in the array.
360
                    if (obj == null) continue;
361

    
362
                    m_array[idx] = null;
363
                    return true;
364
                }
365
                else
366
                {
367
                    // Interaction with takes: 0 or 1 elements left.
368
                    lock (m_foreignLock)
369
                    {
370
                        if (m_headIndex <= tail)
371
                        {
372
                            // Element still available. Take it.
373
                            int idx = tail & m_mask;
374
                            obj = m_array[idx];
375

    
376
                            // Check for nulls in the array.
377
                            if (obj == null) continue;
378

    
379
                            m_array[idx] = null;
380
                            return true;
381
                        }
382
                        else
383
                        {
384
                            // We lost the race, element was stolen, restore the tail.
385
                            m_tailIndex = tail + 1;
386
                            obj = null;
387
                            return false;
388
                        }
389
                    }
390
                }
391
            }
392
        }
393

    
394
        internal bool TrySteal(ref T obj)
395
        {
396
            obj = null;
397

    
398
            while (true)
399
            {
400
                if (m_headIndex >= m_tailIndex)
401
                    return false;
402

    
403
                lock (m_foreignLock)
404
                {
405
                    // Increment head, and ensure read of tail doesn't move before it (fence).
406
                    int head = m_headIndex;
407
#pragma warning disable 0420
408
                    Interlocked.Exchange(ref m_headIndex, head + 1);
409
#pragma warning restore 0420
410

    
411
                    if (head < m_tailIndex)
412
                    {
413
                        int idx = head & m_mask;
414
                        obj = m_array[idx];
415

    
416
                        // Check for nulls in the array.
417
                        if (obj == null) continue;
418

    
419
                        m_array[idx] = null;
420
                        return true;
421
                    }
422
                    else
423
                    {
424
                        // Failed, restore head.
425
                        m_headIndex = head;
426
                        obj = null;
427
                    }
428
                }
429

    
430
                return false;
431
            }
432
        }
433

    
434
        internal bool TryFindAndPop(T obj)
435
        {
436
            // We do an O(N) search for the work item. The theory of work stealing and our
437
            // inlining logic is that most waits will happen on recently queued work.  And
438
            // since recently queued work will be close to the tail end (which is where we
439
            // begin our search), we will likely find it quickly.  In the worst case, we
440
            // will traverse the whole local queue; this is typically not going to be a
441
            // problem (although degenerate cases are clearly an issue) because local work
442
            // queues tend to be somewhat shallow in length, and because if we fail to find
443
            // the work item, we are about to block anyway (which is very expensive).
444

    
445
            for (int i = m_tailIndex - 1; i >= m_headIndex; i--)
446
            {
447
                if (m_array[i & m_mask] == obj)
448
                {
449
                    // If we found the element, block out steals to avoid interference.
450
                    lock (m_foreignLock)
451
                    {
452
                        // If we lost the race, bail.
453
                        if (m_array[i & m_mask] == null)
454
                        {
455
                            return false;
456
                        }
457

    
458
                        // Otherwise, null out the element.
459
                        m_array[i & m_mask] = null;
460

    
461
                        // And then check to see if we can fix up the indexes (if we're at
462
                        // the edge).  If we can't, we just leave nulls in the array and they'll
463
                        // get filtered out eventually (but may lead to superflous resizing).
464
                        if (i == m_tailIndex)
465
                            m_tailIndex -= 1;
466
                        else if (i == m_headIndex)
467
                            m_headIndex += 1;
468

    
469
                        return true;
470
                    }
471
                }
472
            }
473

    
474
            return false;
475
        }
476

    
477
        internal T[] ToArray()
478
        {
479
            List<T> list = new List<T>();
480
            for (int i = m_tailIndex - 1; i >= m_headIndex; i--)
481
            {
482
                T obj = m_array[i & m_mask];
483
                if (obj != null) list.Add(obj);
484
            }
485
            return list.ToArray();
486
        }
487
    }
488
}