1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19 package org.eclipse.jetty.proxy;
20
21 import java.io.IOException;
22 import java.net.InetSocketAddress;
23 import java.nio.ByteBuffer;
24 import java.nio.channels.SelectionKey;
25 import java.nio.channels.SocketChannel;
26 import java.util.HashSet;
27 import java.util.Set;
28 import java.util.concurrent.ConcurrentHashMap;
29 import java.util.concurrent.ConcurrentMap;
30 import java.util.concurrent.Executor;
31
32 import javax.servlet.AsyncContext;
33 import javax.servlet.ServletException;
34 import javax.servlet.http.HttpServletRequest;
35 import javax.servlet.http.HttpServletResponse;
36
37 import org.eclipse.jetty.http.HttpHeader;
38 import org.eclipse.jetty.http.HttpHeaderValue;
39 import org.eclipse.jetty.http.HttpMethod;
40 import org.eclipse.jetty.io.ByteBufferPool;
41 import org.eclipse.jetty.io.Connection;
42 import org.eclipse.jetty.io.EndPoint;
43 import org.eclipse.jetty.io.MappedByteBufferPool;
44 import org.eclipse.jetty.io.SelectChannelEndPoint;
45 import org.eclipse.jetty.io.SelectorManager;
46 import org.eclipse.jetty.server.Handler;
47 import org.eclipse.jetty.server.HttpConnection;
48 import org.eclipse.jetty.server.Request;
49 import org.eclipse.jetty.server.handler.HandlerWrapper;
50 import org.eclipse.jetty.util.BufferUtil;
51 import org.eclipse.jetty.util.Callback;
52 import org.eclipse.jetty.util.TypeUtil;
53 import org.eclipse.jetty.util.log.Log;
54 import org.eclipse.jetty.util.log.Logger;
55 import org.eclipse.jetty.util.thread.ScheduledExecutorScheduler;
56 import org.eclipse.jetty.util.thread.Scheduler;
57
58
59
60
61 public class ConnectHandler extends HandlerWrapper
62 {
63 protected static final Logger LOG = Log.getLogger(ConnectHandler.class);
64
65 private final Set<String> whiteList = new HashSet<>();
66 private final Set<String> blackList = new HashSet<>();
67 private Executor executor;
68 private Scheduler scheduler;
69 private ByteBufferPool bufferPool;
70 private SelectorManager selector;
71 private long connectTimeout = 15000;
72 private long idleTimeout = 30000;
73 private int bufferSize = 4096;
74
75 public ConnectHandler()
76 {
77 this(null);
78 }
79
80 public ConnectHandler(Handler handler)
81 {
82 setHandler(handler);
83 }
84
85 public Executor getExecutor()
86 {
87 return executor;
88 }
89
90 public void setExecutor(Executor executor)
91 {
92 this.executor = executor;
93 }
94
95 public Scheduler getScheduler()
96 {
97 return scheduler;
98 }
99
100 public void setScheduler(Scheduler scheduler)
101 {
102 this.scheduler = scheduler;
103 }
104
105 public ByteBufferPool getByteBufferPool()
106 {
107 return bufferPool;
108 }
109
110 public void setByteBufferPool(ByteBufferPool bufferPool)
111 {
112 this.bufferPool = bufferPool;
113 }
114
115
116
117
118 public long getConnectTimeout()
119 {
120 return connectTimeout;
121 }
122
123
124
125
126 public void setConnectTimeout(long connectTimeout)
127 {
128 this.connectTimeout = connectTimeout;
129 }
130
131
132
133
134 public long getIdleTimeout()
135 {
136 return idleTimeout;
137 }
138
139
140
141
142 public void setIdleTimeout(long idleTimeout)
143 {
144 this.idleTimeout = idleTimeout;
145 }
146
147 public int getBufferSize()
148 {
149 return bufferSize;
150 }
151
152 public void setBufferSize(int bufferSize)
153 {
154 this.bufferSize = bufferSize;
155 }
156
157 @Override
158 protected void doStart() throws Exception
159 {
160 if (executor == null)
161 {
162 setExecutor(getServer().getThreadPool());
163 }
164 if (scheduler == null)
165 {
166 setScheduler(new ScheduledExecutorScheduler());
167 addBean(getScheduler());
168 }
169 if (bufferPool == null)
170 {
171 setByteBufferPool(new MappedByteBufferPool());
172 addBean(getByteBufferPool());
173 }
174 addBean(selector = newSelectorManager());
175 selector.setConnectTimeout(getConnectTimeout());
176 super.doStart();
177 }
178
179 protected SelectorManager newSelectorManager()
180 {
181 return new ConnectManager(getExecutor(), getScheduler(), 1);
182 }
183
184 @Override
185 public void handle(String target, Request baseRequest, HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException
186 {
187 if (HttpMethod.CONNECT.is(request.getMethod()))
188 {
189 String serverAddress = request.getRequestURI();
190 LOG.debug("CONNECT request for {}", serverAddress);
191 try
192 {
193 handleConnect(baseRequest, request, response, serverAddress);
194 }
195 catch (Exception x)
196 {
197
198 LOG.warn("ConnectHandler " + baseRequest.getUri() + " " + x);
199 LOG.debug(x);
200 }
201 }
202 else
203 {
204 super.handle(target, baseRequest, request, response);
205 }
206 }
207
208
209
210
211
212
213
214
215
216
217
218 protected void handleConnect(Request jettyRequest, HttpServletRequest request, HttpServletResponse response, String serverAddress)
219 {
220 jettyRequest.setHandled(true);
221 try
222 {
223 boolean proceed = handleAuthentication(request, response, serverAddress);
224 if (!proceed)
225 {
226 LOG.debug("Missing proxy authentication");
227 sendConnectResponse(request, response, HttpServletResponse.SC_PROXY_AUTHENTICATION_REQUIRED);
228 return;
229 }
230
231 String host = serverAddress;
232 int port = 80;
233 int colon = serverAddress.indexOf(':');
234 if (colon > 0)
235 {
236 host = serverAddress.substring(0, colon);
237 port = Integer.parseInt(serverAddress.substring(colon + 1));
238 }
239
240 if (!validateDestination(host, port))
241 {
242 LOG.debug("Destination {}:{} forbidden", host, port);
243 sendConnectResponse(request, response, HttpServletResponse.SC_FORBIDDEN);
244 return;
245 }
246
247 SocketChannel channel = SocketChannel.open();
248 channel.socket().setTcpNoDelay(true);
249 channel.configureBlocking(false);
250 InetSocketAddress address = new InetSocketAddress(host, port);
251 channel.connect(address);
252
253 AsyncContext asyncContext = request.startAsync();
254 asyncContext.setTimeout(0);
255
256 LOG.debug("Connecting to {}", address);
257 ConnectContext connectContext = new ConnectContext(request, response, asyncContext, HttpConnection.getCurrentConnection());
258 selector.connect(channel, connectContext);
259 }
260 catch (Exception x)
261 {
262 onConnectFailure(request, response, null, x);
263 }
264 }
265
266 protected void onConnectSuccess(ConnectContext connectContext, UpstreamConnection upstreamConnection)
267 {
268 HttpConnection httpConnection = connectContext.getHttpConnection();
269 ByteBuffer requestBuffer = httpConnection.getRequestBuffer();
270 ByteBuffer buffer = BufferUtil.EMPTY_BUFFER;
271 int remaining = requestBuffer.remaining();
272 if (remaining > 0)
273 {
274 buffer = bufferPool.acquire(remaining, requestBuffer.isDirect());
275 BufferUtil.flipToFill(buffer);
276 buffer.put(requestBuffer);
277 buffer.flip();
278 }
279
280 ConcurrentMap<String, Object> context = connectContext.getContext();
281 HttpServletRequest request = connectContext.getRequest();
282 prepareContext(request, context);
283
284 EndPoint downstreamEndPoint = httpConnection.getEndPoint();
285 DownstreamConnection downstreamConnection = newDownstreamConnection(downstreamEndPoint, context, buffer);
286 downstreamConnection.setInputBufferSize(getBufferSize());
287
288 upstreamConnection.setConnection(downstreamConnection);
289 downstreamConnection.setConnection(upstreamConnection);
290 LOG.debug("Connection setup completed: {}<->{}", downstreamConnection, upstreamConnection);
291
292 HttpServletResponse response = connectContext.getResponse();
293 sendConnectResponse(request, response, HttpServletResponse.SC_OK);
294
295 upgradeConnection(request, response, downstreamConnection);
296 connectContext.getAsyncContext().complete();
297 }
298
299 protected void onConnectFailure(HttpServletRequest request, HttpServletResponse response, AsyncContext asyncContext, Throwable failure)
300 {
301 LOG.debug("CONNECT failed", failure);
302 sendConnectResponse(request, response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
303 if (asyncContext != null)
304 asyncContext.complete();
305 }
306
307 private void sendConnectResponse(HttpServletRequest request, HttpServletResponse response, int statusCode)
308 {
309 try
310 {
311 response.setStatus(statusCode);
312 if (statusCode != HttpServletResponse.SC_OK)
313 response.setHeader(HttpHeader.CONNECTION.asString(), HttpHeaderValue.CLOSE.asString());
314 response.getOutputStream().close();
315 LOG.debug("CONNECT response sent {} {}", request.getProtocol(), response.getStatus());
316 }
317 catch (IOException x)
318 {
319
320 }
321 }
322
323
324
325
326
327
328
329
330
331
332 protected boolean handleAuthentication(HttpServletRequest request, HttpServletResponse response, String address)
333 {
334 return true;
335 }
336
337 protected DownstreamConnection newDownstreamConnection(EndPoint endPoint, ConcurrentMap<String, Object> context, ByteBuffer buffer)
338 {
339 return new DownstreamConnection(endPoint, getExecutor(), getByteBufferPool(), context, buffer);
340 }
341
342 protected UpstreamConnection newUpstreamConnection(EndPoint endPoint, ConnectContext connectContext)
343 {
344 return new UpstreamConnection(endPoint, getExecutor(), getByteBufferPool(), connectContext);
345 }
346
347 protected void prepareContext(HttpServletRequest request, ConcurrentMap<String, Object> context)
348 {
349 }
350
351 private void upgradeConnection(HttpServletRequest request, HttpServletResponse response, Connection connection)
352 {
353
354
355 request.setAttribute(HttpConnection.UPGRADE_CONNECTION_ATTRIBUTE, connection);
356 response.setStatus(HttpServletResponse.SC_SWITCHING_PROTOCOLS);
357 LOG.debug("Upgraded connection to {}", connection);
358 }
359
360
361
362
363
364
365
366
367
368
369 protected int read(EndPoint endPoint, ByteBuffer buffer) throws IOException
370 {
371 return endPoint.fill(buffer);
372 }
373
374
375
376
377
378
379
380
381 protected void write(EndPoint endPoint, ByteBuffer buffer, Callback callback)
382 {
383 LOG.debug("{} writing {} bytes", this, buffer.remaining());
384 endPoint.write(callback, buffer);
385 }
386
387 public Set<String> getWhiteListHosts()
388 {
389 return whiteList;
390 }
391
392 public Set<String> getBlackListHosts()
393 {
394 return blackList;
395 }
396
397
398
399
400
401
402
403
404 public boolean validateDestination(String host, int port)
405 {
406 String hostPort = host + ":" + port;
407 if (!whiteList.isEmpty())
408 {
409 if (!whiteList.contains(hostPort))
410 {
411 LOG.debug("Host {}:{} not whitelisted", host, port);
412 return false;
413 }
414 }
415 if (!blackList.isEmpty())
416 {
417 if (blackList.contains(hostPort))
418 {
419 LOG.debug("Host {}:{} blacklisted", host, port);
420 return false;
421 }
422 }
423 return true;
424 }
425
426 @Override
427 public void dump(Appendable out, String indent) throws IOException
428 {
429 dumpThis(out);
430 dump(out, indent, getBeans(), TypeUtil.asList(getHandlers()));
431 }
432
433 protected class ConnectManager extends SelectorManager
434 {
435
436 private ConnectManager(Executor executor, Scheduler scheduler, int selectors)
437 {
438 super(executor, scheduler, selectors);
439 }
440
441 @Override
442 protected EndPoint newEndPoint(SocketChannel channel, ManagedSelector selector, SelectionKey selectionKey) throws IOException
443 {
444 return new SelectChannelEndPoint(channel, selector, selectionKey, getScheduler(), getIdleTimeout());
445 }
446
447 @Override
448 public Connection newConnection(SocketChannel channel, EndPoint endpoint, Object attachment) throws IOException
449 {
450 ConnectHandler.LOG.debug("Connected to {}", channel.getRemoteAddress());
451 ConnectContext connectContext = (ConnectContext)attachment;
452 UpstreamConnection connection = newUpstreamConnection(endpoint, connectContext);
453 connection.setInputBufferSize(getBufferSize());
454 return connection;
455 }
456
457 @Override
458 protected void connectionFailed(SocketChannel channel, Throwable ex, Object attachment)
459 {
460 ConnectContext connectContext = (ConnectContext)attachment;
461 onConnectFailure(connectContext.request, connectContext.response, connectContext.asyncContext, ex);
462 }
463 }
464
465 protected static class ConnectContext
466 {
467 private final ConcurrentMap<String, Object> context = new ConcurrentHashMap<>();
468 private final HttpServletRequest request;
469 private final HttpServletResponse response;
470 private final AsyncContext asyncContext;
471 private final HttpConnection httpConnection;
472
473 public ConnectContext(HttpServletRequest request, HttpServletResponse response, AsyncContext asyncContext, HttpConnection httpConnection)
474 {
475 this.request = request;
476 this.response = response;
477 this.asyncContext = asyncContext;
478 this.httpConnection = httpConnection;
479 }
480
481 public ConcurrentMap<String, Object> getContext()
482 {
483 return context;
484 }
485
486 public HttpServletRequest getRequest()
487 {
488 return request;
489 }
490
491 public HttpServletResponse getResponse()
492 {
493 return response;
494 }
495
496 public AsyncContext getAsyncContext()
497 {
498 return asyncContext;
499 }
500
501 public HttpConnection getHttpConnection()
502 {
503 return httpConnection;
504 }
505 }
506
507 public class UpstreamConnection extends ProxyConnection
508 {
509 private ConnectContext connectContext;
510
511 public UpstreamConnection(EndPoint endPoint, Executor executor, ByteBufferPool bufferPool, ConnectContext connectContext)
512 {
513 super(endPoint, executor, bufferPool, connectContext.getContext());
514 this.connectContext = connectContext;
515 }
516
517 @Override
518 public void onOpen()
519 {
520 super.onOpen();
521 onConnectSuccess(connectContext, this);
522 fillInterested();
523 }
524
525 @Override
526 protected int read(EndPoint endPoint, ByteBuffer buffer) throws IOException
527 {
528 return ConnectHandler.this.read(endPoint, buffer);
529 }
530
531 @Override
532 protected void write(EndPoint endPoint, ByteBuffer buffer,Callback callback)
533 {
534 ConnectHandler.this.write(endPoint, buffer, callback);
535 }
536 }
537
538 public class DownstreamConnection extends ProxyConnection
539 {
540 private final ByteBuffer buffer;
541
542 public DownstreamConnection(EndPoint endPoint, Executor executor, ByteBufferPool bufferPool, ConcurrentMap<String, Object> context, ByteBuffer buffer)
543 {
544 super(endPoint, executor, bufferPool, context);
545 this.buffer = buffer;
546 }
547
548 @Override
549 public void onOpen()
550 {
551 super.onOpen();
552 final int remaining = buffer.remaining();
553 write(getConnection().getEndPoint(), buffer, new Callback()
554 {
555 @Override
556 public void succeeded()
557 {
558 LOG.debug("{} wrote initial {} bytes to server", DownstreamConnection.this, remaining);
559 fillInterested();
560 }
561
562 @Override
563 public void failed(Throwable x)
564 {
565 LOG.debug(this + " failed to write initial " + remaining + " bytes to server", x);
566 close();
567 getConnection().close();
568 }
569 });
570 }
571
572 @Override
573 protected int read(EndPoint endPoint, ByteBuffer buffer) throws IOException
574 {
575 return ConnectHandler.this.read(endPoint, buffer);
576 }
577
578 @Override
579 protected void write(EndPoint endPoint, ByteBuffer buffer, Callback callback)
580 {
581 ConnectHandler.this.write(endPoint, buffer, callback);
582 }
583 }
584 }