1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19 package org.eclipse.jetty.proxy;
20
21 import java.io.Closeable;
22 import java.io.IOException;
23 import java.net.InetSocketAddress;
24 import java.nio.ByteBuffer;
25 import java.nio.channels.SelectionKey;
26 import java.nio.channels.SocketChannel;
27 import java.util.HashSet;
28 import java.util.Set;
29 import java.util.concurrent.ConcurrentHashMap;
30 import java.util.concurrent.ConcurrentMap;
31 import java.util.concurrent.Executor;
32
33 import javax.servlet.AsyncContext;
34 import javax.servlet.ServletException;
35 import javax.servlet.http.HttpServletRequest;
36 import javax.servlet.http.HttpServletResponse;
37
38 import org.eclipse.jetty.http.HttpHeader;
39 import org.eclipse.jetty.http.HttpHeaderValue;
40 import org.eclipse.jetty.http.HttpMethod;
41 import org.eclipse.jetty.io.ByteBufferPool;
42 import org.eclipse.jetty.io.Connection;
43 import org.eclipse.jetty.io.EndPoint;
44 import org.eclipse.jetty.io.ManagedSelector;
45 import org.eclipse.jetty.io.MappedByteBufferPool;
46 import org.eclipse.jetty.io.SelectChannelEndPoint;
47 import org.eclipse.jetty.io.SelectorManager;
48 import org.eclipse.jetty.server.Handler;
49 import org.eclipse.jetty.server.HttpConnection;
50 import org.eclipse.jetty.server.HttpTransport;
51 import org.eclipse.jetty.server.Request;
52 import org.eclipse.jetty.server.handler.HandlerWrapper;
53 import org.eclipse.jetty.util.BufferUtil;
54 import org.eclipse.jetty.util.Callback;
55 import org.eclipse.jetty.util.Promise;
56 import org.eclipse.jetty.util.TypeUtil;
57 import org.eclipse.jetty.util.log.Log;
58 import org.eclipse.jetty.util.log.Logger;
59 import org.eclipse.jetty.util.thread.ScheduledExecutorScheduler;
60 import org.eclipse.jetty.util.thread.Scheduler;
61
62
63
64
65 public class ConnectHandler extends HandlerWrapper
66 {
67 protected static final Logger LOG = Log.getLogger(ConnectHandler.class);
68
69 private final Set<String> whiteList = new HashSet<>();
70 private final Set<String> blackList = new HashSet<>();
71 private Executor executor;
72 private Scheduler scheduler;
73 private ByteBufferPool bufferPool;
74 private SelectorManager selector;
75 private long connectTimeout = 15000;
76 private long idleTimeout = 30000;
77 private int bufferSize = 4096;
78
79 public ConnectHandler()
80 {
81 this(null);
82 }
83
84 public ConnectHandler(Handler handler)
85 {
86 setHandler(handler);
87 }
88
89 public Executor getExecutor()
90 {
91 return executor;
92 }
93
94 public void setExecutor(Executor executor)
95 {
96 this.executor = executor;
97 }
98
99 public Scheduler getScheduler()
100 {
101 return scheduler;
102 }
103
104 public void setScheduler(Scheduler scheduler)
105 {
106 this.scheduler = scheduler;
107 }
108
109 public ByteBufferPool getByteBufferPool()
110 {
111 return bufferPool;
112 }
113
114 public void setByteBufferPool(ByteBufferPool bufferPool)
115 {
116 this.bufferPool = bufferPool;
117 }
118
119
120
121
122 public long getConnectTimeout()
123 {
124 return connectTimeout;
125 }
126
127
128
129
130 public void setConnectTimeout(long connectTimeout)
131 {
132 this.connectTimeout = connectTimeout;
133 }
134
135
136
137
138 public long getIdleTimeout()
139 {
140 return idleTimeout;
141 }
142
143
144
145
146 public void setIdleTimeout(long idleTimeout)
147 {
148 this.idleTimeout = idleTimeout;
149 }
150
151 public int getBufferSize()
152 {
153 return bufferSize;
154 }
155
156 public void setBufferSize(int bufferSize)
157 {
158 this.bufferSize = bufferSize;
159 }
160
161 @Override
162 protected void doStart() throws Exception
163 {
164 if (executor == null)
165 executor = getServer().getThreadPool();
166
167 if (scheduler == null)
168 addBean(scheduler = new ScheduledExecutorScheduler());
169
170 if (bufferPool == null)
171 addBean(bufferPool = new MappedByteBufferPool());
172
173 addBean(selector = newSelectorManager());
174 selector.setConnectTimeout(getConnectTimeout());
175
176 super.doStart();
177 }
178
179 protected SelectorManager newSelectorManager()
180 {
181 return new ConnectManager(getExecutor(), getScheduler(), 1);
182 }
183
184 @Override
185 public void handle(String target, Request baseRequest, HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException
186 {
187 if (HttpMethod.CONNECT.is(request.getMethod()))
188 {
189 String serverAddress = request.getRequestURI();
190 if (LOG.isDebugEnabled())
191 LOG.debug("CONNECT request for {}", serverAddress);
192
193 handleConnect(baseRequest, request, response, serverAddress);
194 }
195 else
196 {
197 super.handle(target, baseRequest, request, response);
198 }
199 }
200
201
202
203
204
205
206
207
208
209
210
211 protected void handleConnect(Request baseRequest, HttpServletRequest request, HttpServletResponse response, String serverAddress)
212 {
213 baseRequest.setHandled(true);
214 try
215 {
216 boolean proceed = handleAuthentication(request, response, serverAddress);
217 if (!proceed)
218 {
219 if (LOG.isDebugEnabled())
220 LOG.debug("Missing proxy authentication");
221 sendConnectResponse(request, response, HttpServletResponse.SC_PROXY_AUTHENTICATION_REQUIRED);
222 return;
223 }
224
225 String host = serverAddress;
226 int port = 80;
227 int colon = serverAddress.indexOf(':');
228 if (colon > 0)
229 {
230 host = serverAddress.substring(0, colon);
231 port = Integer.parseInt(serverAddress.substring(colon + 1));
232 }
233
234 if (!validateDestination(host, port))
235 {
236 if (LOG.isDebugEnabled())
237 LOG.debug("Destination {}:{} forbidden", host, port);
238 sendConnectResponse(request, response, HttpServletResponse.SC_FORBIDDEN);
239 return;
240 }
241
242 HttpTransport transport = baseRequest.getHttpChannel().getHttpTransport();
243
244 if (!(transport instanceof HttpConnection))
245 {
246 if (LOG.isDebugEnabled())
247 LOG.debug("CONNECT not supported for {}", transport);
248 sendConnectResponse(request, response, HttpServletResponse.SC_FORBIDDEN);
249 return;
250 }
251
252 AsyncContext asyncContext = request.startAsync();
253 asyncContext.setTimeout(0);
254
255 if (LOG.isDebugEnabled())
256 LOG.debug("Connecting to {}:{}", host, port);
257
258 connectToServer(request, host, port, new Promise<SocketChannel>()
259 {
260 @Override
261 public void succeeded(SocketChannel channel)
262 {
263 ConnectContext connectContext = new ConnectContext(request, response, asyncContext, (HttpConnection)transport);
264 if (channel.isConnected())
265 selector.accept(channel, connectContext);
266 else
267 selector.connect(channel, connectContext);
268 }
269
270 @Override
271 public void failed(Throwable x)
272 {
273 onConnectFailure(request, response, asyncContext, x);
274 }
275 });
276 }
277 catch (Exception x)
278 {
279 onConnectFailure(request, response, null, x);
280 }
281 }
282
283 protected void connectToServer(HttpServletRequest request, String host, int port, Promise<SocketChannel> promise)
284 {
285 SocketChannel channel = null;
286 try
287 {
288 channel = SocketChannel.open();
289 channel.socket().setTcpNoDelay(true);
290 channel.configureBlocking(false);
291 InetSocketAddress address = newConnectAddress(host, port);
292 channel.connect(address);
293 promise.succeeded(channel);
294 }
295 catch (Throwable x)
296 {
297 close(channel);
298 promise.failed(x);
299 }
300 }
301
302 private void close(Closeable closeable)
303 {
304 try
305 {
306 if (closeable != null)
307 closeable.close();
308 }
309 catch (Throwable x)
310 {
311 LOG.ignore(x);
312 }
313 }
314
315
316
317
318
319
320
321
322 protected InetSocketAddress newConnectAddress(String host, int port)
323 {
324 return new InetSocketAddress(host, port);
325 }
326
327 protected void onConnectSuccess(ConnectContext connectContext, UpstreamConnection upstreamConnection)
328 {
329 ConcurrentMap<String, Object> context = connectContext.getContext();
330 HttpServletRequest request = connectContext.getRequest();
331 prepareContext(request, context);
332
333 HttpConnection httpConnection = connectContext.getHttpConnection();
334 EndPoint downstreamEndPoint = httpConnection.getEndPoint();
335 DownstreamConnection downstreamConnection = newDownstreamConnection(downstreamEndPoint, context, BufferUtil.EMPTY_BUFFER);
336 downstreamConnection.setInputBufferSize(getBufferSize());
337
338 upstreamConnection.setConnection(downstreamConnection);
339 downstreamConnection.setConnection(upstreamConnection);
340 if (LOG.isDebugEnabled())
341 LOG.debug("Connection setup completed: {}<->{}", downstreamConnection, upstreamConnection);
342
343 HttpServletResponse response = connectContext.getResponse();
344 sendConnectResponse(request, response, HttpServletResponse.SC_OK);
345
346 upgradeConnection(request, response, downstreamConnection);
347
348 connectContext.getAsyncContext().complete();
349 }
350
351 protected void onConnectFailure(HttpServletRequest request, HttpServletResponse response, AsyncContext asyncContext, Throwable failure)
352 {
353 if (LOG.isDebugEnabled())
354 LOG.debug("CONNECT failed", failure);
355 sendConnectResponse(request, response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
356 if (asyncContext != null)
357 asyncContext.complete();
358 }
359
360 private void sendConnectResponse(HttpServletRequest request, HttpServletResponse response, int statusCode)
361 {
362 try
363 {
364 response.setStatus(statusCode);
365 if (statusCode != HttpServletResponse.SC_OK)
366 response.setHeader(HttpHeader.CONNECTION.asString(), HttpHeaderValue.CLOSE.asString());
367 response.getOutputStream().close();
368 if (LOG.isDebugEnabled())
369 LOG.debug("CONNECT response sent {} {}", request.getProtocol(), response.getStatus());
370 }
371 catch (IOException x)
372 {
373 if (LOG.isDebugEnabled())
374 LOG.debug("Could not send CONNECT response", x);
375 }
376 }
377
378
379
380
381
382
383
384
385
386
387 protected boolean handleAuthentication(HttpServletRequest request, HttpServletResponse response, String address)
388 {
389 return true;
390 }
391
392
393
394
395 @Deprecated
396 protected DownstreamConnection newDownstreamConnection(EndPoint endPoint, ConcurrentMap<String, Object> context, ByteBuffer buffer)
397 {
398 return newDownstreamConnection(endPoint, context);
399 }
400
401 protected DownstreamConnection newDownstreamConnection(EndPoint endPoint, ConcurrentMap<String, Object> context)
402 {
403 return new DownstreamConnection(endPoint, getExecutor(), getByteBufferPool(), context);
404 }
405
406 protected UpstreamConnection newUpstreamConnection(EndPoint endPoint, ConnectContext connectContext)
407 {
408 return new UpstreamConnection(endPoint, getExecutor(), getByteBufferPool(), connectContext);
409 }
410
411 protected void prepareContext(HttpServletRequest request, ConcurrentMap<String, Object> context)
412 {
413 }
414
415 private void upgradeConnection(HttpServletRequest request, HttpServletResponse response, Connection connection)
416 {
417
418
419 request.setAttribute(HttpConnection.UPGRADE_CONNECTION_ATTRIBUTE, connection);
420 response.setStatus(HttpServletResponse.SC_SWITCHING_PROTOCOLS);
421 if (LOG.isDebugEnabled())
422 LOG.debug("Upgraded connection to {}", connection);
423 }
424
425
426
427
428
429
430
431
432
433
434
435 protected int read(EndPoint endPoint, ByteBuffer buffer, ConcurrentMap<String, Object> context) throws IOException
436 {
437 int read = read(endPoint, buffer);
438 if (LOG.isDebugEnabled())
439 LOG.debug("{} read {} bytes", this, read);
440 return read;
441 }
442
443
444
445
446 @Deprecated
447 protected int read(EndPoint endPoint, ByteBuffer buffer) throws IOException
448 {
449 return endPoint.fill(buffer);
450 }
451
452
453
454
455
456
457
458
459
460 protected void write(EndPoint endPoint, ByteBuffer buffer, Callback callback, ConcurrentMap<String, Object> context)
461 {
462 if (LOG.isDebugEnabled())
463 LOG.debug("{} writing {} bytes", this, buffer.remaining());
464 write(endPoint, buffer, callback);
465 }
466
467
468
469
470 @Deprecated
471 protected void write(EndPoint endPoint, ByteBuffer buffer, Callback callback)
472 {
473 endPoint.write(callback, buffer);
474 }
475
476 public Set<String> getWhiteListHosts()
477 {
478 return whiteList;
479 }
480
481 public Set<String> getBlackListHosts()
482 {
483 return blackList;
484 }
485
486
487
488
489
490
491
492
493 public boolean validateDestination(String host, int port)
494 {
495 String hostPort = host + ":" + port;
496 if (!whiteList.isEmpty())
497 {
498 if (!whiteList.contains(hostPort))
499 {
500 if (LOG.isDebugEnabled())
501 LOG.debug("Host {}:{} not whitelisted", host, port);
502 return false;
503 }
504 }
505 if (!blackList.isEmpty())
506 {
507 if (blackList.contains(hostPort))
508 {
509 if (LOG.isDebugEnabled())
510 LOG.debug("Host {}:{} blacklisted", host, port);
511 return false;
512 }
513 }
514 return true;
515 }
516
517 @Override
518 public void dump(Appendable out, String indent) throws IOException
519 {
520 dumpThis(out);
521 dump(out, indent, getBeans(), TypeUtil.asList(getHandlers()));
522 }
523
524 protected class ConnectManager extends SelectorManager
525 {
526 protected ConnectManager(Executor executor, Scheduler scheduler, int selectors)
527 {
528 super(executor, scheduler, selectors);
529 }
530
531 @Override
532 protected EndPoint newEndPoint(SocketChannel channel, ManagedSelector selector, SelectionKey selectionKey) throws IOException
533 {
534 return new SelectChannelEndPoint(channel, selector, selectionKey, getScheduler(), getIdleTimeout());
535 }
536
537 @Override
538 public Connection newConnection(SocketChannel channel, EndPoint endpoint, Object attachment) throws IOException
539 {
540 if (ConnectHandler.LOG.isDebugEnabled())
541 ConnectHandler.LOG.debug("Connected to {}", channel.getRemoteAddress());
542 ConnectContext connectContext = (ConnectContext)attachment;
543 UpstreamConnection connection = newUpstreamConnection(endpoint, connectContext);
544 connection.setInputBufferSize(getBufferSize());
545 return connection;
546 }
547
548 @Override
549 protected void connectionFailed(SocketChannel channel, final Throwable ex, final Object attachment)
550 {
551 close(channel);
552 ConnectContext connectContext = (ConnectContext)attachment;
553 onConnectFailure(connectContext.request, connectContext.response, connectContext.asyncContext, ex);
554 }
555 }
556
557 protected static class ConnectContext
558 {
559 private final ConcurrentMap<String, Object> context = new ConcurrentHashMap<>();
560 private final HttpServletRequest request;
561 private final HttpServletResponse response;
562 private final AsyncContext asyncContext;
563 private final HttpConnection httpConnection;
564
565 public ConnectContext(HttpServletRequest request, HttpServletResponse response, AsyncContext asyncContext, HttpConnection httpConnection)
566 {
567 this.request = request;
568 this.response = response;
569 this.asyncContext = asyncContext;
570 this.httpConnection = httpConnection;
571 }
572
573 public ConcurrentMap<String, Object> getContext()
574 {
575 return context;
576 }
577
578 public HttpServletRequest getRequest()
579 {
580 return request;
581 }
582
583 public HttpServletResponse getResponse()
584 {
585 return response;
586 }
587
588 public AsyncContext getAsyncContext()
589 {
590 return asyncContext;
591 }
592
593 public HttpConnection getHttpConnection()
594 {
595 return httpConnection;
596 }
597 }
598
599 public class UpstreamConnection extends ProxyConnection
600 {
601 private ConnectContext connectContext;
602
603 public UpstreamConnection(EndPoint endPoint, Executor executor, ByteBufferPool bufferPool, ConnectContext connectContext)
604 {
605 super(endPoint, executor, bufferPool, connectContext.getContext());
606 this.connectContext = connectContext;
607 }
608
609 @Override
610 public void onOpen()
611 {
612 super.onOpen();
613 onConnectSuccess(connectContext, UpstreamConnection.this);
614 fillInterested();
615 }
616
617 @Override
618 protected int read(EndPoint endPoint, ByteBuffer buffer) throws IOException
619 {
620 return ConnectHandler.this.read(endPoint, buffer, getContext());
621 }
622
623 @Override
624 protected void write(EndPoint endPoint, ByteBuffer buffer,Callback callback)
625 {
626 ConnectHandler.this.write(endPoint, buffer, callback, getContext());
627 }
628 }
629
630 public class DownstreamConnection extends ProxyConnection implements Connection.UpgradeTo
631 {
632 private ByteBuffer buffer;
633
634 public DownstreamConnection(EndPoint endPoint, Executor executor, ByteBufferPool bufferPool, ConcurrentMap<String, Object> context)
635 {
636 super(endPoint, executor, bufferPool, context);
637 }
638
639
640
641
642 @Deprecated
643 public DownstreamConnection(EndPoint endPoint, Executor executor, ByteBufferPool bufferPool, ConcurrentMap<String, Object> context, ByteBuffer buffer)
644 {
645 this(endPoint, executor, bufferPool, context);
646 }
647
648 @Override
649 public void onUpgradeTo(ByteBuffer buffer)
650 {
651 this.buffer = buffer == null ? BufferUtil.EMPTY_BUFFER : buffer;
652 }
653
654 @Override
655 public void onOpen()
656 {
657 super.onOpen();
658 final int remaining = buffer.remaining();
659 write(getConnection().getEndPoint(), buffer, new Callback()
660 {
661 @Override
662 public void succeeded()
663 {
664 if (LOG.isDebugEnabled())
665 LOG.debug("{} wrote initial {} bytes to server", DownstreamConnection.this, remaining);
666 fillInterested();
667 }
668
669 @Override
670 public void failed(Throwable x)
671 {
672 if (LOG.isDebugEnabled())
673 LOG.debug(this + " failed to write initial " + remaining + " bytes to server", x);
674 close();
675 getConnection().close();
676 }
677 });
678 }
679
680 @Override
681 protected int read(EndPoint endPoint, ByteBuffer buffer) throws IOException
682 {
683 return ConnectHandler.this.read(endPoint, buffer, getContext());
684 }
685
686 @Override
687 protected void write(EndPoint endPoint, ByteBuffer buffer, Callback callback)
688 {
689 ConnectHandler.this.write(endPoint, buffer, callback, getContext());
690 }
691 }
692 }