// ----------------------------------------------------------------------- // // TODO: Update copyright text. // // ----------------------------------------------------------------------- using System.Collections.Specialized; using System.Diagnostics; using System.Diagnostics.Contracts; using System.IO; using System.Net; using System.Runtime.Serialization; using System.Threading.Tasks; using log4net; namespace Pithos.Network { using System; using System.Collections.Generic; using System.Linq; using System.Text; /// /// TODO: Update summary. /// public class RestClient:WebClient { public int Timeout { get; set; } public bool TimedOut { get; set; } public HttpStatusCode StatusCode { get; private set; } public string StatusDescription { get; set; } public long? RangeFrom { get; set; } public long? RangeTo { get; set; } public int Retries { get; set; } private readonly Dictionary _parameters=new Dictionary(); public Dictionary Parameters { get { Contract.Ensures(_parameters!=null); return _parameters; } } private static readonly ILog Log = LogManager.GetLogger("RestClient"); [ContractInvariantMethod] private void Invariants() { Contract.Invariant(Headers!=null); } public RestClient():base() { } public RestClient(RestClient other) : base() { if (other==null) throw new ArgumentNullException("other"); Contract.EndContractBlock(); CopyHeaders(other); Timeout = other.Timeout; Retries = other.Retries; BaseAddress = other.BaseAddress; foreach (var parameter in other.Parameters) { Parameters.Add(parameter.Key,parameter.Value); } this.Proxy = other.Proxy; } protected override WebRequest GetWebRequest(Uri address) { TimedOut = false; var webRequest = base.GetWebRequest(address); var request = (HttpWebRequest)webRequest; if (IfModifiedSince.HasValue) request.IfModifiedSince = IfModifiedSince.Value; request.AutomaticDecompression = DecompressionMethods.Deflate | DecompressionMethods.GZip; if(Timeout>0) request.Timeout = Timeout; if (RangeFrom.HasValue) { if (RangeTo.HasValue) request.AddRange(RangeFrom.Value, RangeTo.Value); else request.AddRange(RangeFrom.Value); } return request; } public DateTime? IfModifiedSince { get; set; } protected override WebResponse GetWebResponse(WebRequest request, IAsyncResult result) { return ProcessResponse(()=>base.GetWebResponse(request, result)); } protected override WebResponse GetWebResponse(WebRequest request) { return ProcessResponse(() => base.GetWebResponse(request)); } private WebResponse ProcessResponse(Func getResponse) { try { var response = (HttpWebResponse)getResponse(); StatusCode = response.StatusCode; LastModified = response.LastModified; StatusDescription = response.StatusDescription; return response; } catch (WebException exc) { if (exc.Response != null) { var response = (exc.Response as HttpWebResponse); if (AllowedStatusCodes.Contains(response.StatusCode)) { StatusCode = response.StatusCode; LastModified = response.LastModified; StatusDescription = response.StatusDescription; return response; } if (exc.Response.ContentLength > 0) { string content = GetContent(exc.Response); Log.ErrorFormat(content); } } throw; } } private readonly List _allowedStatusCodes=new List{HttpStatusCode.NotModified}; public List AllowedStatusCodes { get { return _allowedStatusCodes; } } public DateTime LastModified { get; private set; } private static string GetContent(WebResponse webResponse) { if (webResponse == null) throw new ArgumentNullException("webResponse"); Contract.EndContractBlock(); string content; using (var stream = webResponse.GetResponseStream()) using (var reader = new StreamReader(stream)) { content = reader.ReadToEnd(); } return content; } public string DownloadStringWithRetry(string address,int retries=0) { if (address == null) throw new ArgumentNullException("address"); var actualAddress = GetActualAddress(address); TraceStart("GET",actualAddress); var actualRetries = (retries == 0) ? Retries : retries; var task = Retry(() => { var uriString = String.Join("/", BaseAddress.TrimEnd('/'), actualAddress); var content = base.DownloadString(uriString); if (StatusCode == HttpStatusCode.NoContent) return String.Empty; return content; }, actualRetries); var result = task.Result; return result; } public void Head(string address,int retries=0) { AllowedStatusCodes.Add(HttpStatusCode.NotFound); RetryWithoutContent(address, retries, "HEAD"); } public void PutWithRetry(string address, int retries = 0) { RetryWithoutContent(address, retries, "PUT"); } public void DeleteWithRetry(string address,int retries=0) { RetryWithoutContent(address, retries, "DELETE"); } public string GetHeaderValue(string headerName,bool optional=false) { if (this.ResponseHeaders==null) throw new InvalidOperationException("ResponseHeaders are null"); Contract.EndContractBlock(); var values=this.ResponseHeaders.GetValues(headerName); if (values != null) return values[0]; if (optional) return null; //A required header was not found throw new WebException(String.Format("The {0} header is missing", headerName)); } public void SetNonEmptyHeaderValue(string headerName, string value) { if (String.IsNullOrWhiteSpace(value)) return; Headers.Add(headerName,value); } private void RetryWithoutContent(string address, int retries, string method) { if (address == null) throw new ArgumentNullException("address"); var actualAddress = GetActualAddress(address); var actualRetries = (retries == 0) ? Retries : retries; var task = Retry(() => { var uriString = String.Join("/",BaseAddress ,actualAddress); var uri = new Uri(uriString); var request = GetWebRequest(uri); request.Method = method; if (ResponseHeaders!=null) ResponseHeaders.Clear(); TraceStart(method, uriString); if (method == "PUT") request.ContentLength = 0; var response = (HttpWebResponse)GetWebResponse(request); StatusCode = response.StatusCode; StatusDescription = response.StatusDescription; return 0; }, actualRetries); try { task.Wait(); } catch (AggregateException ex) { var exc = ex.InnerException; if (exc is RetryException) { Log.ErrorFormat("[{0}] RETRY FAILED for {1} after {2} retries",method,address,retries); } else { Log.ErrorFormat("[{0}] FAILED for {1} with \n{2}", method, address, exc); } throw exc; } catch(Exception ex) { Log.ErrorFormat("[{0}] FAILED for {1} with \n{2}", method, address, ex); throw; } } private static void TraceStart(string method, string actualAddress) { Log.InfoFormat("[{0}] {1} {2}", method, DateTime.Now, actualAddress); } private string GetActualAddress(string address) { if (Parameters.Count == 0) return address; var addressBuilder=new StringBuilder(address); bool isFirst = true; foreach (var parameter in Parameters) { if(isFirst) addressBuilder.AppendFormat("?{0}={1}", parameter.Key, parameter.Value); else addressBuilder.AppendFormat("&{0}={1}", parameter.Key, parameter.Value); isFirst = false; } return addressBuilder.ToString(); } public string DownloadStringWithRetry(Uri address,int retries=0) { if (address == null) throw new ArgumentNullException("address"); var actualRetries = (retries == 0) ? Retries : retries; var task = Retry(() => { var content = base.DownloadString(address); if (StatusCode == HttpStatusCode.NoContent) return String.Empty; return content; }, actualRetries); var result = task.Result; return result; } /// /// Copies headers from another RestClient /// /// The RestClient from which the headers are copied public void CopyHeaders(RestClient source) { if (source == null) throw new ArgumentNullException("source", "source can't be null"); Contract.EndContractBlock(); //The Headers getter initializes the property, it is never null Contract.Assume(Headers!=null); CopyHeaders(source.Headers,Headers); } /// /// Copies headers from one header collection to another /// /// The source collection from which the headers are copied /// The target collection to which the headers are copied public static void CopyHeaders(WebHeaderCollection source,WebHeaderCollection target) { if (source == null) throw new ArgumentNullException("source", "source can't be null"); if (target == null) throw new ArgumentNullException("target", "target can't be null"); Contract.EndContractBlock(); for (int i = 0; i < source.Count; i++) { target.Add(source.GetKey(i), source[i]); } } public void AssertStatusOK(string message) { if (StatusCode >= HttpStatusCode.BadRequest) throw new WebException(String.Format("{0} with code {1} - {2}", message, StatusCode, StatusDescription)); } private Task Retry(Func original, int retryCount, TaskCompletionSource tcs = null) { if (original==null) throw new ArgumentNullException("original"); Contract.EndContractBlock(); if (tcs == null) tcs = new TaskCompletionSource(); Task.Factory.StartNew(original).ContinueWith(_original => { if (!_original.IsFaulted) tcs.SetFromTask(_original); else { var e = _original.Exception.InnerException; var we = (e as WebException); if (we==null) tcs.SetException(e); else { var statusCode = GetStatusCode(we); //Return null for 404 if (statusCode == HttpStatusCode.NotFound) tcs.SetResult(default(T)); //Retry for timeouts and service unavailable else if (we.Status == WebExceptionStatus.Timeout || (we.Status == WebExceptionStatus.ProtocolError && statusCode == HttpStatusCode.ServiceUnavailable)) { TimedOut = true; if (retryCount == 0) { Log.ErrorFormat("[ERROR] Timed out too many times. \n{0}\n",e); tcs.SetException(new RetryException("Timed out too many times.", e)); } else { Log.ErrorFormat( "[RETRY] Timed out after {0} ms. Will retry {1} more times\n{2}", Timeout, retryCount, e); Retry(original, retryCount - 1, tcs); } } else tcs.SetException(e); } }; }); return tcs.Task; } private HttpStatusCode GetStatusCode(WebException we) { if (we==null) throw new ArgumentNullException("we"); var statusCode = HttpStatusCode.RequestTimeout; if (we.Response != null) { statusCode = ((HttpWebResponse) we.Response).StatusCode; this.StatusCode = statusCode; } return statusCode; } public UriBuilder GetAddressBuilder(string container, string objectName) { var builder = new UriBuilder(String.Join("/", BaseAddress, container, objectName)); return builder; } } public class RetryException:Exception { public RetryException() :base() { } public RetryException(string message) :base(message) { } public RetryException(string message,Exception innerException) :base(message,innerException) { } public RetryException(SerializationInfo info,StreamingContext context) :base(info,context) { } } }