Fixed frequent timeout by always closing WebResponse objects
[pithos-ms-client] / trunk / Pithos.Network / RestClient.cs
1 // -----------------------------------------------------------------------
2 // <copyright file="RestClient.cs" company="Microsoft">
3 // TODO: Update copyright text.
4 // </copyright>
5 // -----------------------------------------------------------------------
6
7 using System.Collections.Specialized;
8 using System.Diagnostics;
9 using System.Diagnostics.Contracts;
10 using System.IO;
11 using System.Net;
12 using System.Runtime.Serialization;
13 using System.Threading.Tasks;
14 using log4net;
15
16
17 namespace Pithos.Network
18 {
19     using System;
20     using System.Collections.Generic;
21     using System.Linq;
22     using System.Text;
23
24     /// <summary>
25     /// TODO: Update summary.
26     /// </summary>
27     public class RestClient:WebClient
28     {
29         public int Timeout { get; set; }
30
31         public bool TimedOut { get; set; }
32
33         public HttpStatusCode StatusCode { get; private set; }
34
35         public string StatusDescription { get; set; }
36
37         public long? RangeFrom { get; set; }
38         public long? RangeTo { get; set; }
39
40         public int Retries { get; set; }
41
42         private readonly Dictionary<string, string> _parameters=new Dictionary<string, string>();
43         public Dictionary<string, string> Parameters
44         {
45             get
46             {
47                 Contract.Ensures(_parameters!=null);
48                 return _parameters;
49             }            
50         }
51
52         private static readonly ILog Log = LogManager.GetLogger("RestClient");
53
54
55         [ContractInvariantMethod]
56         private void Invariants()
57         {
58             Contract.Invariant(Headers!=null);    
59         }
60
61         public RestClient():base()
62         {
63             
64         }
65
66        
67         public RestClient(RestClient other)
68             : base()
69         {
70             if (other==null)
71                 throw new ArgumentNullException("other");
72             Contract.EndContractBlock();
73
74             CopyHeaders(other);
75             Timeout = other.Timeout;
76             Retries = other.Retries;
77             BaseAddress = other.BaseAddress;             
78
79             foreach (var parameter in other.Parameters)
80             {
81                 Parameters.Add(parameter.Key,parameter.Value);
82             }
83
84             this.Proxy = other.Proxy;
85         }
86
87
88         private WebHeaderCollection _responseHeaders;
89
90         public new WebHeaderCollection ResponseHeaders
91         {
92             get
93             {
94                 if (base.ResponseHeaders==null)
95                 {
96                     return _responseHeaders;
97                 }
98                 else
99                 {
100                     _responseHeaders = null;
101                     return base.ResponseHeaders;   
102                 }
103                 
104             }
105
106             set { _responseHeaders = value; }
107         } 
108         protected override WebRequest GetWebRequest(Uri address)
109         {
110             TimedOut = false;
111             var webRequest = base.GetWebRequest(address);            
112             var request = (HttpWebRequest)webRequest;
113             if (IfModifiedSince.HasValue)
114                 request.IfModifiedSince = IfModifiedSince.Value;
115             request.AutomaticDecompression = DecompressionMethods.Deflate | DecompressionMethods.GZip;
116             if(Timeout>0)
117                 request.Timeout = Timeout;
118
119             if (RangeFrom.HasValue)
120             {
121                 if (RangeTo.HasValue)
122                     request.AddRange(RangeFrom.Value, RangeTo.Value);
123                 else
124                     request.AddRange(RangeFrom.Value);
125             }
126             return request; 
127         }
128
129         public DateTime? IfModifiedSince { get; set; }
130
131         //Asynchronous version
132         protected override WebResponse GetWebResponse(WebRequest request, IAsyncResult result)
133         {
134             Log.InfoFormat("ASYNC [{0}] {1}",request.Method, request.RequestUri);
135             HttpWebResponse response = null;
136
137             try
138             {
139                 response = (HttpWebResponse)base.GetWebResponse(request, result);
140             }
141             catch (WebException exc)
142             {
143                 if (!TryGetResponse(exc, out response))
144                     throw;
145             }
146
147             StatusCode = response.StatusCode;
148             LastModified = response.LastModified;
149             StatusDescription = response.StatusDescription;
150             return response;
151
152         }
153       
154
155         //Synchronous version
156         protected override WebResponse GetWebResponse(WebRequest request)
157         {
158             HttpWebResponse response = null;
159             try
160             {                                
161                 response = (HttpWebResponse)base.GetWebResponse(request);
162             }
163             catch (WebException exc)
164             {
165                 if (!TryGetResponse(exc, out response))
166                     throw;
167             }
168
169             StatusCode = response.StatusCode;
170             LastModified = response.LastModified;
171             StatusDescription = response.StatusDescription;
172             return response;
173         }
174
175         private bool TryGetResponse(WebException exc, out HttpWebResponse response)
176         {
177             response = null;
178             //Fail on empty response
179             if (exc.Response == null)
180                 return false;
181
182             response = (exc.Response as HttpWebResponse);
183             //Succeed on allowed status codes
184             if (AllowedStatusCodes.Contains(response.StatusCode))
185                 return true;
186
187             //Does the response have any content to log?
188             if (exc.Response.ContentLength > 0)
189             {
190                 var content = GetContent(exc.Response);
191                 Log.ErrorFormat(content);
192             }
193             return false;
194         }
195
196         private readonly List<HttpStatusCode> _allowedStatusCodes=new List<HttpStatusCode>{HttpStatusCode.NotModified};        
197
198         public List<HttpStatusCode> AllowedStatusCodes
199         {
200             get
201             {
202                 return _allowedStatusCodes;
203             }            
204         }
205
206         public DateTime LastModified { get; private set; }
207
208         private static string GetContent(WebResponse webResponse)
209         {
210             if (webResponse == null)
211                 throw new ArgumentNullException("webResponse");
212             Contract.EndContractBlock();
213
214             string content;
215             using (var stream = webResponse.GetResponseStream())
216             using (var reader = new StreamReader(stream))
217             {
218                 content = reader.ReadToEnd();
219             }
220             return content;
221         }
222
223         public string DownloadStringWithRetry(string address,int retries=0)
224         {
225             
226             if (address == null)
227                 throw new ArgumentNullException("address");
228
229             var actualAddress = GetActualAddress(address);
230
231             TraceStart("GET",actualAddress);            
232             
233             var actualRetries = (retries == 0) ? Retries : retries;
234
235             
236             var task = Retry(() =>
237             {
238                 var uriString = String.Join("/", BaseAddress.TrimEnd('/'), actualAddress);                
239                 var content = base.DownloadString(uriString);
240
241                 if (StatusCode == HttpStatusCode.NoContent)
242                     return String.Empty;
243                 return content;
244
245             }, actualRetries);
246
247             var result = task.Result;
248             return result;
249         }
250
251         public void Head(string address,int retries=0)
252         {
253             AllowedStatusCodes.Add(HttpStatusCode.NotFound);
254             RetryWithoutContent(address, retries, "HEAD");
255         }
256
257         public void PutWithRetry(string address, int retries = 0)
258         {
259             RetryWithoutContent(address, retries, "PUT");
260         }
261
262         public void DeleteWithRetry(string address,int retries=0)
263         {
264             RetryWithoutContent(address, retries, "DELETE");
265         }
266
267         public string GetHeaderValue(string headerName,bool optional=false)
268         {
269             if (this.ResponseHeaders==null)
270                 throw new InvalidOperationException("ResponseHeaders are null");
271             Contract.EndContractBlock();
272
273             var values=this.ResponseHeaders.GetValues(headerName);
274             if (values != null)
275                 return values[0];
276
277             if (optional)            
278                 return null;            
279             //A required header was not found
280             throw new WebException(String.Format("The {0}  header is missing", headerName));
281         }
282
283         public void SetNonEmptyHeaderValue(string headerName, string value)
284         {
285             if (String.IsNullOrWhiteSpace(value))
286                 return;
287             Headers.Add(headerName,value);
288         }
289
290         private void RetryWithoutContent(string address, int retries, string method)
291         {
292             if (address == null)
293                 throw new ArgumentNullException("address");
294
295             var actualAddress = GetActualAddress(address);            
296             var actualRetries = (retries == 0) ? Retries : retries;
297
298             var task = Retry(() =>
299             {
300                 var uriString = String.Join("/",BaseAddress ,actualAddress);
301                 var uri = new Uri(uriString);
302                 var request =  GetWebRequest(uri);
303                 request.Method = method;
304                 if (ResponseHeaders!=null)
305                     ResponseHeaders.Clear();
306
307                 TraceStart(method, uriString);
308                 if (method == "PUT")
309                     request.ContentLength = 0;
310
311                 var response = (HttpWebResponse)GetWebResponse(request);
312                 //var response = (HttpWebResponse)request.GetResponse();
313                 
314                 //ResponseHeaders= response.Headers;
315
316                 LastModified = response.LastModified;
317                 StatusCode = response.StatusCode;
318                 StatusDescription = response.StatusDescription;
319                 response.Close();
320
321                 return 0;
322             }, actualRetries);
323
324             try
325             {
326                 task.Wait();
327             }
328             catch (AggregateException ex)
329             {
330                 var exc = ex.InnerException;
331                 if (exc is RetryException)
332                 {
333                     Log.ErrorFormat("[{0}] RETRY FAILED for {1} after {2} retries",method,address,retries);
334                 }
335                 else
336                 {
337                     Log.ErrorFormat("[{0}] FAILED for {1} with \n{2}", method, address, exc);
338                 }
339                 throw exc;
340
341             }
342             catch(Exception ex)
343             {
344                 Log.ErrorFormat("[{0}] FAILED for {1} with \n{2}", method, address, ex);
345                 throw;
346             }
347         }
348         
349         private static void TraceStart(string method, string actualAddress)
350         {
351             Log.InfoFormat("[{0}] {1} {2}", method, DateTime.Now, actualAddress);
352         }
353
354         private string GetActualAddress(string address)
355         {
356             if (Parameters.Count == 0)
357                 return address;
358             var addressBuilder=new StringBuilder(address);            
359
360             bool isFirst = true;
361             foreach (var parameter in Parameters)
362             {
363                 if(isFirst)
364                     addressBuilder.AppendFormat("?{0}={1}", parameter.Key, parameter.Value);
365                 else
366                     addressBuilder.AppendFormat("&{0}={1}", parameter.Key, parameter.Value);
367                 isFirst = false;
368             }
369             return addressBuilder.ToString();
370         }
371
372         public string DownloadStringWithRetry(Uri address,int retries=0)
373         {
374             if (address == null)
375                 throw new ArgumentNullException("address");
376
377             var actualRetries = (retries == 0) ? Retries : retries;            
378             var task = Retry(() =>
379             {
380                 var content = base.DownloadString(address);
381
382                 if (StatusCode == HttpStatusCode.NoContent)
383                     return String.Empty;
384                 return content;
385
386             }, actualRetries);
387
388             var result = task.Result;
389             return result;
390         }
391
392       
393         /// <summary>
394         /// Copies headers from another RestClient
395         /// </summary>
396         /// <param name="source">The RestClient from which the headers are copied</param>
397         public void CopyHeaders(RestClient source)
398         {
399             if (source == null)
400                 throw new ArgumentNullException("source", "source can't be null");
401             Contract.EndContractBlock();
402             //The Headers getter initializes the property, it is never null
403             Contract.Assume(Headers!=null);
404                 
405             CopyHeaders(source.Headers,Headers);
406         }
407         
408         /// <summary>
409         /// Copies headers from one header collection to another
410         /// </summary>
411         /// <param name="source">The source collection from which the headers are copied</param>
412         /// <param name="target">The target collection to which the headers are copied</param>
413         public static void CopyHeaders(WebHeaderCollection source,WebHeaderCollection target)
414         {
415             if (source == null)
416                 throw new ArgumentNullException("source", "source can't be null");
417             if (target == null)
418                 throw new ArgumentNullException("target", "target can't be null");
419             Contract.EndContractBlock();
420
421             for (int i = 0; i < source.Count; i++)
422             {
423                 target.Add(source.GetKey(i), source[i]);
424             }            
425         }
426
427         public void AssertStatusOK(string message)
428         {
429             if (StatusCode >= HttpStatusCode.BadRequest)
430                 throw new WebException(String.Format("{0} with code {1} - {2}", message, StatusCode, StatusDescription));
431         }
432
433
434         private Task<T> Retry<T>(Func<T> original, int retryCount, TaskCompletionSource<T> tcs = null)
435         {
436             if (original==null)
437                 throw new ArgumentNullException("original");
438             Contract.EndContractBlock();
439
440             if (tcs == null)
441                 tcs = new TaskCompletionSource<T>();
442             Task.Factory.StartNew(original).ContinueWith(_original =>
443                 {
444                     if (!_original.IsFaulted)
445                         tcs.SetFromTask(_original);
446                     else 
447                     {
448                         var e = _original.Exception.InnerException;
449                         var we = (e as WebException);
450                         if (we==null)
451                             tcs.SetException(e);
452                         else
453                         {
454                             var statusCode = GetStatusCode(we);
455
456                             //Return null for 404
457                             if (statusCode == HttpStatusCode.NotFound)
458                                 tcs.SetResult(default(T));
459                             //Retry for timeouts and service unavailable
460                             else if (we.Status == WebExceptionStatus.Timeout ||
461                                 (we.Status == WebExceptionStatus.ProtocolError && statusCode == HttpStatusCode.ServiceUnavailable))
462                             {
463                                 TimedOut = true;
464                                 if (retryCount == 0)
465                                 {                                    
466                                     Log.ErrorFormat("[ERROR] Timed out too many times. \n{0}\n",e);
467                                     tcs.SetException(new RetryException("Timed out too many times.", e));                                    
468                                 }
469                                 else
470                                 {
471                                     Log.ErrorFormat(
472                                         "[RETRY] Timed out after {0} ms. Will retry {1} more times\n{2}", Timeout,
473                                         retryCount, e);
474                                     Retry(original, retryCount - 1, tcs);
475                                 }
476                             }
477                             else
478                                 tcs.SetException(e);
479                         }
480                     };
481                 });
482             return tcs.Task;
483         }
484
485         private HttpStatusCode GetStatusCode(WebException we)
486         {
487             if (we==null)
488                 throw new ArgumentNullException("we");
489             var statusCode = HttpStatusCode.RequestTimeout;
490             if (we.Response != null)
491             {
492                 statusCode = ((HttpWebResponse) we.Response).StatusCode;
493                 this.StatusCode = statusCode;
494             }
495             return statusCode;
496         }
497
498         public UriBuilder GetAddressBuilder(string container, string objectName)
499         {
500             var builder = new UriBuilder(String.Join("/", BaseAddress, container, objectName));
501             return builder;
502         }
503
504         public Dictionary<string, string> GetMeta(string metaPrefix)
505         {
506             if (String.IsNullOrWhiteSpace(metaPrefix))
507                 throw new ArgumentNullException("metaPrefix");
508             Contract.EndContractBlock();
509
510             var keys = ResponseHeaders.AllKeys.AsQueryable();
511             var dict = (from key in keys
512                         where key.StartsWith(metaPrefix)
513                         let name = key.Substring(metaPrefix.Length)
514                         select new { Name = name, Value = ResponseHeaders[key] })
515                         .ToDictionary(t => t.Name, t => t.Value);
516             return dict;
517         }
518     }
519
520     public class RetryException:Exception
521     {
522         public RetryException()
523             :base()
524         {
525             
526         }
527
528         public RetryException(string message)
529             :base(message)
530         {
531             
532         }
533
534         public RetryException(string message,Exception innerException)
535             :base(message,innerException)
536         {
537             
538         }
539
540         public RetryException(SerializationInfo info,StreamingContext context)
541             :base(info,context)
542         {
543             
544         }
545     }
546 }