1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19 package org.eclipse.jetty.proxy;
20
21 import java.io.IOException;
22 import java.net.InetSocketAddress;
23 import java.nio.ByteBuffer;
24 import java.nio.channels.SelectionKey;
25 import java.nio.channels.SocketChannel;
26 import java.util.HashSet;
27 import java.util.Set;
28 import java.util.concurrent.ConcurrentHashMap;
29 import java.util.concurrent.ConcurrentMap;
30 import java.util.concurrent.Executor;
31 import javax.servlet.AsyncContext;
32 import javax.servlet.ServletException;
33 import javax.servlet.http.HttpServletRequest;
34 import javax.servlet.http.HttpServletResponse;
35
36 import org.eclipse.jetty.http.HttpHeader;
37 import org.eclipse.jetty.http.HttpHeaderValue;
38 import org.eclipse.jetty.http.HttpMethod;
39 import org.eclipse.jetty.io.ByteBufferPool;
40 import org.eclipse.jetty.io.Connection;
41 import org.eclipse.jetty.io.EndPoint;
42 import org.eclipse.jetty.io.MappedByteBufferPool;
43 import org.eclipse.jetty.io.SelectChannelEndPoint;
44 import org.eclipse.jetty.io.SelectorManager;
45 import org.eclipse.jetty.server.Handler;
46 import org.eclipse.jetty.server.HttpConnection;
47 import org.eclipse.jetty.server.Request;
48 import org.eclipse.jetty.server.handler.HandlerWrapper;
49 import org.eclipse.jetty.util.BufferUtil;
50 import org.eclipse.jetty.util.Callback;
51 import org.eclipse.jetty.util.TypeUtil;
52 import org.eclipse.jetty.util.log.Log;
53 import org.eclipse.jetty.util.log.Logger;
54 import org.eclipse.jetty.util.thread.ScheduledExecutorScheduler;
55 import org.eclipse.jetty.util.thread.Scheduler;
56
57
58
59
60 public class ConnectHandler extends HandlerWrapper
61 {
62 protected static final Logger LOG = Log.getLogger(ConnectHandler.class);
63
64 private final Set<String> whiteList = new HashSet<>();
65 private final Set<String> blackList = new HashSet<>();
66 private Executor executor;
67 private Scheduler scheduler;
68 private ByteBufferPool bufferPool;
69 private SelectorManager selector;
70 private long connectTimeout = 15000;
71 private long idleTimeout = 30000;
72 private int bufferSize = 4096;
73
74 public ConnectHandler()
75 {
76 this(null);
77 }
78
79 public ConnectHandler(Handler handler)
80 {
81 setHandler(handler);
82 }
83
84 public Executor getExecutor()
85 {
86 return executor;
87 }
88
89 public void setExecutor(Executor executor)
90 {
91 this.executor = executor;
92 }
93
94 public Scheduler getScheduler()
95 {
96 return scheduler;
97 }
98
99 public void setScheduler(Scheduler scheduler)
100 {
101 this.scheduler = scheduler;
102 }
103
104 public ByteBufferPool getByteBufferPool()
105 {
106 return bufferPool;
107 }
108
109 public void setByteBufferPool(ByteBufferPool bufferPool)
110 {
111 this.bufferPool = bufferPool;
112 }
113
114
115
116
117 public long getConnectTimeout()
118 {
119 return connectTimeout;
120 }
121
122
123
124
125 public void setConnectTimeout(long connectTimeout)
126 {
127 this.connectTimeout = connectTimeout;
128 }
129
130
131
132
133 public long getIdleTimeout()
134 {
135 return idleTimeout;
136 }
137
138
139
140
141 public void setIdleTimeout(long idleTimeout)
142 {
143 this.idleTimeout = idleTimeout;
144 }
145
146 public int getBufferSize()
147 {
148 return bufferSize;
149 }
150
151 public void setBufferSize(int bufferSize)
152 {
153 this.bufferSize = bufferSize;
154 }
155
156 @Override
157 protected void doStart() throws Exception
158 {
159 if (executor == null)
160 {
161 setExecutor(getServer().getThreadPool());
162 }
163 if (scheduler == null)
164 {
165 setScheduler(new ScheduledExecutorScheduler());
166 addBean(getScheduler());
167 }
168 if (bufferPool == null)
169 {
170 setByteBufferPool(new MappedByteBufferPool());
171 addBean(getByteBufferPool());
172 }
173 addBean(selector = newSelectorManager());
174 selector.setConnectTimeout(getConnectTimeout());
175 super.doStart();
176 }
177
178 protected SelectorManager newSelectorManager()
179 {
180 return new Manager(getExecutor(), getScheduler(), 1);
181 }
182
183 @Override
184 public void handle(String target, Request baseRequest, HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException
185 {
186 if (HttpMethod.CONNECT.is(request.getMethod()))
187 {
188 String serverAddress = request.getRequestURI();
189 LOG.debug("CONNECT request for {}", serverAddress);
190 try
191 {
192 handleConnect(baseRequest, request, response, serverAddress);
193 }
194 catch (Exception x)
195 {
196
197 LOG.warn("ConnectHandler " + baseRequest.getUri() + " " + x);
198 LOG.debug(x);
199 }
200 }
201 else
202 {
203 super.handle(target, baseRequest, request, response);
204 }
205 }
206
207
208
209
210
211
212
213
214
215
216
217 protected void handleConnect(Request jettyRequest, HttpServletRequest request, HttpServletResponse response, String serverAddress)
218 {
219 jettyRequest.setHandled(true);
220 try
221 {
222 boolean proceed = handleAuthentication(request, response, serverAddress);
223 if (!proceed)
224 {
225 LOG.debug("Missing proxy authentication");
226 sendConnectResponse(request, response, HttpServletResponse.SC_PROXY_AUTHENTICATION_REQUIRED);
227 return;
228 }
229
230 String host = serverAddress;
231 int port = 80;
232 int colon = serverAddress.indexOf(':');
233 if (colon > 0)
234 {
235 host = serverAddress.substring(0, colon);
236 port = Integer.parseInt(serverAddress.substring(colon + 1));
237 }
238
239 if (!validateDestination(host, port))
240 {
241 LOG.debug("Destination {}:{} forbidden", host, port);
242 sendConnectResponse(request, response, HttpServletResponse.SC_FORBIDDEN);
243 return;
244 }
245
246 SocketChannel channel = SocketChannel.open();
247 channel.socket().setTcpNoDelay(true);
248 channel.configureBlocking(false);
249 InetSocketAddress address = new InetSocketAddress(host, port);
250 channel.connect(address);
251
252 AsyncContext asyncContext = request.startAsync();
253 asyncContext.setTimeout(0);
254
255 LOG.debug("Connecting to {}", address);
256 ConnectContext connectContext = new ConnectContext(request, response, asyncContext, HttpConnection.getCurrentConnection());
257 selector.connect(channel, connectContext);
258 }
259 catch (Exception x)
260 {
261 onConnectFailure(request, response, null, x);
262 }
263 }
264
265 protected void onConnectSuccess(ConnectContext connectContext, UpstreamConnection upstreamConnection)
266 {
267 HttpConnection httpConnection = connectContext.getHttpConnection();
268 ByteBuffer requestBuffer = httpConnection.getRequestBuffer();
269 ByteBuffer buffer = BufferUtil.EMPTY_BUFFER;
270 int remaining = requestBuffer.remaining();
271 if (remaining > 0)
272 {
273 buffer = bufferPool.acquire(remaining, requestBuffer.isDirect());
274 BufferUtil.flipToFill(buffer);
275 buffer.put(requestBuffer);
276 buffer.flip();
277 }
278
279 ConcurrentMap<String, Object> context = connectContext.getContext();
280 HttpServletRequest request = connectContext.getRequest();
281 prepareContext(request, context);
282
283 EndPoint downstreamEndPoint = httpConnection.getEndPoint();
284 DownstreamConnection downstreamConnection = newDownstreamConnection(downstreamEndPoint, context, buffer);
285 downstreamConnection.setInputBufferSize(getBufferSize());
286
287 upstreamConnection.setConnection(downstreamConnection);
288 downstreamConnection.setConnection(upstreamConnection);
289 LOG.debug("Connection setup completed: {}<->{}", downstreamConnection, upstreamConnection);
290
291 HttpServletResponse response = connectContext.getResponse();
292 sendConnectResponse(request, response, HttpServletResponse.SC_OK);
293
294 upgradeConnection(request, response, downstreamConnection);
295 connectContext.getAsyncContext().complete();
296 }
297
298 protected void onConnectFailure(HttpServletRequest request, HttpServletResponse response, AsyncContext asyncContext, Throwable failure)
299 {
300 LOG.debug("CONNECT failed", failure);
301 sendConnectResponse(request, response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
302 if (asyncContext != null)
303 asyncContext.complete();
304 }
305
306 private void sendConnectResponse(HttpServletRequest request, HttpServletResponse response, int statusCode)
307 {
308 try
309 {
310 response.setStatus(statusCode);
311 if (statusCode != HttpServletResponse.SC_OK)
312 response.setHeader(HttpHeader.CONNECTION.asString(), HttpHeaderValue.CLOSE.asString());
313 response.getOutputStream().close();
314 LOG.debug("CONNECT response sent {} {}", request.getProtocol(), response.getStatus());
315 }
316 catch (IOException x)
317 {
318
319 }
320 }
321
322
323
324
325
326
327
328
329
330
331 protected boolean handleAuthentication(HttpServletRequest request, HttpServletResponse response, String address)
332 {
333 return true;
334 }
335
336 protected DownstreamConnection newDownstreamConnection(EndPoint endPoint, ConcurrentMap<String, Object> context, ByteBuffer buffer)
337 {
338 return new DownstreamConnection(endPoint, getExecutor(), getByteBufferPool(), context, buffer);
339 }
340
341 protected UpstreamConnection newUpstreamConnection(EndPoint endPoint, ConnectContext connectContext)
342 {
343 return new UpstreamConnection(endPoint, getExecutor(), getByteBufferPool(), connectContext);
344 }
345
346 protected void prepareContext(HttpServletRequest request, ConcurrentMap<String, Object> context)
347 {
348 }
349
350 private void upgradeConnection(HttpServletRequest request, HttpServletResponse response, Connection connection)
351 {
352
353
354 request.setAttribute(HttpConnection.UPGRADE_CONNECTION_ATTRIBUTE, connection);
355 response.setStatus(HttpServletResponse.SC_SWITCHING_PROTOCOLS);
356 LOG.debug("Upgraded connection to {}", connection);
357 }
358
359
360
361
362
363
364
365
366
367
368 protected int read(EndPoint endPoint, ByteBuffer buffer) throws IOException
369 {
370 return endPoint.fill(buffer);
371 }
372
373
374
375
376
377
378
379
380 protected void write(EndPoint endPoint, ByteBuffer buffer, Callback callback)
381 {
382 LOG.debug("{} writing {} bytes", this, buffer.remaining());
383 endPoint.write(callback, buffer);
384 }
385
386 public Set<String> getWhiteListHosts()
387 {
388 return whiteList;
389 }
390
391 public Set<String> getBlackListHosts()
392 {
393 return blackList;
394 }
395
396
397
398
399
400
401
402
403 public boolean validateDestination(String host, int port)
404 {
405 String hostPort = host + ":" + port;
406 if (!whiteList.isEmpty())
407 {
408 if (!whiteList.contains(hostPort))
409 {
410 LOG.debug("Host {}:{} not whitelisted", host, port);
411 return false;
412 }
413 }
414 if (!blackList.isEmpty())
415 {
416 if (blackList.contains(hostPort))
417 {
418 LOG.debug("Host {}:{} blacklisted", host, port);
419 return false;
420 }
421 }
422 return true;
423 }
424
425 @Override
426 public void dump(Appendable out, String indent) throws IOException
427 {
428 dumpThis(out);
429 dump(out, indent, getBeans(), TypeUtil.asList(getHandlers()));
430 }
431
432 protected class Manager extends SelectorManager
433 {
434
435 private Manager(Executor executor, Scheduler scheduler, int selectors)
436 {
437 super(executor, scheduler, selectors);
438 }
439
440 @Override
441 protected EndPoint newEndPoint(SocketChannel channel, ManagedSelector selector, SelectionKey selectionKey) throws IOException
442 {
443 return new SelectChannelEndPoint(channel, selector, selectionKey, getScheduler(), getIdleTimeout());
444 }
445
446 @Override
447 public Connection newConnection(SocketChannel channel, EndPoint endpoint, Object attachment) throws IOException
448 {
449 ConnectHandler.LOG.debug("Connected to {}", channel.getRemoteAddress());
450 ConnectContext connectContext = (ConnectContext)attachment;
451 UpstreamConnection connection = newUpstreamConnection(endpoint, connectContext);
452 connection.setInputBufferSize(getBufferSize());
453 return connection;
454 }
455
456 @Override
457 protected void connectionFailed(SocketChannel channel, Throwable ex, Object attachment)
458 {
459 ConnectContext connectContext = (ConnectContext)attachment;
460 onConnectFailure(connectContext.request, connectContext.response, connectContext.asyncContext, ex);
461 }
462 }
463
464 protected static class ConnectContext
465 {
466 private final ConcurrentMap<String, Object> context = new ConcurrentHashMap<>();
467 private final HttpServletRequest request;
468 private final HttpServletResponse response;
469 private final AsyncContext asyncContext;
470 private final HttpConnection httpConnection;
471
472 public ConnectContext(HttpServletRequest request, HttpServletResponse response, AsyncContext asyncContext, HttpConnection httpConnection)
473 {
474 this.request = request;
475 this.response = response;
476 this.asyncContext = asyncContext;
477 this.httpConnection = httpConnection;
478 }
479
480 public ConcurrentMap<String, Object> getContext()
481 {
482 return context;
483 }
484
485 public HttpServletRequest getRequest()
486 {
487 return request;
488 }
489
490 public HttpServletResponse getResponse()
491 {
492 return response;
493 }
494
495 public AsyncContext getAsyncContext()
496 {
497 return asyncContext;
498 }
499
500 public HttpConnection getHttpConnection()
501 {
502 return httpConnection;
503 }
504 }
505
506 public class UpstreamConnection extends ProxyConnection
507 {
508 private ConnectContext connectContext;
509
510 public UpstreamConnection(EndPoint endPoint, Executor executor, ByteBufferPool bufferPool, ConnectContext connectContext)
511 {
512 super(endPoint, executor, bufferPool, connectContext.getContext());
513 this.connectContext = connectContext;
514 }
515
516 @Override
517 public void onOpen()
518 {
519 super.onOpen();
520 onConnectSuccess(connectContext, this);
521 fillInterested();
522 }
523
524 @Override
525 protected int read(EndPoint endPoint, ByteBuffer buffer) throws IOException
526 {
527 return ConnectHandler.this.read(endPoint, buffer);
528 }
529
530 @Override
531 protected void write(EndPoint endPoint, ByteBuffer buffer,Callback callback)
532 {
533 ConnectHandler.this.write(endPoint, buffer, callback);
534 }
535 }
536
537 public class DownstreamConnection extends ProxyConnection
538 {
539 private final ByteBuffer buffer;
540
541 public DownstreamConnection(EndPoint endPoint, Executor executor, ByteBufferPool bufferPool, ConcurrentMap<String, Object> context, ByteBuffer buffer)
542 {
543 super(endPoint, executor, bufferPool, context);
544 this.buffer = buffer;
545 }
546
547 @Override
548 public void onOpen()
549 {
550 super.onOpen();
551 final int remaining = buffer.remaining();
552 write(getConnection().getEndPoint(), buffer, new Callback()
553 {
554 @Override
555 public void succeeded()
556 {
557 LOG.debug("{} wrote initial {} bytes to server", DownstreamConnection.this, remaining);
558 fillInterested();
559 }
560
561 @Override
562 public void failed(Throwable x)
563 {
564 LOG.debug(this + " failed to write initial " + remaining + " bytes to server", x);
565 close();
566 getConnection().close();
567 }
568 });
569 }
570
571 @Override
572 protected int read(EndPoint endPoint, ByteBuffer buffer) throws IOException
573 {
574 return ConnectHandler.this.read(endPoint, buffer);
575 }
576
577 @Override
578 protected void write(EndPoint endPoint, ByteBuffer buffer, Callback callback)
579 {
580 ConnectHandler.this.write(endPoint, buffer, callback);
581 }
582 }
583 }