1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19 package org.eclipse.jetty.proxy;
20
21 import java.io.IOException;
22 import java.net.InetSocketAddress;
23 import java.nio.ByteBuffer;
24 import java.nio.channels.SelectionKey;
25 import java.nio.channels.SocketChannel;
26 import java.util.HashSet;
27 import java.util.Set;
28 import java.util.concurrent.ConcurrentHashMap;
29 import java.util.concurrent.ConcurrentMap;
30 import java.util.concurrent.Executor;
31 import javax.servlet.AsyncContext;
32 import javax.servlet.ServletException;
33 import javax.servlet.http.HttpServletRequest;
34 import javax.servlet.http.HttpServletResponse;
35
36 import org.eclipse.jetty.http.HttpHeader;
37 import org.eclipse.jetty.http.HttpHeaderValue;
38 import org.eclipse.jetty.http.HttpMethod;
39 import org.eclipse.jetty.io.ByteBufferPool;
40 import org.eclipse.jetty.io.Connection;
41 import org.eclipse.jetty.io.EndPoint;
42 import org.eclipse.jetty.io.MappedByteBufferPool;
43 import org.eclipse.jetty.io.SelectChannelEndPoint;
44 import org.eclipse.jetty.io.SelectorManager;
45 import org.eclipse.jetty.server.Handler;
46 import org.eclipse.jetty.server.HttpConnection;
47 import org.eclipse.jetty.server.Request;
48 import org.eclipse.jetty.server.handler.HandlerWrapper;
49 import org.eclipse.jetty.util.BufferUtil;
50 import org.eclipse.jetty.util.Callback;
51 import org.eclipse.jetty.util.TypeUtil;
52 import org.eclipse.jetty.util.log.Log;
53 import org.eclipse.jetty.util.log.Logger;
54 import org.eclipse.jetty.util.thread.ScheduledExecutorScheduler;
55 import org.eclipse.jetty.util.thread.Scheduler;
56
57
58
59
60 public class ConnectHandler extends HandlerWrapper
61 {
62 protected static final Logger LOG = Log.getLogger(ConnectHandler.class);
63
64 private final Set<String> whiteList = new HashSet<>();
65 private final Set<String> blackList = new HashSet<>();
66 private Executor executor;
67 private Scheduler scheduler;
68 private ByteBufferPool bufferPool;
69 private SelectorManager selector;
70 private long connectTimeout = 15000;
71 private long idleTimeout = 30000;
72 private int bufferSize = 4096;
73
74 public ConnectHandler()
75 {
76 this(null);
77 }
78
79 public ConnectHandler(Handler handler)
80 {
81 setHandler(handler);
82 }
83
84 public Executor getExecutor()
85 {
86 return executor;
87 }
88
89 public void setExecutor(Executor executor)
90 {
91 this.executor = executor;
92 }
93
94 public Scheduler getScheduler()
95 {
96 return scheduler;
97 }
98
99 public void setScheduler(Scheduler scheduler)
100 {
101 this.scheduler = scheduler;
102 }
103
104 public ByteBufferPool getByteBufferPool()
105 {
106 return bufferPool;
107 }
108
109 public void setByteBufferPool(ByteBufferPool bufferPool)
110 {
111 this.bufferPool = bufferPool;
112 }
113
114
115
116
117 public long getConnectTimeout()
118 {
119 return connectTimeout;
120 }
121
122
123
124
125 public void setConnectTimeout(long connectTimeout)
126 {
127 this.connectTimeout = connectTimeout;
128 }
129
130
131
132
133 public long getIdleTimeout()
134 {
135 return idleTimeout;
136 }
137
138
139
140
141 public void setIdleTimeout(long idleTimeout)
142 {
143 this.idleTimeout = idleTimeout;
144 }
145
146 public int getBufferSize()
147 {
148 return bufferSize;
149 }
150
151 public void setBufferSize(int bufferSize)
152 {
153 this.bufferSize = bufferSize;
154 }
155
156 @Override
157 protected void doStart() throws Exception
158 {
159 if (executor == null)
160 {
161 setExecutor(getServer().getThreadPool());
162 }
163 if (scheduler == null)
164 {
165 setScheduler(new ScheduledExecutorScheduler());
166 addBean(getScheduler());
167 }
168 if (bufferPool == null)
169 {
170 setByteBufferPool(new MappedByteBufferPool());
171 addBean(getByteBufferPool());
172 }
173 addBean(selector = newSelectorManager());
174 selector.setConnectTimeout(getConnectTimeout());
175 super.doStart();
176 }
177
178 protected SelectorManager newSelectorManager()
179 {
180 return new ConnectManager(getExecutor(), getScheduler(), 1);
181 }
182
183 @Override
184 public void handle(String target, Request baseRequest, HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException
185 {
186 if (HttpMethod.CONNECT.is(request.getMethod()))
187 {
188 String serverAddress = request.getRequestURI();
189 if (LOG.isDebugEnabled())
190 LOG.debug("CONNECT request for {}", serverAddress);
191 try
192 {
193 handleConnect(baseRequest, request, response, serverAddress);
194 }
195 catch (Exception x)
196 {
197
198 LOG.warn("ConnectHandler " + baseRequest.getUri() + " " + x);
199 LOG.debug(x);
200 }
201 }
202 else
203 {
204 super.handle(target, baseRequest, request, response);
205 }
206 }
207
208
209
210
211
212
213
214
215
216
217
218 protected void handleConnect(Request jettyRequest, HttpServletRequest request, HttpServletResponse response, String serverAddress)
219 {
220 jettyRequest.setHandled(true);
221 try
222 {
223 boolean proceed = handleAuthentication(request, response, serverAddress);
224 if (!proceed)
225 {
226 if (LOG.isDebugEnabled())
227 LOG.debug("Missing proxy authentication");
228 sendConnectResponse(request, response, HttpServletResponse.SC_PROXY_AUTHENTICATION_REQUIRED);
229 return;
230 }
231
232 String host = serverAddress;
233 int port = 80;
234 int colon = serverAddress.indexOf(':');
235 if (colon > 0)
236 {
237 host = serverAddress.substring(0, colon);
238 port = Integer.parseInt(serverAddress.substring(colon + 1));
239 }
240
241 if (!validateDestination(host, port))
242 {
243 if (LOG.isDebugEnabled())
244 LOG.debug("Destination {}:{} forbidden", host, port);
245 sendConnectResponse(request, response, HttpServletResponse.SC_FORBIDDEN);
246 return;
247 }
248
249 SocketChannel channel = SocketChannel.open();
250 channel.socket().setTcpNoDelay(true);
251 channel.configureBlocking(false);
252 InetSocketAddress address = new InetSocketAddress(host, port);
253
254 AsyncContext asyncContext = request.startAsync();
255 asyncContext.setTimeout(0);
256
257 if (LOG.isDebugEnabled())
258 LOG.debug("Connecting to {}", address);
259
260 ConnectContext connectContext = new ConnectContext(request, response, asyncContext, HttpConnection.getCurrentConnection());
261 if (channel.connect(address))
262 selector.accept(channel, connectContext);
263 else
264 selector.connect(channel, connectContext);
265 }
266 catch (Exception x)
267 {
268 onConnectFailure(request, response, null, x);
269 }
270 }
271
272 protected void onConnectSuccess(ConnectContext connectContext, UpstreamConnection upstreamConnection)
273 {
274 HttpConnection httpConnection = connectContext.getHttpConnection();
275 ByteBuffer requestBuffer = httpConnection.getRequestBuffer();
276 ByteBuffer buffer = BufferUtil.EMPTY_BUFFER;
277 int remaining = requestBuffer.remaining();
278 if (remaining > 0)
279 {
280 buffer = bufferPool.acquire(remaining, requestBuffer.isDirect());
281 BufferUtil.flipToFill(buffer);
282 buffer.put(requestBuffer);
283 buffer.flip();
284 }
285
286 ConcurrentMap<String, Object> context = connectContext.getContext();
287 HttpServletRequest request = connectContext.getRequest();
288 prepareContext(request, context);
289
290 EndPoint downstreamEndPoint = httpConnection.getEndPoint();
291 DownstreamConnection downstreamConnection = newDownstreamConnection(downstreamEndPoint, context, buffer);
292 downstreamConnection.setInputBufferSize(getBufferSize());
293
294 upstreamConnection.setConnection(downstreamConnection);
295 downstreamConnection.setConnection(upstreamConnection);
296 if (LOG.isDebugEnabled())
297 LOG.debug("Connection setup completed: {}<->{}", downstreamConnection, upstreamConnection);
298
299 HttpServletResponse response = connectContext.getResponse();
300 sendConnectResponse(request, response, HttpServletResponse.SC_OK);
301
302 upgradeConnection(request, response, downstreamConnection);
303 connectContext.getAsyncContext().complete();
304 }
305
306 protected void onConnectFailure(HttpServletRequest request, HttpServletResponse response, AsyncContext asyncContext, Throwable failure)
307 {
308 if (LOG.isDebugEnabled())
309 LOG.debug("CONNECT failed", failure);
310 sendConnectResponse(request, response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
311 if (asyncContext != null)
312 asyncContext.complete();
313 }
314
315 private void sendConnectResponse(HttpServletRequest request, HttpServletResponse response, int statusCode)
316 {
317 try
318 {
319 response.setStatus(statusCode);
320 if (statusCode != HttpServletResponse.SC_OK)
321 response.setHeader(HttpHeader.CONNECTION.asString(), HttpHeaderValue.CLOSE.asString());
322 response.getOutputStream().close();
323 if (LOG.isDebugEnabled())
324 LOG.debug("CONNECT response sent {} {}", request.getProtocol(), response.getStatus());
325 }
326 catch (IOException x)
327 {
328
329 }
330 }
331
332
333
334
335
336
337
338
339
340
341 protected boolean handleAuthentication(HttpServletRequest request, HttpServletResponse response, String address)
342 {
343 return true;
344 }
345
346 protected DownstreamConnection newDownstreamConnection(EndPoint endPoint, ConcurrentMap<String, Object> context, ByteBuffer buffer)
347 {
348 return new DownstreamConnection(endPoint, getExecutor(), getByteBufferPool(), context, buffer);
349 }
350
351 protected UpstreamConnection newUpstreamConnection(EndPoint endPoint, ConnectContext connectContext)
352 {
353 return new UpstreamConnection(endPoint, getExecutor(), getByteBufferPool(), connectContext);
354 }
355
356 protected void prepareContext(HttpServletRequest request, ConcurrentMap<String, Object> context)
357 {
358 }
359
360 private void upgradeConnection(HttpServletRequest request, HttpServletResponse response, Connection connection)
361 {
362
363
364 request.setAttribute(HttpConnection.UPGRADE_CONNECTION_ATTRIBUTE, connection);
365 response.setStatus(HttpServletResponse.SC_SWITCHING_PROTOCOLS);
366 if (LOG.isDebugEnabled())
367 LOG.debug("Upgraded connection to {}", connection);
368 }
369
370
371
372
373
374
375
376
377
378
379 protected int read(EndPoint endPoint, ByteBuffer buffer) throws IOException
380 {
381 return endPoint.fill(buffer);
382 }
383
384
385
386
387
388
389
390
391 protected void write(EndPoint endPoint, ByteBuffer buffer, Callback callback)
392 {
393 if (LOG.isDebugEnabled())
394 LOG.debug("{} writing {} bytes", this, buffer.remaining());
395 endPoint.write(callback, buffer);
396 }
397
398 public Set<String> getWhiteListHosts()
399 {
400 return whiteList;
401 }
402
403 public Set<String> getBlackListHosts()
404 {
405 return blackList;
406 }
407
408
409
410
411
412
413
414
415 public boolean validateDestination(String host, int port)
416 {
417 String hostPort = host + ":" + port;
418 if (!whiteList.isEmpty())
419 {
420 if (!whiteList.contains(hostPort))
421 {
422 if (LOG.isDebugEnabled())
423 LOG.debug("Host {}:{} not whitelisted", host, port);
424 return false;
425 }
426 }
427 if (!blackList.isEmpty())
428 {
429 if (blackList.contains(hostPort))
430 {
431 if (LOG.isDebugEnabled())
432 LOG.debug("Host {}:{} blacklisted", host, port);
433 return false;
434 }
435 }
436 return true;
437 }
438
439 @Override
440 public void dump(Appendable out, String indent) throws IOException
441 {
442 dumpThis(out);
443 dump(out, indent, getBeans(), TypeUtil.asList(getHandlers()));
444 }
445
446 protected class ConnectManager extends SelectorManager
447 {
448 protected ConnectManager(Executor executor, Scheduler scheduler, int selectors)
449 {
450 super(executor, scheduler, selectors);
451 }
452
453 @Override
454 protected EndPoint newEndPoint(SocketChannel channel, ManagedSelector selector, SelectionKey selectionKey) throws IOException
455 {
456 return new SelectChannelEndPoint(channel, selector, selectionKey, getScheduler(), getIdleTimeout());
457 }
458
459 @Override
460 public Connection newConnection(SocketChannel channel, EndPoint endpoint, Object attachment) throws IOException
461 {
462 if (ConnectHandler.LOG.isDebugEnabled())
463 ConnectHandler.LOG.debug("Connected to {}", channel.getRemoteAddress());
464 ConnectContext connectContext = (ConnectContext)attachment;
465 UpstreamConnection connection = newUpstreamConnection(endpoint, connectContext);
466 connection.setInputBufferSize(getBufferSize());
467 return connection;
468 }
469
470 @Override
471 protected void connectionFailed(SocketChannel channel, final Throwable ex, final Object attachment)
472 {
473 getExecutor().execute(new Runnable()
474 {
475 public void run()
476 {
477 ConnectContext connectContext = (ConnectContext)attachment;
478 onConnectFailure(connectContext.request, connectContext.response, connectContext.asyncContext, ex);
479 }
480 });
481 }
482 }
483
484 protected static class ConnectContext
485 {
486 private final ConcurrentMap<String, Object> context = new ConcurrentHashMap<>();
487 private final HttpServletRequest request;
488 private final HttpServletResponse response;
489 private final AsyncContext asyncContext;
490 private final HttpConnection httpConnection;
491
492 public ConnectContext(HttpServletRequest request, HttpServletResponse response, AsyncContext asyncContext, HttpConnection httpConnection)
493 {
494 this.request = request;
495 this.response = response;
496 this.asyncContext = asyncContext;
497 this.httpConnection = httpConnection;
498 }
499
500 public ConcurrentMap<String, Object> getContext()
501 {
502 return context;
503 }
504
505 public HttpServletRequest getRequest()
506 {
507 return request;
508 }
509
510 public HttpServletResponse getResponse()
511 {
512 return response;
513 }
514
515 public AsyncContext getAsyncContext()
516 {
517 return asyncContext;
518 }
519
520 public HttpConnection getHttpConnection()
521 {
522 return httpConnection;
523 }
524 }
525
526 public class UpstreamConnection extends ProxyConnection
527 {
528 private ConnectContext connectContext;
529
530 public UpstreamConnection(EndPoint endPoint, Executor executor, ByteBufferPool bufferPool, ConnectContext connectContext)
531 {
532 super(endPoint, executor, bufferPool, connectContext.getContext());
533 this.connectContext = connectContext;
534 }
535
536 @Override
537 public void onOpen()
538 {
539 super.onOpen();
540 getExecutor().execute(new Runnable()
541 {
542 public void run()
543 {
544 onConnectSuccess(connectContext, UpstreamConnection.this);
545 fillInterested();
546 }
547 });
548 }
549
550 @Override
551 protected int read(EndPoint endPoint, ByteBuffer buffer) throws IOException
552 {
553 return ConnectHandler.this.read(endPoint, buffer);
554 }
555
556 @Override
557 protected void write(EndPoint endPoint, ByteBuffer buffer,Callback callback)
558 {
559 ConnectHandler.this.write(endPoint, buffer, callback);
560 }
561 }
562
563 public class DownstreamConnection extends ProxyConnection
564 {
565 private final ByteBuffer buffer;
566
567 public DownstreamConnection(EndPoint endPoint, Executor executor, ByteBufferPool bufferPool, ConcurrentMap<String, Object> context, ByteBuffer buffer)
568 {
569 super(endPoint, executor, bufferPool, context);
570 this.buffer = buffer;
571 }
572
573 @Override
574 public void onOpen()
575 {
576 super.onOpen();
577 final int remaining = buffer.remaining();
578 write(getConnection().getEndPoint(), buffer, new Callback()
579 {
580 @Override
581 public void succeeded()
582 {
583 if (LOG.isDebugEnabled())
584 LOG.debug("{} wrote initial {} bytes to server", DownstreamConnection.this, remaining);
585 fillInterested();
586 }
587
588 @Override
589 public void failed(Throwable x)
590 {
591 if (LOG.isDebugEnabled())
592 LOG.debug(this + " failed to write initial " + remaining + " bytes to server", x);
593 close();
594 getConnection().close();
595 }
596 });
597 }
598
599 @Override
600 protected int read(EndPoint endPoint, ByteBuffer buffer) throws IOException
601 {
602 return ConnectHandler.this.read(endPoint, buffer);
603 }
604
605 @Override
606 protected void write(EndPoint endPoint, ByteBuffer buffer, Callback callback)
607 {
608 ConnectHandler.this.write(endPoint, buffer, callback);
609 }
610 }
611 }