Jump To …

RestClient.cs


TODO: Update copyright text.


using System.Collections.Specialized;
using System.Diagnostics;
using System.Diagnostics.Contracts;
using System.IO;
using System.Net;
using System.Runtime.Serialization;
using System.Threading.Tasks;

namespace Pithos.Network
{
    using System;
    using System.Collections.Generic;
    using System.Linq;
    using System.Text;

/

/ TODO: Update summary. /

    public class RestClient:WebClient
    {
        public int Timeout { get; set; }

        public bool TimedOut { get; set; }

        public HttpStatusCode StatusCode { get; private set; }

        public string StatusDescription { get; set; }


        public int Retries { get; set; }

        private readonly Dictionary<string, string> _parameters=new Dictionary<string, string>();
        public Dictionary<string, string> Parameters
        {
            get { return _parameters; }            
        }

        public RestClient():base()
        {
            
        }

       
        public RestClient(RestClient other)
            : base()
        {
            CopyHeaders(other);
            Timeout = other.Timeout;
            Retries = other.Retries;
            BaseAddress = other.BaseAddress;             

            foreach (var parameter in other.Parameters)
            {
                Parameters.Add(parameter.Key,parameter.Value);
            }

            this.Proxy = other.Proxy;
        }

        protected override WebRequest GetWebRequest(Uri address)
        {
            TimedOut = false;
            var webRequest = base.GetWebRequest(address);
            var request = webRequest as HttpWebRequest;
            if (IfModifiedSince.HasValue)
                request.IfModifiedSince = IfModifiedSince.Value;
            request.AutomaticDecompression = DecompressionMethods.Deflate | DecompressionMethods.GZip;
            if(Timeout>0)
                request.Timeout = Timeout;
            return request; 
        }

        public DateTime? IfModifiedSince { get; set; }

        protected override WebResponse GetWebResponse(WebRequest request, IAsyncResult result)
        {
            var response = (HttpWebResponse) base.GetWebResponse(request, result);            
            StatusCode=response.StatusCode;
            StatusDescription=response.StatusDescription;
            return response;
        }



        protected override WebResponse GetWebResponse(WebRequest request)
        {
            try
            {                
                var response = (HttpWebResponse)base.GetWebResponse(request);
                StatusCode = response.StatusCode;
                StatusDescription = response.StatusDescription;
                return response;
            }
            catch (WebException exc)
            {
                if (exc.Response!=null && exc.Response.ContentLength > 0)
                {
                    string content = GetContent(exc.Response);
                    Trace.TraceError(content);
                }
                throw;
            }
        }

        private static string GetContent(WebResponse webResponse)
        {
            string content;
            using (var stream = webResponse.GetResponseStream())
            using (var reader = new StreamReader(stream))
            {
                content = reader.ReadToEnd();
            }
            return content;
        }

        public string DownloadStringWithRetry(string address,int retries=0)
        {
            if (address == null)
                throw new ArgumentNullException("address");

            var actualAddress = GetActualAddress(address);

            TraceStart("GET",actualAddress);            
            
            var actualRetries = (retries == 0) ? Retries : retries;
            

            
            var task = Retry(() =>
            {
                var uriString = String.Join("/", BaseAddress.TrimEnd('/'), actualAddress);                
                var content = base.DownloadString(uriString);

                if (StatusCode == HttpStatusCode.NoContent)
                    return String.Empty;
                return content;

            }, actualRetries);

            var result = task.Result;
            return result;
        }

        public void Head(string address,int retries=0)
        {
            RetryWithoutContent(address, retries, "HEAD");
        }

        public void PutWithRetry(string address, int retries = 0)
        {
            RetryWithoutContent(address, retries, "PUT");
        }

        public void DeleteWithRetry(string address,int retries=0)
        {
            RetryWithoutContent(address, retries, "DELETE");
        }

        public string GetHeaderValue(string headerName)
        {
            var values=this.ResponseHeaders.GetValues(headerName);
            if (values == null)
                throw new WebException(String.Format("The {0}  header is missing", headerName));
            else
                return values[0];
        }

        private void RetryWithoutContent(string address, int retries, string method)
        {
            if (address == null)
                throw new ArgumentNullException("address");

            var actualAddress = GetActualAddress(address);            
            var actualRetries = (retries == 0) ? Retries : retries;

            var task = Retry(() =>
            {
                var uriString = String.Join("/",BaseAddress ,actualAddress);
                var uri = new Uri(uriString);
                var request =  GetWebRequest(uri);
                request.Method = method;
                if (ResponseHeaders!=null)
                    ResponseHeaders.Clear();

                TraceStart(method, uriString);

                var response = (HttpWebResponse)GetWebResponse(request);
                StatusCode = response.StatusCode;
                StatusDescription = response.StatusDescription;                
                

                return 0;
            }, actualRetries);

            task.Wait();
        }
        
