View Javadoc

1   //
2   //  ========================================================================
3   //  Copyright (c) 1995-2015 Mort Bay Consulting Pty. Ltd.
4   //  ------------------------------------------------------------------------
5   //  All rights reserved. This program and the accompanying materials
6   //  are made available under the terms of the Eclipse Public License v1.0
7   //  and Apache License v2.0 which accompanies this distribution.
8   //
9   //      The Eclipse Public License is available at
10  //      http://www.eclipse.org/legal/epl-v10.html
11  //
12  //      The Apache License v2.0 is available at
13  //      http://www.opensource.org/licenses/apache2.0.php
14  //
15  //  You may elect to redistribute this code under either of these licenses.
16  //  ========================================================================
17  //
18  
19  package org.eclipse.jetty.proxy;
20  
21  import java.io.IOException;
22  import java.net.InetSocketAddress;
23  import java.nio.ByteBuffer;
24  import java.nio.channels.SelectionKey;
25  import java.nio.channels.SocketChannel;
26  import java.util.HashSet;
27  import java.util.Set;
28  import java.util.concurrent.ConcurrentHashMap;
29  import java.util.concurrent.ConcurrentMap;
30  import java.util.concurrent.Executor;
31  
32  import javax.servlet.AsyncContext;
33  import javax.servlet.ServletException;
34  import javax.servlet.http.HttpServletRequest;
35  import javax.servlet.http.HttpServletResponse;
36  
37  import org.eclipse.jetty.http.HttpHeader;
38  import org.eclipse.jetty.http.HttpHeaderValue;
39  import org.eclipse.jetty.http.HttpMethod;
40  import org.eclipse.jetty.io.ByteBufferPool;
41  import org.eclipse.jetty.io.Connection;
42  import org.eclipse.jetty.io.EndPoint;
43  import org.eclipse.jetty.io.ManagedSelector;
44  import org.eclipse.jetty.io.MappedByteBufferPool;
45  import org.eclipse.jetty.io.SelectChannelEndPoint;
46  import org.eclipse.jetty.io.SelectorManager;
47  import org.eclipse.jetty.server.Handler;
48  import org.eclipse.jetty.server.HttpConnection;
49  import org.eclipse.jetty.server.HttpTransport;
50  import org.eclipse.jetty.server.Request;
51  import org.eclipse.jetty.server.handler.HandlerWrapper;
52  import org.eclipse.jetty.util.BufferUtil;
53  import org.eclipse.jetty.util.Callback;
54  import org.eclipse.jetty.util.TypeUtil;
55  import org.eclipse.jetty.util.log.Log;
56  import org.eclipse.jetty.util.log.Logger;
57  import org.eclipse.jetty.util.thread.ScheduledExecutorScheduler;
58  import org.eclipse.jetty.util.thread.Scheduler;
59  
60  /**
61   * <p>Implementation of a {@link Handler} that supports HTTP CONNECT.</p>
62   */
63  public class ConnectHandler extends HandlerWrapper
64  {
65      protected static final Logger LOG = Log.getLogger(ConnectHandler.class);
66  
67      private final Set<String> whiteList = new HashSet<>();
68      private final Set<String> blackList = new HashSet<>();
69      private Executor executor;
70      private Scheduler scheduler;
71      private ByteBufferPool bufferPool;
72      private SelectorManager selector;
73      private long connectTimeout = 15000;
74      private long idleTimeout = 30000;
75      private int bufferSize = 4096;
76  
77      public ConnectHandler()
78      {
79          this(null);
80      }
81  
82      public ConnectHandler(Handler handler)
83      {
84          setHandler(handler);
85      }
86  
87      public Executor getExecutor()
88      {
89          return executor;
90      }
91  
92      public void setExecutor(Executor executor)
93      {
94          this.executor = executor;
95      }
96  
97      public Scheduler getScheduler()
98      {
99          return scheduler;
100     }
101 
102     public void setScheduler(Scheduler scheduler)
103     {
104         this.scheduler = scheduler;
105     }
106 
107     public ByteBufferPool getByteBufferPool()
108     {
109         return bufferPool;
110     }
111 
112     public void setByteBufferPool(ByteBufferPool bufferPool)
113     {
114         this.bufferPool = bufferPool;
115     }
116 
117     /**
118      * @return the timeout, in milliseconds, to connect to the remote server
119      */
120     public long getConnectTimeout()
121     {
122         return connectTimeout;
123     }
124 
125     /**
126      * @param connectTimeout the timeout, in milliseconds, to connect to the remote server
127      */
128     public void setConnectTimeout(long connectTimeout)
129     {
130         this.connectTimeout = connectTimeout;
131     }
132 
133     /**
134      * @return the idle timeout, in milliseconds
135      */
136     public long getIdleTimeout()
137     {
138         return idleTimeout;
139     }
140 
141     /**
142      * @param idleTimeout the idle timeout, in milliseconds
143      */
144     public void setIdleTimeout(long idleTimeout)
145     {
146         this.idleTimeout = idleTimeout;
147     }
148 
149     public int getBufferSize()
150     {
151         return bufferSize;
152     }
153 
154     public void setBufferSize(int bufferSize)
155     {
156         this.bufferSize = bufferSize;
157     }
158 
159     @Override
160     protected void doStart() throws Exception
161     {
162         if (executor == null)
163         {
164             setExecutor(getServer().getThreadPool());
165         }
166         if (scheduler == null)
167         {
168             setScheduler(new ScheduledExecutorScheduler());
169             addBean(getScheduler());
170         }
171         if (bufferPool == null)
172         {
173             setByteBufferPool(new MappedByteBufferPool());
174             addBean(getByteBufferPool());
175         }
176         addBean(selector = newSelectorManager());
177         selector.setConnectTimeout(getConnectTimeout());
178         super.doStart();
179     }
180 
181     protected SelectorManager newSelectorManager()
182     {
183         return new ConnectManager(getExecutor(), getScheduler(), 1);
184     }
185 
186     @Override
187     public void handle(String target, Request baseRequest, HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException
188     {
189         if (HttpMethod.CONNECT.is(request.getMethod()))
190         {
191             String serverAddress = request.getRequestURI();
192             if (LOG.isDebugEnabled())
193                 LOG.debug("CONNECT request for {}", serverAddress);
194             try
195             {
196                 handleConnect(baseRequest, request, response, serverAddress);
197             }
198             catch (Exception x)
199             {
200                 // TODO
201                 LOG.warn("ConnectHandler " + baseRequest.getHttpURI() + " " + x);
202                 LOG.debug(x);
203             }
204         }
205         else
206         {
207             super.handle(target, baseRequest, request, response);
208         }
209     }
210 
211     /**
212      * <p>Handles a CONNECT request.</p>
213      * <p>CONNECT requests may have authentication headers such as {@code Proxy-Authorization}
214      * that authenticate the client with the proxy.</p>
215      *
216      * @param baseRequest  Jetty-specific http request
217      * @param request       the http request
218      * @param response      the http response
219      * @param serverAddress the remote server address in the form {@code host:port}
220      */
221     protected void handleConnect(Request baseRequest, HttpServletRequest request, HttpServletResponse response, String serverAddress)
222     {
223         baseRequest.setHandled(true);
224         try
225         {
226             boolean proceed = handleAuthentication(request, response, serverAddress);
227             if (!proceed)
228             {
229                 if (LOG.isDebugEnabled())
230                     LOG.debug("Missing proxy authentication");
231                 sendConnectResponse(request, response, HttpServletResponse.SC_PROXY_AUTHENTICATION_REQUIRED);
232                 return;
233             }
234 
235             String host = serverAddress;
236             int port = 80;
237             int colon = serverAddress.indexOf(':');
238             if (colon > 0)
239             {
240                 host = serverAddress.substring(0, colon);
241                 port = Integer.parseInt(serverAddress.substring(colon + 1));
242             }
243 
244             if (!validateDestination(host, port))
245             {
246                 if (LOG.isDebugEnabled())
247                     LOG.debug("Destination {}:{} forbidden", host, port);
248                 sendConnectResponse(request, response, HttpServletResponse.SC_FORBIDDEN);
249                 return;
250             }
251 
252             SocketChannel channel = SocketChannel.open();
253             channel.socket().setTcpNoDelay(true);
254             channel.configureBlocking(false);
255 
256             AsyncContext asyncContext = request.startAsync();
257             asyncContext.setTimeout(0);
258 
259             HttpTransport transport = baseRequest.getHttpChannel().getHttpTransport();
260             
261             // TODO Handle CONNECT over HTTP2!
262             if (!(transport instanceof HttpConnection))
263             {
264                 if (LOG.isDebugEnabled())
265                     LOG.debug("CONNECT forbidden for {}", transport);
266                 sendConnectResponse(request, response, HttpServletResponse.SC_FORBIDDEN);
267                 return;
268             }
269 
270             InetSocketAddress address = newConnectAddress(host, port);
271             if (LOG.isDebugEnabled())
272                 LOG.debug("Connecting to {}", address);
273             ConnectContext connectContext = new ConnectContext(request, response, asyncContext, (HttpConnection)transport);
274             if (channel.connect(address))
275                 selector.accept(channel, connectContext);
276             else
277                 selector.connect(channel, connectContext);
278         }
279         catch (Exception x)
280         {
281             onConnectFailure(request, response, null, x);
282         }
283     }
284 
285     /* ------------------------------------------------------------ */
286     /** Create the address the connect channel will connect to.
287      * @param host The host from the connect request
288      * @param port The port from the connect request
289      * @return The InetSocketAddress to connect to.
290      */
291     protected InetSocketAddress newConnectAddress(String host, int port)
292     {
293         return new InetSocketAddress(host, port);
294     }
295     
296     protected void onConnectSuccess(ConnectContext connectContext, UpstreamConnection upstreamConnection)
297     {
298         HttpConnection httpConnection = connectContext.getHttpConnection();
299         ByteBuffer requestBuffer = httpConnection.getRequestBuffer();
300         ByteBuffer buffer = BufferUtil.EMPTY_BUFFER;
301         int remaining = requestBuffer.remaining();
302         if (remaining > 0)
303         {
304             buffer = bufferPool.acquire(remaining, requestBuffer.isDirect());
305             BufferUtil.flipToFill(buffer);
306             buffer.put(requestBuffer);
307             buffer.flip();
308         }
309 
310         ConcurrentMap<String, Object> context = connectContext.getContext();
311         HttpServletRequest request = connectContext.getRequest();
312         prepareContext(request, context);
313 
314         EndPoint downstreamEndPoint = httpConnection.getEndPoint();
315         DownstreamConnection downstreamConnection = newDownstreamConnection(downstreamEndPoint, context, buffer);
316         downstreamConnection.setInputBufferSize(getBufferSize());
317 
318         upstreamConnection.setConnection(downstreamConnection);
319         downstreamConnection.setConnection(upstreamConnection);
320         if (LOG.isDebugEnabled())
321             LOG.debug("Connection setup completed: {}<->{}", downstreamConnection, upstreamConnection);
322 
323         HttpServletResponse response = connectContext.getResponse();
324         sendConnectResponse(request, response, HttpServletResponse.SC_OK);
325 
326         upgradeConnection(request, response, downstreamConnection);
327         connectContext.getAsyncContext().complete();
328     }
329 
330     protected void onConnectFailure(HttpServletRequest request, HttpServletResponse response, AsyncContext asyncContext, Throwable failure)
331     {
332         if (LOG.isDebugEnabled())
333             LOG.debug("CONNECT failed", failure);
334         sendConnectResponse(request, response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
335         if (asyncContext != null)
336             asyncContext.complete();
337     }
338 
339     private void sendConnectResponse(HttpServletRequest request, HttpServletResponse response, int statusCode)
340     {
341         try
342         {
343             response.setStatus(statusCode);
344             if (statusCode != HttpServletResponse.SC_OK)
345                 response.setHeader(HttpHeader.CONNECTION.asString(), HttpHeaderValue.CLOSE.asString());
346             response.getOutputStream().close();
347             if (LOG.isDebugEnabled())
348                 LOG.debug("CONNECT response sent {} {}", request.getProtocol(), response.getStatus());
349         }
350         catch (IOException x)
351         {
352             // TODO: nothing we can do, close the connection
353         }
354     }
355 
356     /**
357      * <p>Handles the authentication before setting up the tunnel to the remote server.</p>
358      * <p>The default implementation returns true.</p>
359      *
360      * @param request  the HTTP request
361      * @param response the HTTP response
362      * @param address  the address of the remote server in the form {@code host:port}.
363      * @return true to allow to connect to the remote host, false otherwise
364      */
365     protected boolean handleAuthentication(HttpServletRequest request, HttpServletResponse response, String address)
366     {
367         return true;
368     }
369 
370     protected DownstreamConnection newDownstreamConnection(EndPoint endPoint, ConcurrentMap<String, Object> context, ByteBuffer buffer)
371     {
372         return new DownstreamConnection(endPoint, getExecutor(), getByteBufferPool(), context, buffer);
373     }
374 
375     protected UpstreamConnection newUpstreamConnection(EndPoint endPoint, ConnectContext connectContext)
376     {
377         return new UpstreamConnection(endPoint, getExecutor(), getByteBufferPool(), connectContext);
378     }
379 
380     protected void prepareContext(HttpServletRequest request, ConcurrentMap<String, Object> context)
381     {
382     }
383 
384     private void upgradeConnection(HttpServletRequest request, HttpServletResponse response, Connection connection)
385     {
386         // Set the new connection as request attribute and change the status to 101
387         // so that Jetty understands that it has to upgrade the connection
388         request.setAttribute(HttpConnection.UPGRADE_CONNECTION_ATTRIBUTE, connection);
389         response.setStatus(HttpServletResponse.SC_SWITCHING_PROTOCOLS);
390         if (LOG.isDebugEnabled())
391             LOG.debug("Upgraded connection to {}", connection);
392     }
393 
394     /**
395      * <p>Reads (with non-blocking semantic) into the given {@code buffer} from the given {@code endPoint}.</p>
396      *
397      * @param endPoint the endPoint to read from
398      * @param buffer   the buffer to read data into
399      * @return the number of bytes read (possibly 0 since the read is non-blocking)
400      *         or -1 if the channel has been closed remotely
401      * @throws IOException if the endPoint cannot be read
402      */
403     protected int read(EndPoint endPoint, ByteBuffer buffer) throws IOException
404     {
405         return endPoint.fill(buffer);
406     }
407 
408     /**
409      * <p>Writes (with non-blocking semantic) the given buffer of data onto the given endPoint.</p>
410      *
411      * @param endPoint the endPoint to write to
412      * @param buffer   the buffer to write
413      * @param callback the completion callback to invoke
414      */
415     protected void write(EndPoint endPoint, ByteBuffer buffer, Callback callback)
416     {
417         if (LOG.isDebugEnabled())
418             LOG.debug("{} writing {} bytes", this, buffer.remaining());
419         endPoint.write(callback, buffer);
420     }
421 
422     public Set<String> getWhiteListHosts()
423     {
424         return whiteList;
425     }
426 
427     public Set<String> getBlackListHosts()
428     {
429         return blackList;
430     }
431 
432     /**
433      * Checks the given {@code host} and {@code port} against whitelist and blacklist.
434      *
435      * @param host the host to check
436      * @param port the port to check
437      * @return true if it is allowed to connect to the given host and port
438      */
439     public boolean validateDestination(String host, int port)
440     {
441         String hostPort = host + ":" + port;
442         if (!whiteList.isEmpty())
443         {
444             if (!whiteList.contains(hostPort))
445             {
446                 if (LOG.isDebugEnabled())
447                     LOG.debug("Host {}:{} not whitelisted", host, port);
448                 return false;
449             }
450         }
451         if (!blackList.isEmpty())
452         {
453             if (blackList.contains(hostPort))
454             {
455                 if (LOG.isDebugEnabled())
456                     LOG.debug("Host {}:{} blacklisted", host, port);
457                 return false;
458             }
459         }
460         return true;
461     }
462 
463     @Override
464     public void dump(Appendable out, String indent) throws IOException
465     {
466         dumpThis(out);
467         dump(out, indent, getBeans(), TypeUtil.asList(getHandlers()));
468     }
469 
470     protected class ConnectManager extends SelectorManager
471     {
472         protected ConnectManager(Executor executor, Scheduler scheduler, int selectors)
473         {
474             super(executor, scheduler, selectors);
475         }
476 
477         @Override
478         protected EndPoint newEndPoint(SocketChannel channel, ManagedSelector selector, SelectionKey selectionKey) throws IOException
479         {
480             return new SelectChannelEndPoint(channel, selector, selectionKey, getScheduler(), getIdleTimeout());
481         }
482 
483         @Override
484         public Connection newConnection(SocketChannel channel, EndPoint endpoint, Object attachment) throws IOException
485         {
486             if (ConnectHandler.LOG.isDebugEnabled())
487                 ConnectHandler.LOG.debug("Connected to {}", channel.getRemoteAddress());
488             ConnectContext connectContext = (ConnectContext)attachment;
489             UpstreamConnection connection = newUpstreamConnection(endpoint, connectContext);
490             connection.setInputBufferSize(getBufferSize());
491             return connection;
492         }
493 
494         @Override
495         protected void connectionFailed(SocketChannel channel, final Throwable ex, final Object attachment)
496         {
497             getExecutor().execute(new Runnable()
498             {
499                 public void run()
500                 {
501                     ConnectContext connectContext = (ConnectContext)attachment;
502                     onConnectFailure(connectContext.request, connectContext.response, connectContext.asyncContext, ex);
503                 }
504             });
505         }
506     }
507 
508     protected static class ConnectContext
509     {
510         private final ConcurrentMap<String, Object> context = new ConcurrentHashMap<>();
511         private final HttpServletRequest request;
512         private final HttpServletResponse response;
513         private final AsyncContext asyncContext;
514         private final HttpConnection httpConnection;
515 
516         public ConnectContext(HttpServletRequest request, HttpServletResponse response, AsyncContext asyncContext, HttpConnection httpConnection)
517         {
518             this.request = request;
519             this.response = response;
520             this.asyncContext = asyncContext;
521             this.httpConnection = httpConnection;
522         }
523 
524         public ConcurrentMap<String, Object> getContext()
525         {
526             return context;
527         }
528 
529         public HttpServletRequest getRequest()
530         {
531             return request;
532         }
533 
534         public HttpServletResponse getResponse()
535         {
536             return response;
537         }
538 
539         public AsyncContext getAsyncContext()
540         {
541             return asyncContext;
542         }
543 
544         public HttpConnection getHttpConnection()
545         {
546             return httpConnection;
547         }
548     }
549 
550     public class UpstreamConnection extends ProxyConnection
551     {
552         private ConnectContext connectContext;
553 
554         public UpstreamConnection(EndPoint endPoint, Executor executor, ByteBufferPool bufferPool, ConnectContext connectContext)
555         {
556             super(endPoint, executor, bufferPool, connectContext.getContext());
557             this.connectContext = connectContext;
558         }
559 
560         @Override
561         public void onOpen()
562         {
563             super.onOpen();
564             getExecutor().execute(new Runnable()
565             {
566                 public void run()
567                 {
568                     onConnectSuccess(connectContext, UpstreamConnection.this);
569                     fillInterested();
570                 }
571             });
572         }
573 
574         @Override
575         protected int read(EndPoint endPoint, ByteBuffer buffer) throws IOException
576         {
577             return ConnectHandler.this.read(endPoint, buffer);
578         }
579 
580         @Override
581         protected void write(EndPoint endPoint, ByteBuffer buffer,Callback callback)
582         {
583             ConnectHandler.this.write(endPoint, buffer, callback);
584         }
585     }
586 
587     public class DownstreamConnection extends ProxyConnection
588     {
589         private final ByteBuffer buffer;
590 
591         public DownstreamConnection(EndPoint endPoint, Executor executor, ByteBufferPool bufferPool, ConcurrentMap<String, Object> context, ByteBuffer buffer)
592         {
593             super(endPoint, executor, bufferPool, context);
594             this.buffer = buffer;
595         }
596 
597         @Override
598         public void onOpen()
599         {
600             super.onOpen();
601             final int remaining = buffer.remaining();
602             write(getConnection().getEndPoint(), buffer, new Callback()
603             {
604                 @Override
605                 public void succeeded()
606                 {
607                     if (LOG.isDebugEnabled())
608                         LOG.debug("{} wrote initial {} bytes to server", DownstreamConnection.this, remaining);
609                     fillInterested();
610                 }
611 
612                 @Override
613                 public void failed(Throwable x)
614                 {
615                     if (LOG.isDebugEnabled())
616                         LOG.debug(this + " failed to write initial " + remaining + " bytes to server", x);
617                     close();
618                     getConnection().close();
619                 }
620             });
621         }
622 
623         @Override
624         protected int read(EndPoint endPoint, ByteBuffer buffer) throws IOException
625         {
626             return ConnectHandler.this.read(endPoint, buffer);
627         }
628 
629         @Override
630         protected void write(EndPoint endPoint, ByteBuffer buffer, Callback callback)
631         {
632             ConnectHandler.this.write(endPoint, buffer, callback);
633         }
634     }
635 }