1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19 package org.eclipse.jetty.proxy;
20
21 import java.io.IOException;
22 import java.net.InetAddress;
23 import java.net.URI;
24 import java.net.UnknownHostException;
25 import java.nio.ByteBuffer;
26 import java.util.Enumeration;
27 import java.util.HashSet;
28 import java.util.Iterator;
29 import java.util.Locale;
30 import java.util.Set;
31 import java.util.concurrent.TimeUnit;
32 import java.util.concurrent.TimeoutException;
33 import javax.servlet.AsyncContext;
34 import javax.servlet.ServletConfig;
35 import javax.servlet.ServletException;
36 import javax.servlet.UnavailableException;
37 import javax.servlet.http.HttpServlet;
38 import javax.servlet.http.HttpServletRequest;
39 import javax.servlet.http.HttpServletResponse;
40
41 import org.eclipse.jetty.client.HttpClient;
42 import org.eclipse.jetty.client.api.Request;
43 import org.eclipse.jetty.client.api.Response;
44 import org.eclipse.jetty.client.api.Result;
45 import org.eclipse.jetty.client.util.InputStreamContentProvider;
46 import org.eclipse.jetty.http.HttpField;
47 import org.eclipse.jetty.http.HttpMethod;
48 import org.eclipse.jetty.http.HttpVersion;
49 import org.eclipse.jetty.server.handler.ContextHandler;
50 import org.eclipse.jetty.util.HttpCookieStore;
51 import org.eclipse.jetty.util.log.Log;
52 import org.eclipse.jetty.util.log.Logger;
53 import org.eclipse.jetty.util.thread.QueuedThreadPool;
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78 public class ProxyServlet extends HttpServlet
79 {
80 protected static final String ASYNC_CONTEXT = ProxyServlet.class.getName() + ".asyncContext";
81 private static final Set<String> HOP_HEADERS = new HashSet<>();
82 static
83 {
84 HOP_HEADERS.add("proxy-connection");
85 HOP_HEADERS.add("connection");
86 HOP_HEADERS.add("keep-alive");
87 HOP_HEADERS.add("transfer-encoding");
88 HOP_HEADERS.add("te");
89 HOP_HEADERS.add("trailer");
90 HOP_HEADERS.add("proxy-authorization");
91 HOP_HEADERS.add("proxy-authenticate");
92 HOP_HEADERS.add("upgrade");
93 }
94
95 private final Set<String> _whiteList = new HashSet<>();
96 private final Set<String> _blackList = new HashSet<>();
97
98 protected Logger _log;
99 private String _hostHeader;
100 private String _viaHost;
101 private HttpClient _client;
102 private long _timeout;
103
104 @Override
105 public void init() throws ServletException
106 {
107 _log = createLogger();
108
109 ServletConfig config = getServletConfig();
110
111 _hostHeader = config.getInitParameter("hostHeader");
112
113 _viaHost = config.getInitParameter("viaHost");
114 if (_viaHost == null)
115 _viaHost = viaHost();
116
117 try
118 {
119 _client = createHttpClient();
120
121
122 getServletContext().setAttribute(config.getServletName() + ".HttpClient", _client);
123
124 String whiteList = config.getInitParameter("whiteList");
125 if (whiteList != null)
126 getWhiteListHosts().addAll(parseList(whiteList));
127
128 String blackList = config.getInitParameter("blackList");
129 if (blackList != null)
130 getBlackListHosts().addAll(parseList(blackList));
131 }
132 catch (Exception e)
133 {
134 throw new ServletException(e);
135 }
136 }
137
138 public long getTimeout()
139 {
140 return _timeout;
141 }
142
143 public void setTimeout(long timeout)
144 {
145 this._timeout = timeout;
146 }
147
148 public Set<String> getWhiteListHosts()
149 {
150 return _whiteList;
151 }
152
153 public Set<String> getBlackListHosts()
154 {
155 return _blackList;
156 }
157
158 protected static String viaHost()
159 {
160 try
161 {
162 return InetAddress.getLocalHost().getHostName();
163 }
164 catch (UnknownHostException x)
165 {
166 return "localhost";
167 }
168 }
169
170
171
172
173 protected Logger createLogger()
174 {
175 String name = getServletConfig().getServletName();
176 name = name.replace('-', '.');
177 return Log.getLogger(name);
178 }
179
180 public void destroy()
181 {
182 try
183 {
184 _client.stop();
185 }
186 catch (Exception x)
187 {
188 _log.debug(x);
189 }
190 }
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241 protected HttpClient createHttpClient() throws ServletException
242 {
243 ServletConfig config = getServletConfig();
244
245 HttpClient client = newHttpClient();
246
247 client.setFollowRedirects(false);
248
249
250 client.setCookieStore(new HttpCookieStore.Empty());
251
252 String value = config.getInitParameter("maxThreads");
253 if (value == null)
254 value = "256";
255 QueuedThreadPool executor = new QueuedThreadPool(Integer.parseInt(value));
256 String servletName = config.getServletName();
257 int dot = servletName.lastIndexOf('.');
258 if (dot >= 0)
259 servletName = servletName.substring(dot + 1);
260 executor.setName(servletName);
261 client.setExecutor(executor);
262
263 value = config.getInitParameter("maxConnections");
264 if (value == null)
265 value = "32768";
266 client.setMaxConnectionsPerDestination(Integer.parseInt(value));
267
268 value = config.getInitParameter("idleTimeout");
269 if (value == null)
270 value = "30000";
271 client.setIdleTimeout(Long.parseLong(value));
272
273 value = config.getInitParameter("timeout");
274 if (value == null)
275 value = "60000";
276 _timeout = Long.parseLong(value);
277
278 value = config.getInitParameter("requestBufferSize");
279 if (value != null)
280 client.setRequestBufferSize(Integer.parseInt(value));
281
282 value = config.getInitParameter("responseBufferSize");
283 if (value != null)
284 client.setResponseBufferSize(Integer.parseInt(value));
285
286 try
287 {
288 client.start();
289
290
291 client.getContentDecoderFactories().clear();
292
293 return client;
294 }
295 catch (Exception x)
296 {
297 throw new ServletException(x);
298 }
299 }
300
301
302
303
304 protected HttpClient newHttpClient()
305 {
306 return new HttpClient();
307 }
308
309 private Set<String> parseList(String list)
310 {
311 Set<String> result = new HashSet<>();
312 String[] hosts = list.split(",");
313 for (String host : hosts)
314 {
315 host = host.trim();
316 if (host.length() == 0)
317 continue;
318 result.add(host);
319 }
320 return result;
321 }
322
323
324
325
326
327
328
329
330 public boolean validateDestination(String host, int port)
331 {
332 String hostPort = host + ":" + port;
333 if (!_whiteList.isEmpty())
334 {
335 if (!_whiteList.contains(hostPort))
336 {
337 _log.debug("Host {}:{} not whitelisted", host, port);
338 return false;
339 }
340 }
341 if (!_blackList.isEmpty())
342 {
343 if (_blackList.contains(hostPort))
344 {
345 _log.debug("Host {}:{} blacklisted", host, port);
346 return false;
347 }
348 }
349 return true;
350 }
351
352 @Override
353 protected void service(final HttpServletRequest request, final HttpServletResponse response) throws ServletException, IOException
354 {
355 final int requestId = getRequestId(request);
356
357 URI rewrittenURI = rewriteURI(request);
358
359 if (_log.isDebugEnabled())
360 {
361 StringBuffer uri = request.getRequestURL();
362 if (request.getQueryString() != null)
363 uri.append("?").append(request.getQueryString());
364 _log.debug("{} rewriting: {} -> {}", requestId, uri, rewrittenURI);
365 }
366
367 if (rewrittenURI == null)
368 {
369 response.sendError(HttpServletResponse.SC_FORBIDDEN);
370 return;
371 }
372
373 final Request proxyRequest = _client.newRequest(rewrittenURI)
374 .method(HttpMethod.fromString(request.getMethod()))
375 .version(HttpVersion.fromString(request.getProtocol()));
376
377
378 for (Enumeration<String> headerNames = request.getHeaderNames(); headerNames.hasMoreElements();)
379 {
380 String headerName = headerNames.nextElement();
381 String lowerHeaderName = headerName.toLowerCase(Locale.ENGLISH);
382
383
384 if (HOP_HEADERS.contains(lowerHeaderName))
385 continue;
386
387 if (_hostHeader!=null && lowerHeaderName.equals("host"))
388 continue;
389
390 for (Enumeration<String> headerValues = request.getHeaders(headerName); headerValues.hasMoreElements();)
391 {
392 String headerValue = headerValues.nextElement();
393 if (headerValue != null)
394 proxyRequest.header(headerName, headerValue);
395 }
396 }
397
398
399 if (_hostHeader != null)
400 proxyRequest.header("Host", _hostHeader);
401
402
403 proxyRequest.header("Via", "http/1.1 " + _viaHost);
404 proxyRequest.header("X-Forwarded-For", request.getRemoteAddr());
405 proxyRequest.header("X-Forwarded-Proto", request.getScheme());
406 proxyRequest.header("X-Forwarded-Host", request.getHeader("Host"));
407 proxyRequest.header("X-Forwarded-Server", request.getLocalName());
408
409 proxyRequest.content(new InputStreamContentProvider(request.getInputStream())
410 {
411 @Override
412 public long getLength()
413 {
414 return request.getContentLength();
415 }
416
417 @Override
418 protected ByteBuffer onRead(byte[] buffer, int offset, int length)
419 {
420 _log.debug("{} proxying content to upstream: {} bytes", requestId, length);
421 return super.onRead(buffer, offset, length);
422 }
423 });
424
425 final AsyncContext asyncContext = request.startAsync();
426
427 asyncContext.setTimeout(0);
428 request.setAttribute(ASYNC_CONTEXT, asyncContext);
429
430 customizeProxyRequest(proxyRequest, request);
431
432 if (_log.isDebugEnabled())
433 {
434 StringBuilder builder = new StringBuilder(request.getMethod());
435 builder.append(" ").append(request.getRequestURI());
436 String query = request.getQueryString();
437 if (query != null)
438 builder.append("?").append(query);
439 builder.append(" ").append(request.getProtocol()).append("\r\n");
440 for (Enumeration<String> headerNames = request.getHeaderNames(); headerNames.hasMoreElements();)
441 {
442 String headerName = headerNames.nextElement();
443 builder.append(headerName).append(": ");
444 for (Enumeration<String> headerValues = request.getHeaders(headerName); headerValues.hasMoreElements();)
445 {
446 String headerValue = headerValues.nextElement();
447 if (headerValue != null)
448 builder.append(headerValue);
449 if (headerValues.hasMoreElements())
450 builder.append(",");
451 }
452 builder.append("\r\n");
453 }
454 builder.append("\r\n");
455
456 _log.debug("{} proxying to upstream:{}{}{}{}",
457 requestId,
458 System.lineSeparator(),
459 builder,
460 proxyRequest,
461 System.lineSeparator(),
462 proxyRequest.getHeaders().toString().trim());
463 }
464
465 proxyRequest.timeout(getTimeout(), TimeUnit.MILLISECONDS);
466 proxyRequest.send(new ProxyResponseListener(request, response));
467 }
468
469 protected void onResponseHeaders(HttpServletRequest request, HttpServletResponse response, Response proxyResponse)
470 {
471 for (HttpField field : proxyResponse.getHeaders())
472 {
473 String headerName = field.getName();
474 String lowerHeaderName = headerName.toLowerCase(Locale.ENGLISH);
475 if (HOP_HEADERS.contains(lowerHeaderName))
476 continue;
477
478 String newHeaderValue = filterResponseHeader(request, headerName, field.getValue());
479 if (newHeaderValue == null || newHeaderValue.trim().length() == 0)
480 continue;
481
482 response.addHeader(headerName, newHeaderValue);
483 }
484 }
485
486 protected void onResponseContent(HttpServletRequest request, HttpServletResponse response, Response proxyResponse, byte[] buffer, int offset, int length) throws IOException
487 {
488 response.getOutputStream().write(buffer, offset, length);
489 _log.debug("{} proxying content to downstream: {} bytes", getRequestId(request), length);
490 }
491
492 protected void onResponseSuccess(HttpServletRequest request, HttpServletResponse response, Response proxyResponse)
493 {
494 AsyncContext asyncContext = (AsyncContext)request.getAttribute(ASYNC_CONTEXT);
495 asyncContext.complete();
496 _log.debug("{} proxying successful", getRequestId(request));
497 }
498
499 protected void onResponseFailure(HttpServletRequest request, HttpServletResponse response, Response proxyResponse, Throwable failure)
500 {
501 _log.debug(getRequestId(request) + " proxying failed", failure);
502 if (!response.isCommitted())
503 {
504 if (failure instanceof TimeoutException)
505 response.setStatus(HttpServletResponse.SC_GATEWAY_TIMEOUT);
506 else
507 response.setStatus(HttpServletResponse.SC_BAD_GATEWAY);
508 }
509 AsyncContext asyncContext = (AsyncContext)request.getAttribute(ASYNC_CONTEXT);
510 asyncContext.complete();
511 }
512
513 protected int getRequestId(HttpServletRequest request)
514 {
515 return System.identityHashCode(request);
516 }
517
518 protected URI rewriteURI(HttpServletRequest request)
519 {
520 if (!validateDestination(request.getServerName(), request.getServerPort()))
521 return null;
522
523 StringBuffer uri = request.getRequestURL();
524 String query = request.getQueryString();
525 if (query != null)
526 uri.append("?").append(query);
527
528 return URI.create(uri.toString());
529 }
530
531
532
533
534
535
536
537
538 protected void customizeProxyRequest(Request proxyRequest, HttpServletRequest request)
539 {
540 }
541
542
543
544
545
546
547
548
549
550
551
552 protected String filterResponseHeader(HttpServletRequest request, String headerName, String headerValue)
553 {
554 return headerValue;
555 }
556
557
558
559
560
561
562
563
564
565
566
567
568
569 public static class Transparent extends ProxyServlet
570 {
571 private String _proxyTo;
572 private String _prefix;
573
574 public Transparent()
575 {
576 }
577
578 public Transparent(String proxyTo, String prefix)
579 {
580 _proxyTo = URI.create(proxyTo).normalize().toString();
581 _prefix = URI.create(prefix).normalize().toString();
582 }
583
584 @Override
585 public void init() throws ServletException
586 {
587 super.init();
588
589 ServletConfig config = getServletConfig();
590
591 String prefix = config.getInitParameter("prefix");
592 _prefix = prefix == null ? _prefix : prefix;
593
594
595 String contextPath = getServletContext().getContextPath();
596 _prefix = _prefix == null ? contextPath : (contextPath + _prefix);
597
598 String proxyTo = config.getInitParameter("proxyTo");
599 _proxyTo = proxyTo == null ? _proxyTo : proxyTo;
600
601 if (_proxyTo == null)
602 throw new UnavailableException("Init parameter 'proxyTo' is required.");
603
604 if (!_prefix.startsWith("/"))
605 throw new UnavailableException("Init parameter 'prefix' parameter must start with a '/'.");
606
607 _log.info(config.getServletName() + " @ " + _prefix + " to " + _proxyTo);
608 }
609
610 @Override
611 protected URI rewriteURI(HttpServletRequest request)
612 {
613 String path = request.getRequestURI();
614 if (!path.startsWith(_prefix))
615 return null;
616
617 URI rewrittenURI = URI.create(_proxyTo + path.substring(_prefix.length())).normalize();
618
619 if (!validateDestination(rewrittenURI.getHost(), rewrittenURI.getPort()))
620 return null;
621
622 return rewrittenURI;
623 }
624 }
625
626 private class ProxyResponseListener extends Response.Listener.Empty
627 {
628 private final HttpServletRequest request;
629 private final HttpServletResponse response;
630
631 public ProxyResponseListener(HttpServletRequest request, HttpServletResponse response)
632 {
633 this.request = request;
634 this.response = response;
635 }
636
637 @Override
638 public void onBegin(Response proxyResponse)
639 {
640 response.setStatus(proxyResponse.getStatus());
641 }
642
643 @Override
644 public void onHeaders(Response proxyResponse)
645 {
646 onResponseHeaders(request, response, proxyResponse);
647
648 if (_log.isDebugEnabled())
649 {
650 StringBuilder builder = new StringBuilder("\r\n");
651 builder.append(request.getProtocol()).append(" ").append(response.getStatus()).append(" ").append(proxyResponse.getReason()).append("\r\n");
652 for (String headerName : response.getHeaderNames())
653 {
654 builder.append(headerName).append(": ");
655 for (Iterator<String> headerValues = response.getHeaders(headerName).iterator(); headerValues.hasNext();)
656 {
657 String headerValue = headerValues.next();
658 if (headerValue != null)
659 builder.append(headerValue);
660 if (headerValues.hasNext())
661 builder.append(",");
662 }
663 builder.append("\r\n");
664 }
665 _log.debug("{} proxying to downstream:{}{}{}{}{}",
666 getRequestId(request),
667 System.lineSeparator(),
668 proxyResponse,
669 System.lineSeparator(),
670 proxyResponse.getHeaders().toString().trim(),
671 System.lineSeparator(),
672 builder);
673 }
674 }
675
676 @Override
677 public void onContent(Response proxyResponse, ByteBuffer content)
678 {
679 byte[] buffer;
680 int offset;
681 int length = content.remaining();
682 if (content.hasArray())
683 {
684 buffer = content.array();
685 offset = content.arrayOffset();
686 }
687 else
688 {
689 buffer = new byte[length];
690 content.get(buffer);
691 offset = 0;
692 }
693
694 try
695 {
696 onResponseContent(request, response, proxyResponse, buffer, offset, length);
697 }
698 catch (IOException x)
699 {
700 proxyResponse.abort(x);
701 }
702 }
703
704 @Override
705 public void onSuccess(Response proxyResponse)
706 {
707 onResponseSuccess(request, response, proxyResponse);
708 }
709
710 @Override
711 public void onFailure(Response proxyResponse, Throwable failure)
712 {
713 onResponseFailure(request, response, proxyResponse, failure);
714 }
715
716 @Override
717 public void onComplete(Result result)
718 {
719 _log.debug("{} proxying complete", getRequestId(request));
720 }
721 }
722 }