View Javadoc

1   //
2   //  ========================================================================
3   //  Copyright (c) 1995-2013 Mort Bay Consulting Pty. Ltd.
4   //  ------------------------------------------------------------------------
5   //  All rights reserved. This program and the accompanying materials
6   //  are made available under the terms of the Eclipse Public License v1.0
7   //  and Apache License v2.0 which accompanies this distribution.
8   //
9   //      The Eclipse Public License is available at
10  //      http://www.eclipse.org/legal/epl-v10.html
11  //
12  //      The Apache License v2.0 is available at
13  //      http://www.opensource.org/licenses/apache2.0.php
14  //
15  //  You may elect to redistribute this code under either of these licenses.
16  //  ========================================================================
17  //
18  
19  package org.eclipse.jetty.proxy;
20  
21  import java.io.IOException;
22  import java.net.InetAddress;
23  import java.net.URI;
24  import java.net.UnknownHostException;
25  import java.nio.ByteBuffer;
26  import java.util.Enumeration;
27  import java.util.HashSet;
28  import java.util.Iterator;
29  import java.util.Locale;
30  import java.util.Set;
31  import java.util.concurrent.TimeUnit;
32  import java.util.concurrent.TimeoutException;
33  import javax.servlet.AsyncContext;
34  import javax.servlet.ServletConfig;
35  import javax.servlet.ServletException;
36  import javax.servlet.UnavailableException;
37  import javax.servlet.http.HttpServlet;
38  import javax.servlet.http.HttpServletRequest;
39  import javax.servlet.http.HttpServletResponse;
40  
41  import org.eclipse.jetty.client.HttpClient;
42  import org.eclipse.jetty.client.api.Request;
43  import org.eclipse.jetty.client.api.Response;
44  import org.eclipse.jetty.client.api.Result;
45  import org.eclipse.jetty.client.util.InputStreamContentProvider;
46  import org.eclipse.jetty.http.HttpField;
47  import org.eclipse.jetty.http.HttpHeader;
48  import org.eclipse.jetty.http.HttpMethod;
49  import org.eclipse.jetty.http.HttpVersion;
50  import org.eclipse.jetty.server.handler.ContextHandler;
51  import org.eclipse.jetty.util.HttpCookieStore;
52  import org.eclipse.jetty.util.log.Log;
53  import org.eclipse.jetty.util.log.Logger;
54  import org.eclipse.jetty.util.thread.QueuedThreadPool;
55  
56  /**
57   * Asynchronous ProxyServlet.
58   * <p/>
59   * Forwards requests to another server either as a standard web reverse proxy
60   * (as defined by RFC2616) or as a transparent reverse proxy.
61   * <p/>
62   * To facilitate JMX monitoring, the {@link HttpClient} instance is set as context attribute,
63   * prefixed with the servlet's name and exposed by the mechanism provided by
64   * {@link ContextHandler#MANAGED_ATTRIBUTES}.
65   * <p/>
66   * The following init parameters may be used to configure the servlet:
67   * <ul>
68   * <li>hostHeader - forces the host header to a particular value</li>
69   * <li>viaHost - the name to use in the Via header: Via: http/1.1 &lt;viaHost&gt;</li>
70   * <li>whiteList - comma-separated list of allowed proxy hosts</li>
71   * <li>blackList - comma-separated list of forbidden proxy hosts</li>
72   * </ul>
73   * <p/>
74   * In addition, see {@link #createHttpClient()} for init parameters used to configure
75   * the {@link HttpClient} instance.
76   *
77   * @see ConnectHandler
78   */
79  public class ProxyServlet extends HttpServlet
80  {
81      protected static final String ASYNC_CONTEXT = ProxyServlet.class.getName() + ".asyncContext";
82      private static final Set<String> HOP_HEADERS = new HashSet<>();
83      static
84      {
85          HOP_HEADERS.add("proxy-connection");
86          HOP_HEADERS.add("connection");
87          HOP_HEADERS.add("keep-alive");
88          HOP_HEADERS.add("transfer-encoding");
89          HOP_HEADERS.add("te");
90          HOP_HEADERS.add("trailer");
91          HOP_HEADERS.add("proxy-authorization");
92          HOP_HEADERS.add("proxy-authenticate");
93          HOP_HEADERS.add("upgrade");
94      }
95  
96      private final Set<String> _whiteList = new HashSet<>();
97      private final Set<String> _blackList = new HashSet<>();
98  
99      protected Logger _log;
100     private String _hostHeader;
101     private String _viaHost;
102     private HttpClient _client;
103     private long _timeout;
104 
105     @Override
106     public void init() throws ServletException
107     {
108         _log = createLogger();
109 
110         ServletConfig config = getServletConfig();
111 
112         _hostHeader = config.getInitParameter("hostHeader");
113 
114         _viaHost = config.getInitParameter("viaHost");
115         if (_viaHost == null)
116             _viaHost = viaHost();
117 
118         try
119         {
120             _client = createHttpClient();
121 
122             // Put the HttpClient in the context to leverage ContextHandler.MANAGED_ATTRIBUTES
123             getServletContext().setAttribute(config.getServletName() + ".HttpClient", _client);
124 
125             String whiteList = config.getInitParameter("whiteList");
126             if (whiteList != null)
127                 getWhiteListHosts().addAll(parseList(whiteList));
128 
129             String blackList = config.getInitParameter("blackList");
130             if (blackList != null)
131                 getBlackListHosts().addAll(parseList(blackList));
132         }
133         catch (Exception e)
134         {
135             throw new ServletException(e);
136         }
137     }
138 
139     public long getTimeout()
140     {
141         return _timeout;
142     }
143 
144     public void setTimeout(long timeout)
145     {
146         this._timeout = timeout;
147     }
148 
149     public Set<String> getWhiteListHosts()
150     {
151         return _whiteList;
152     }
153 
154     public Set<String> getBlackListHosts()
155     {
156         return _blackList;
157     }
158 
159     protected static String viaHost()
160     {
161         try
162         {
163             return InetAddress.getLocalHost().getHostName();
164         }
165         catch (UnknownHostException x)
166         {
167             return "localhost";
168         }
169     }
170 
171     /**
172      * @return a logger instance with a name derived from this servlet's name.
173      */
174     protected Logger createLogger()
175     {
176         String name = getServletConfig().getServletName();
177         name = name.replace('-', '.');
178         return Log.getLogger(name);
179     }
180 
181     public void destroy()
182     {
183         try
184         {
185             _client.stop();
186         }
187         catch (Exception x)
188         {
189             _log.debug(x);
190         }
191     }
192 
193     /**
194      * Creates a {@link HttpClient} instance, configured with init parameters of this servlet.
195      * <p/>
196      * The init parameters used to configure the {@link HttpClient} instance are:
197      * <table>
198      * <thead>
199      * <tr>
200      * <th>init-param</th>
201      * <th>default</th>
202      * <th>description</th>
203      * </tr>
204      * </thead>
205      * <tbody>
206      * <tr>
207      * <td>maxThreads</td>
208      * <td>256</td>
209      * <td>The max number of threads of HttpClient's Executor</td>
210      * </tr>
211      * <tr>
212      * <td>maxConnections</td>
213      * <td>32768</td>
214      * <td>The max number of connections per destination, see {@link HttpClient#setMaxConnectionsPerDestination(int)}</td>
215      * </tr>
216      * <tr>
217      * <td>idleTimeout</td>
218      * <td>30000</td>
219      * <td>The idle timeout in milliseconds, see {@link HttpClient#setIdleTimeout(long)}</td>
220      * </tr>
221      * <tr>
222      * <td>timeout</td>
223      * <td>60000</td>
224      * <td>The total timeout in milliseconds, see {@link Request#timeout(long, TimeUnit)}</td>
225      * </tr>
226      * <tr>
227      * <td>requestBufferSize</td>
228      * <td>HttpClient's default</td>
229      * <td>The request buffer size, see {@link HttpClient#setRequestBufferSize(int)}</td>
230      * </tr>
231      * <tr>
232      * <td>responseBufferSize</td>
233      * <td>HttpClient's default</td>
234      * <td>The response buffer size, see {@link HttpClient#setResponseBufferSize(int)}</td>
235      * </tr>
236      * </tbody>
237      * </table>
238      *
239      * @return a {@link HttpClient} configured from the {@link #getServletConfig() servlet configuration}
240      * @throws ServletException if the {@link HttpClient} cannot be created
241      */
242     protected HttpClient createHttpClient() throws ServletException
243     {
244         ServletConfig config = getServletConfig();
245 
246         HttpClient client = newHttpClient();
247         // Redirects must be proxied as is, not followed
248         client.setFollowRedirects(false);
249 
250         // Must not store cookies, otherwise cookies of different clients will mix
251         client.setCookieStore(new HttpCookieStore.Empty());
252 
253         String value = config.getInitParameter("maxThreads");
254         if (value == null)
255             value = "256";
256         QueuedThreadPool executor = new QueuedThreadPool(Integer.parseInt(value));
257         String servletName = config.getServletName();
258         int dot = servletName.lastIndexOf('.');
259         if (dot >= 0)
260             servletName = servletName.substring(dot + 1);
261         executor.setName(servletName);
262         client.setExecutor(executor);
263 
264         value = config.getInitParameter("maxConnections");
265         if (value == null)
266             value = "32768";
267         client.setMaxConnectionsPerDestination(Integer.parseInt(value));
268 
269         value = config.getInitParameter("idleTimeout");
270         if (value == null)
271             value = "30000";
272         client.setIdleTimeout(Long.parseLong(value));
273 
274         value = config.getInitParameter("timeout");
275         if (value == null)
276             value = "60000";
277         _timeout = Long.parseLong(value);
278 
279         value = config.getInitParameter("requestBufferSize");
280         if (value != null)
281             client.setRequestBufferSize(Integer.parseInt(value));
282 
283         value = config.getInitParameter("responseBufferSize");
284         if (value != null)
285             client.setResponseBufferSize(Integer.parseInt(value));
286 
287         try
288         {
289             client.start();
290 
291             // Content must not be decoded, otherwise the client gets confused
292             client.getContentDecoderFactories().clear();
293 
294             return client;
295         }
296         catch (Exception x)
297         {
298             throw new ServletException(x);
299         }
300     }
301 
302     /**
303      * @return a new HttpClient instance
304      */
305     protected HttpClient newHttpClient()
306     {
307         return new HttpClient();
308     }
309 
310     private Set<String> parseList(String list)
311     {
312         Set<String> result = new HashSet<>();
313         String[] hosts = list.split(",");
314         for (String host : hosts)
315         {
316             host = host.trim();
317             if (host.length() == 0)
318                 continue;
319             result.add(host);
320         }
321         return result;
322     }
323 
324     /**
325      * Checks the given {@code host} and {@code port} against whitelist and blacklist.
326      *
327      * @param host the host to check
328      * @param port the port to check
329      * @return true if it is allowed to be proxy to the given host and port
330      */
331     public boolean validateDestination(String host, int port)
332     {
333         String hostPort = host + ":" + port;
334         if (!_whiteList.isEmpty())
335         {
336             if (!_whiteList.contains(hostPort))
337             {
338                 _log.debug("Host {}:{} not whitelisted", host, port);
339                 return false;
340             }
341         }
342         if (!_blackList.isEmpty())
343         {
344             if (_blackList.contains(hostPort))
345             {
346                 _log.debug("Host {}:{} blacklisted", host, port);
347                 return false;
348             }
349         }
350         return true;
351     }
352 
353     @Override
354     protected void service(final HttpServletRequest request, final HttpServletResponse response) throws ServletException, IOException
355     {
356         final int requestId = getRequestId(request);
357 
358         URI rewrittenURI = rewriteURI(request);
359 
360         if (_log.isDebugEnabled())
361         {
362             StringBuffer uri = request.getRequestURL();
363             if (request.getQueryString() != null)
364                 uri.append("?").append(request.getQueryString());
365             _log.debug("{} rewriting: {} -> {}", requestId, uri, rewrittenURI);
366         }
367 
368         if (rewrittenURI == null)
369         {
370             response.sendError(HttpServletResponse.SC_FORBIDDEN);
371             return;
372         }
373 
374         final Request proxyRequest = _client.newRequest(rewrittenURI)
375                 .method(HttpMethod.fromString(request.getMethod()))
376                 .version(HttpVersion.fromString(request.getProtocol()));
377 
378         // Copy headers
379         for (Enumeration<String> headerNames = request.getHeaderNames(); headerNames.hasMoreElements();)
380         {
381             String headerName = headerNames.nextElement();
382             String lowerHeaderName = headerName.toLowerCase(Locale.ENGLISH);
383 
384             // Remove hop-by-hop headers
385             if (HOP_HEADERS.contains(lowerHeaderName))
386                 continue;
387 
388             if (_hostHeader!=null && lowerHeaderName.equals("host"))
389                 continue;
390 
391             for (Enumeration<String> headerValues = request.getHeaders(headerName); headerValues.hasMoreElements();)
392             {
393                 String headerValue = headerValues.nextElement();
394                 if (headerValue != null)
395                     proxyRequest.header(headerName, headerValue);
396             }
397         }
398 
399         // Force the Host header if configured
400         if (_hostHeader != null)
401             proxyRequest.header(HttpHeader.HOST, _hostHeader);
402 
403         // Add proxy headers
404         proxyRequest.header(HttpHeader.VIA, "http/1.1 " + _viaHost);
405         proxyRequest.header(HttpHeader.X_FORWARDED_FOR, request.getRemoteAddr());
406         proxyRequest.header(HttpHeader.X_FORWARDED_PROTO, request.getScheme());
407         proxyRequest.header(HttpHeader.X_FORWARDED_HOST, request.getHeader(HttpHeader.HOST.asString()));
408         proxyRequest.header(HttpHeader.X_FORWARDED_SERVER, request.getLocalName());
409 
410         proxyRequest.content(new InputStreamContentProvider(request.getInputStream())
411         {
412             @Override
413             public long getLength()
414             {
415                 return request.getContentLength();
416             }
417 
418             @Override
419             protected ByteBuffer onRead(byte[] buffer, int offset, int length)
420             {
421                 _log.debug("{} proxying content to upstream: {} bytes", requestId, length);
422                 return super.onRead(buffer, offset, length);
423             }
424         });
425 
426         final AsyncContext asyncContext = request.startAsync();
427         // We do not timeout the continuation, but the proxy request
428         asyncContext.setTimeout(0);
429         request.setAttribute(ASYNC_CONTEXT, asyncContext);
430 
431         customizeProxyRequest(proxyRequest, request);
432 
433         if (_log.isDebugEnabled())
434         {
435             StringBuilder builder = new StringBuilder(request.getMethod());
436             builder.append(" ").append(request.getRequestURI());
437             String query = request.getQueryString();
438             if (query != null)
439                 builder.append("?").append(query);
440             builder.append(" ").append(request.getProtocol()).append("\r\n");
441             for (Enumeration<String> headerNames = request.getHeaderNames(); headerNames.hasMoreElements();)
442             {
443                 String headerName = headerNames.nextElement();
444                 builder.append(headerName).append(": ");
445                 for (Enumeration<String> headerValues = request.getHeaders(headerName); headerValues.hasMoreElements();)
446                 {
447                     String headerValue = headerValues.nextElement();
448                     if (headerValue != null)
449                         builder.append(headerValue);
450                     if (headerValues.hasMoreElements())
451                         builder.append(",");
452                 }
453                 builder.append("\r\n");
454             }
455             builder.append("\r\n");
456 
457             _log.debug("{} proxying to upstream:{}{}{}{}",
458                     requestId,
459                     System.lineSeparator(),
460                     builder,
461                     proxyRequest,
462                     System.lineSeparator(),
463                     proxyRequest.getHeaders().toString().trim());
464         }
465 
466         proxyRequest.timeout(getTimeout(), TimeUnit.MILLISECONDS);
467         proxyRequest.send(new ProxyResponseListener(request, response));
468     }
469 
470     protected void onResponseHeaders(HttpServletRequest request, HttpServletResponse response, Response proxyResponse)
471     {
472         for (HttpField field : proxyResponse.getHeaders())
473         {
474             String headerName = field.getName();
475             String lowerHeaderName = headerName.toLowerCase(Locale.ENGLISH);
476             if (HOP_HEADERS.contains(lowerHeaderName))
477                 continue;
478 
479             String newHeaderValue = filterResponseHeader(request, headerName, field.getValue());
480             if (newHeaderValue == null || newHeaderValue.trim().length() == 0)
481                 continue;
482 
483             response.addHeader(headerName, newHeaderValue);
484         }
485     }
486 
487     protected void onResponseContent(HttpServletRequest request, HttpServletResponse response, Response proxyResponse, byte[] buffer, int offset, int length) throws IOException
488     {
489         response.getOutputStream().write(buffer, offset, length);
490         _log.debug("{} proxying content to downstream: {} bytes", getRequestId(request), length);
491     }
492 
493     protected void onResponseSuccess(HttpServletRequest request, HttpServletResponse response, Response proxyResponse)
494     {
495         AsyncContext asyncContext = (AsyncContext)request.getAttribute(ASYNC_CONTEXT);
496         asyncContext.complete();
497         _log.debug("{} proxying successful", getRequestId(request));
498     }
499 
500     protected void onResponseFailure(HttpServletRequest request, HttpServletResponse response, Response proxyResponse, Throwable failure)
501     {
502         _log.debug(getRequestId(request) + " proxying failed", failure);
503         if (!response.isCommitted())
504         {
505             if (failure instanceof TimeoutException)
506                 response.setStatus(HttpServletResponse.SC_GATEWAY_TIMEOUT);
507             else
508                 response.setStatus(HttpServletResponse.SC_BAD_GATEWAY);
509         }
510         AsyncContext asyncContext = (AsyncContext)request.getAttribute(ASYNC_CONTEXT);
511         asyncContext.complete();
512     }
513 
514     protected int getRequestId(HttpServletRequest request)
515     {
516         return System.identityHashCode(request);
517     }
518 
519     protected URI rewriteURI(HttpServletRequest request)
520     {
521         if (!validateDestination(request.getServerName(), request.getServerPort()))
522             return null;
523 
524         StringBuffer uri = request.getRequestURL();
525         String query = request.getQueryString();
526         if (query != null)
527             uri.append("?").append(query);
528 
529         return URI.create(uri.toString());
530     }
531 
532     /**
533      * Extension point for subclasses to customize the proxy request.
534      * The default implementation does nothing.
535      *
536      * @param proxyRequest the proxy request to customize
537      * @param request the request to be proxied
538      */
539     protected void customizeProxyRequest(Request proxyRequest, HttpServletRequest request)
540     {
541     }
542 
543     /**
544      * Extension point for remote server response header filtering.
545      * The default implementation returns the header value as is.
546      * If null is returned, this header won't be forwarded back to the client.
547      *
548      * @param headerName the header name
549      * @param headerValue the header value
550      * @param request the request to proxy
551      * @return filteredHeaderValue the new header value
552      */
553     protected String filterResponseHeader(HttpServletRequest request, String headerName, String headerValue)
554     {
555         return headerValue;
556     }
557 
558     /**
559      * Transparent Proxy.
560      * <p/>
561      * This convenience extension to ProxyServlet configures the servlet as a transparent proxy.
562      * The servlet is configured with init parameters:
563      * <ul>
564      * <li>proxyTo - a URI like http://host:80/context to which the request is proxied.
565      * <li>prefix - a URI prefix that is striped from the start of the forwarded URI.
566      * </ul>
567      * For example, if a request is received at /foo/bar and the 'proxyTo' parameter is "http://host:80/context"
568      * and the 'prefix' parameter is "/foo", then the request would be proxied to "http://host:80/context/bar".
569      */
570     public static class Transparent extends ProxyServlet
571     {
572         private String _proxyTo;
573         private String _prefix;
574 
575         public Transparent()
576         {
577         }
578 
579         public Transparent(String proxyTo, String prefix)
580         {
581             _proxyTo = URI.create(proxyTo).normalize().toString();
582             _prefix = URI.create(prefix).normalize().toString();
583         }
584 
585         @Override
586         public void init() throws ServletException
587         {
588             super.init();
589 
590             ServletConfig config = getServletConfig();
591 
592             String prefix = config.getInitParameter("prefix");
593             _prefix = prefix == null ? _prefix : prefix;
594 
595             // Adjust prefix value to account for context path
596             String contextPath = getServletContext().getContextPath();
597             _prefix = _prefix == null ? contextPath : (contextPath + _prefix);
598 
599             String proxyTo = config.getInitParameter("proxyTo");
600             _proxyTo = proxyTo == null ? _proxyTo : proxyTo;
601 
602             if (_proxyTo == null)
603                 throw new UnavailableException("Init parameter 'proxyTo' is required.");
604 
605             if (!_prefix.startsWith("/"))
606                 throw new UnavailableException("Init parameter 'prefix' parameter must start with a '/'.");
607 
608             _log.debug(config.getServletName() + " @ " + _prefix + " to " + _proxyTo);
609         }
610 
611         @Override
612         protected URI rewriteURI(HttpServletRequest request)
613         {
614             String path = request.getRequestURI();
615             if (!path.startsWith(_prefix))
616                 return null;
617 
618             StringBuilder uri = new StringBuilder(_proxyTo);
619             uri.append(path.substring(_prefix.length()));
620             String query = request.getQueryString();
621             if (query != null)
622                 uri.append("?").append(query);
623             URI rewrittenURI = URI.create(uri.toString()).normalize();
624 
625             if (!validateDestination(rewrittenURI.getHost(), rewrittenURI.getPort()))
626                 return null;
627 
628             return rewrittenURI;
629         }
630     }
631 
632     private class ProxyResponseListener extends Response.Listener.Empty
633     {
634         private final HttpServletRequest request;
635         private final HttpServletResponse response;
636 
637         public ProxyResponseListener(HttpServletRequest request, HttpServletResponse response)
638         {
639             this.request = request;
640             this.response = response;
641         }
642 
643         @Override
644         public void onBegin(Response proxyResponse)
645         {
646             response.setStatus(proxyResponse.getStatus());
647         }
648 
649         @Override
650         public void onHeaders(Response proxyResponse)
651         {
652             onResponseHeaders(request, response, proxyResponse);
653 
654             if (_log.isDebugEnabled())
655             {
656                 StringBuilder builder = new StringBuilder("\r\n");
657                 builder.append(request.getProtocol()).append(" ").append(response.getStatus()).append(" ").append(proxyResponse.getReason()).append("\r\n");
658                 for (String headerName : response.getHeaderNames())
659                 {
660                     builder.append(headerName).append(": ");
661                     for (Iterator<String> headerValues = response.getHeaders(headerName).iterator(); headerValues.hasNext();)
662                     {
663                         String headerValue = headerValues.next();
664                         if (headerValue != null)
665                             builder.append(headerValue);
666                         if (headerValues.hasNext())
667                             builder.append(",");
668                     }
669                     builder.append("\r\n");
670                 }
671                 _log.debug("{} proxying to downstream:{}{}{}{}{}",
672                         getRequestId(request),
673                         System.lineSeparator(),
674                         proxyResponse,
675                         System.lineSeparator(),
676                         proxyResponse.getHeaders().toString().trim(),
677                         System.lineSeparator(),
678                         builder);
679             }
680         }
681 
682         @Override
683         public void onContent(Response proxyResponse, ByteBuffer content)
684         {
685             byte[] buffer;
686             int offset;
687             int length = content.remaining();
688             if (content.hasArray())
689             {
690                 buffer = content.array();
691                 offset = content.arrayOffset();
692             }
693             else
694             {
695                 buffer = new byte[length];
696                 content.get(buffer);
697                 offset = 0;
698             }
699 
700             try
701             {
702                 onResponseContent(request, response, proxyResponse, buffer, offset, length);
703             }
704             catch (IOException x)
705             {
706                 proxyResponse.abort(x);
707             }
708         }
709 
710         @Override
711         public void onSuccess(Response proxyResponse)
712         {
713             onResponseSuccess(request, response, proxyResponse);
714         }
715 
716         @Override
717         public void onFailure(Response proxyResponse, Throwable failure)
718         {
719             onResponseFailure(request, response, proxyResponse, failure);
720         }
721 
722         @Override
723         public void onComplete(Result result)
724         {
725             _log.debug("{} proxying complete", getRequestId(request));
726         }
727     }
728 }