        /*private string RetryWithContent(string address, int retries, string method)
        {
            if (address == null)
                throw new ArgumentNullException("address");

            var actualAddress = GetActualAddress(address);            
            var actualRetries = (retries == 0) ? Retries : retries;

            var task = Retry(() =>
            {
                var uriString = String.Join("/",BaseAddress ,actualAddress);
                var uri = new Uri(uriString);
                
                var request =  GetWebRequest(uri);
                request.Method = method;                

                if (ResponseHeaders!=null)
                    ResponseHeaders.Clear();

                TraceStart(method, uriString);

                var getResponse = request.GetResponseAsync();
                
                var setStatus= getResponse.ContinueWith(t =>
                {
                    var response = (HttpWebResponse)t.Result;                    
                    StatusCode = response.StatusCode;
                    StatusDescription = response.StatusDescription;                
                    return response;
                });

                var getData = setStatus.ContinueWith(t =>
                {
                    var response = t.Result;
                    return response.GetResponseStream()
                        .ReadAllBytesAsync();
                }).Unwrap();

                var data = getData.Result;
                var content=Encoding.UTF8.GetString(data);

           var response = (HttpWebResponse)GetWebResponse(request);
                
                
/*
                StatusCode = response.StatusCode;
                StatusDescription = response.StatusDescription;                
#1#
                

                return content;
            }, actualRetries);

            return task.Result;
        }*/

        private static void TraceStart(string method, string actualAddress)
        {
            Trace.WriteLine(String.Format("[{0}] {1} {2}", method, DateTime.Now, actualAddress));
        }

        private string GetActualAddress(string address)
        {
            if (Parameters.Count == 0)
                return address;
            var addressBuilder=new StringBuilder(address);            

            bool isFirst = true;
            foreach (var parameter in Parameters)
            {
                if(isFirst)
                    addressBuilder.AppendFormat("?{0}={1}", parameter.Key, parameter.Value);
                else
                    addressBuilder.AppendFormat("&{0}={1}", parameter.Key, parameter.Value);
                isFirst = false;
            }
            return addressBuilder.ToString();
        }

        public string DownloadStringWithRetry(Uri address,int retries=0)
        {
            if (address == null)
                throw new ArgumentNullException("address");

            var actualRetries = (retries == 0) ? Retries : retries;            
            var task = Retry(() =>
            {
                var content = base.DownloadString(address);

                if (StatusCode == HttpStatusCode.NoContent)
                    return String.Empty;
                return content;

            }, actualRetries);

            var result = task.Result;
            return result;
        }

      

/

/ Copies headers from another RestClient / / The RestClient from which the headers are copied

        public void CopyHeaders(RestClient source)
        {
            Contract.Requires(source != null, "source can't be null");
            if (source == null)
                throw new ArgumentNullException("source", "source can't be null");
            CopyHeaders(source.Headers,Headers);
        }
        

/

/ Copies headers from one header collection to another / / The source collection from which the headers are copied / The target collection to which the headers are copied

        public static void CopyHeaders(WebHeaderCollection source,WebHeaderCollection target)
        {
            Contract.Requires(source != null, "source can't be null");
            Contract.Requires(target != null, "target can't be null");
            if (source == null)
                throw new ArgumentNullException("source", "source can't be null");
            if (target == null)
                throw new ArgumentNullException("target", "target can't be null");
            for (int i = 0; i < source.Count; i++)
            {
                target.Add(source.GetKey(i), source[i]);
            }            
        }

        public void AssertStatusOK(string message)
        {
            if (StatusCode >= HttpStatusCode.BadRequest)
                throw new WebException(String.Format("{0} with code {1} - {2}", message, StatusCode, StatusDescription));
        }


        private Task<T> Retry<T>(Func<T> original, int retryCount, TaskCompletionSource<T> tcs = null)
        {
            if (tcs == null)
                tcs = new TaskCompletionSource<T>();
            Task.Factory.StartNew(original).ContinueWith(_original =>
                {
                    if (!_original.IsFaulted)
                        tcs.SetFromTask(_original);
                    else 
                    {
                        var e = _original.Exception.InnerException;
                        var we = (e as WebException);
                        if (we==null)
                            tcs.SetException(e);
                        else
                        {
                            var statusCode = GetStatusCode(we);

Return null for 404

                            if (statusCode == HttpStatusCode.NotFound)
                                tcs.SetResult(default(T));

Retry for timeouts and service unavailable

                            else if (we.Status == WebExceptionStatus.Timeout ||
                                (we.Status == WebExceptionStatus.ProtocolError && statusCode == HttpStatusCode.ServiceUnavailable))
                            {
                                TimedOut = true;
                                if (retryCount == 0)
                                {
                                    Trace.TraceError("[ERROR] Timed out too many times. {0}\n", e);
                                    tcs.SetException(new RetryException("Timed out too many times.", e));                                    
                                }
                                else
                                {
                                    Trace.TraceError(
                                        "[RETRY] Timed out after {0} ms. Will retry {1} more times\n{2}", Timeout,
                                        retryCount, e);
                                    Retry(original, retryCount - 1, tcs);
                                }
                            }
                            else
                                tcs.SetException(e);
                        }
                    };
                });
            return tcs.Task;
        }

        private HttpStatusCode GetStatusCode(WebException we)
        {
            var statusCode = HttpStatusCode.RequestTimeout;
            if (we.Response != null)
            {
                statusCode = ((HttpWebResponse) we.Response).StatusCode;
                this.StatusCode = statusCode;
            }
            return statusCode;
        }
    }

    public class RetryException:Exception
    {
        public RetryException()
            :base()
        {
            
        }

        public RetryException(string message)
            :base(message)
        {
            
        }

        public RetryException(string message,Exception innerException)
            :base(message,innerException)
        {
            
        }

        public RetryException(SerializationInfo info,StreamingContext context)
            :base(info,context)
        {
            
        }
    }
}