View Javadoc

1   //
2   //  ========================================================================
3   //  Copyright (c) 1995-2016 Mort Bay Consulting Pty. Ltd.
4   //  ------------------------------------------------------------------------
5   //  All rights reserved. This program and the accompanying materials
6   //  are made available under the terms of the Eclipse Public License v1.0
7   //  and Apache License v2.0 which accompanies this distribution.
8   //
9   //      The Eclipse Public License is available at
10  //      http://www.eclipse.org/legal/epl-v10.html
11  //
12  //      The Apache License v2.0 is available at
13  //      http://www.opensource.org/licenses/apache2.0.php
14  //
15  //  You may elect to redistribute this code under either of these licenses.
16  //  ========================================================================
17  //
18  
19  package org.eclipse.jetty.proxy;
20  
21  import java.io.Closeable;
22  import java.io.IOException;
23  import java.net.InetSocketAddress;
24  import java.nio.ByteBuffer;
25  import java.nio.channels.SelectionKey;
26  import java.nio.channels.SocketChannel;
27  import java.util.HashSet;
28  import java.util.Set;
29  import java.util.concurrent.ConcurrentHashMap;
30  import java.util.concurrent.ConcurrentMap;
31  import java.util.concurrent.Executor;
32  
33  import javax.servlet.AsyncContext;
34  import javax.servlet.ServletException;
35  import javax.servlet.http.HttpServletRequest;
36  import javax.servlet.http.HttpServletResponse;
37  
38  import org.eclipse.jetty.http.HttpHeader;
39  import org.eclipse.jetty.http.HttpHeaderValue;
40  import org.eclipse.jetty.http.HttpMethod;
41  import org.eclipse.jetty.io.ByteBufferPool;
42  import org.eclipse.jetty.io.Connection;
43  import org.eclipse.jetty.io.EndPoint;
44  import org.eclipse.jetty.io.ManagedSelector;
45  import org.eclipse.jetty.io.MappedByteBufferPool;
46  import org.eclipse.jetty.io.SelectChannelEndPoint;
47  import org.eclipse.jetty.io.SelectorManager;
48  import org.eclipse.jetty.server.Handler;
49  import org.eclipse.jetty.server.HttpConnection;
50  import org.eclipse.jetty.server.HttpTransport;
51  import org.eclipse.jetty.server.Request;
52  import org.eclipse.jetty.server.handler.HandlerWrapper;
53  import org.eclipse.jetty.util.BufferUtil;
54  import org.eclipse.jetty.util.Callback;
55  import org.eclipse.jetty.util.Promise;
56  import org.eclipse.jetty.util.TypeUtil;
57  import org.eclipse.jetty.util.log.Log;
58  import org.eclipse.jetty.util.log.Logger;
59  import org.eclipse.jetty.util.thread.ScheduledExecutorScheduler;
60  import org.eclipse.jetty.util.thread.Scheduler;
61  
62  /**
63   * <p>Implementation of a {@link Handler} that supports HTTP CONNECT.</p>
64   */
65  public class ConnectHandler extends HandlerWrapper
66  {
67      protected static final Logger LOG = Log.getLogger(ConnectHandler.class);
68  
69      private final Set<String> whiteList = new HashSet<>();
70      private final Set<String> blackList = new HashSet<>();
71      private Executor executor;
72      private Scheduler scheduler;
73      private ByteBufferPool bufferPool;
74      private SelectorManager selector;
75      private long connectTimeout = 15000;
76      private long idleTimeout = 30000;
77      private int bufferSize = 4096;
78  
79      public ConnectHandler()
80      {
81          this(null);
82      }
83  
84      public ConnectHandler(Handler handler)
85      {
86          setHandler(handler);
87      }
88  
89      public Executor getExecutor()
90      {
91          return executor;
92      }
93  
94      public void setExecutor(Executor executor)
95      {
96          this.executor = executor;
97      }
98  
99      public Scheduler getScheduler()
100     {
101         return scheduler;
102     }
103 
104     public void setScheduler(Scheduler scheduler)
105     {
106         this.scheduler = scheduler;
107     }
108 
109     public ByteBufferPool getByteBufferPool()
110     {
111         return bufferPool;
112     }
113 
114     public void setByteBufferPool(ByteBufferPool bufferPool)
115     {
116         this.bufferPool = bufferPool;
117     }
118 
119     /**
120      * @return the timeout, in milliseconds, to connect to the remote server
121      */
122     public long getConnectTimeout()
123     {
124         return connectTimeout;
125     }
126 
127     /**
128      * @param connectTimeout the timeout, in milliseconds, to connect to the remote server
129      */
130     public void setConnectTimeout(long connectTimeout)
131     {
132         this.connectTimeout = connectTimeout;
133     }
134 
135     /**
136      * @return the idle timeout, in milliseconds
137      */
138     public long getIdleTimeout()
139     {
140         return idleTimeout;
141     }
142 
143     /**
144      * @param idleTimeout the idle timeout, in milliseconds
145      */
146     public void setIdleTimeout(long idleTimeout)
147     {
148         this.idleTimeout = idleTimeout;
149     }
150 
151     public int getBufferSize()
152     {
153         return bufferSize;
154     }
155 
156     public void setBufferSize(int bufferSize)
157     {
158         this.bufferSize = bufferSize;
159     }
160 
161     @Override
162     protected void doStart() throws Exception
163     {
164         if (executor == null)
165             executor = getServer().getThreadPool();
166 
167         if (scheduler == null)
168             addBean(scheduler = new ScheduledExecutorScheduler());
169 
170         if (bufferPool == null)
171             addBean(bufferPool = new MappedByteBufferPool());
172 
173         addBean(selector = newSelectorManager());
174         selector.setConnectTimeout(getConnectTimeout());
175 
176         super.doStart();
177     }
178 
179     protected SelectorManager newSelectorManager()
180     {
181         return new ConnectManager(getExecutor(), getScheduler(), 1);
182     }
183 
184     @Override
185     public void handle(String target, Request baseRequest, HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException
186     {
187         if (HttpMethod.CONNECT.is(request.getMethod()))
188         {
189             String serverAddress = request.getRequestURI();
190             if (LOG.isDebugEnabled())
191                 LOG.debug("CONNECT request for {}", serverAddress);
192 
193             handleConnect(baseRequest, request, response, serverAddress);
194         }
195         else
196         {
197             super.handle(target, baseRequest, request, response);
198         }
199     }
200 
201     /**
202      * <p>Handles a CONNECT request.</p>
203      * <p>CONNECT requests may have authentication headers such as {@code Proxy-Authorization}
204      * that authenticate the client with the proxy.</p>
205      *
206      * @param baseRequest  Jetty-specific http request
207      * @param request       the http request
208      * @param response      the http response
209      * @param serverAddress the remote server address in the form {@code host:port}
210      */
211     protected void handleConnect(Request baseRequest, HttpServletRequest request, HttpServletResponse response, String serverAddress)
212     {
213         baseRequest.setHandled(true);
214         try
215         {
216             boolean proceed = handleAuthentication(request, response, serverAddress);
217             if (!proceed)
218             {
219                 if (LOG.isDebugEnabled())
220                     LOG.debug("Missing proxy authentication");
221                 sendConnectResponse(request, response, HttpServletResponse.SC_PROXY_AUTHENTICATION_REQUIRED);
222                 return;
223             }
224 
225             String host = serverAddress;
226             int port = 80;
227             int colon = serverAddress.indexOf(':');
228             if (colon > 0)
229             {
230                 host = serverAddress.substring(0, colon);
231                 port = Integer.parseInt(serverAddress.substring(colon + 1));
232             }
233 
234             if (!validateDestination(host, port))
235             {
236                 if (LOG.isDebugEnabled())
237                     LOG.debug("Destination {}:{} forbidden", host, port);
238                 sendConnectResponse(request, response, HttpServletResponse.SC_FORBIDDEN);
239                 return;
240             }
241 
242             HttpTransport transport = baseRequest.getHttpChannel().getHttpTransport();
243             // TODO Handle CONNECT over HTTP2!
244             if (!(transport instanceof HttpConnection))
245             {
246                 if (LOG.isDebugEnabled())
247                     LOG.debug("CONNECT not supported for {}", transport);
248                 sendConnectResponse(request, response, HttpServletResponse.SC_FORBIDDEN);
249                 return;
250             }
251 
252             AsyncContext asyncContext = request.startAsync();
253             asyncContext.setTimeout(0);
254 
255             if (LOG.isDebugEnabled())
256                 LOG.debug("Connecting to {}:{}", host, port);
257 
258             connectToServer(request, host, port, new Promise<SocketChannel>()
259             {
260                 @Override
261                 public void succeeded(SocketChannel channel)
262                 {
263                     ConnectContext connectContext = new ConnectContext(request, response, asyncContext, (HttpConnection)transport);
264                     if (channel.isConnected())
265                         selector.accept(channel, connectContext);
266                     else
267                         selector.connect(channel, connectContext);
268                 }
269 
270                 @Override
271                 public void failed(Throwable x)
272                 {
273                     onConnectFailure(request, response, asyncContext, x);
274                 }
275             });
276         }
277         catch (Exception x)
278         {
279             onConnectFailure(request, response, null, x);
280         }
281     }
282 
283     protected void connectToServer(HttpServletRequest request, String host, int port, Promise<SocketChannel> promise)
284     {
285         SocketChannel channel = null;
286         try
287         {
288             channel = SocketChannel.open();
289             channel.socket().setTcpNoDelay(true);
290             channel.configureBlocking(false);
291             InetSocketAddress address = newConnectAddress(host, port);
292             channel.connect(address);
293             promise.succeeded(channel);
294         }
295         catch (Throwable x)
296         {
297             close(channel);
298             promise.failed(x);
299         }
300     }
301 
302     private void close(Closeable closeable)
303     {
304         try
305         {
306             if (closeable != null)
307                 closeable.close();
308         }
309         catch (Throwable x)
310         {
311             LOG.ignore(x);
312         }
313     }
314 
315     /**
316      * Creates the server address to connect to.
317      *
318      * @param host The host from the CONNECT request
319      * @param port The port from the CONNECT request
320      * @return The InetSocketAddress to connect to.
321      */
322     protected InetSocketAddress newConnectAddress(String host, int port)
323     {
324         return new InetSocketAddress(host, port);
325     }
326 
327     protected void onConnectSuccess(ConnectContext connectContext, UpstreamConnection upstreamConnection)
328     {
329         ConcurrentMap<String, Object> context = connectContext.getContext();
330         HttpServletRequest request = connectContext.getRequest();
331         prepareContext(request, context);
332 
333         HttpConnection httpConnection = connectContext.getHttpConnection();
334         EndPoint downstreamEndPoint = httpConnection.getEndPoint();
335         DownstreamConnection downstreamConnection = newDownstreamConnection(downstreamEndPoint, context, BufferUtil.EMPTY_BUFFER);
336         downstreamConnection.setInputBufferSize(getBufferSize());
337 
338         upstreamConnection.setConnection(downstreamConnection);
339         downstreamConnection.setConnection(upstreamConnection);
340         if (LOG.isDebugEnabled())
341             LOG.debug("Connection setup completed: {}<->{}", downstreamConnection, upstreamConnection);
342 
343         HttpServletResponse response = connectContext.getResponse();
344         sendConnectResponse(request, response, HttpServletResponse.SC_OK);
345 
346         upgradeConnection(request, response, downstreamConnection);
347 
348         connectContext.getAsyncContext().complete();
349     }
350 
351     protected void onConnectFailure(HttpServletRequest request, HttpServletResponse response, AsyncContext asyncContext, Throwable failure)
352     {
353         if (LOG.isDebugEnabled())
354             LOG.debug("CONNECT failed", failure);
355         sendConnectResponse(request, response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
356         if (asyncContext != null)
357             asyncContext.complete();
358     }
359 
360     private void sendConnectResponse(HttpServletRequest request, HttpServletResponse response, int statusCode)
361     {
362         try
363         {
364             response.setStatus(statusCode);
365             if (statusCode != HttpServletResponse.SC_OK)
366                 response.setHeader(HttpHeader.CONNECTION.asString(), HttpHeaderValue.CLOSE.asString());
367             response.getOutputStream().close();
368             if (LOG.isDebugEnabled())
369                 LOG.debug("CONNECT response sent {} {}", request.getProtocol(), response.getStatus());
370         }
371         catch (IOException x)
372         {
373             if (LOG.isDebugEnabled())
374                 LOG.debug("Could not send CONNECT response", x);
375         }
376     }
377 
378     /**
379      * <p>Handles the authentication before setting up the tunnel to the remote server.</p>
380      * <p>The default implementation returns true.</p>
381      *
382      * @param request  the HTTP request
383      * @param response the HTTP response
384      * @param address  the address of the remote server in the form {@code host:port}.
385      * @return true to allow to connect to the remote host, false otherwise
386      */
387     protected boolean handleAuthentication(HttpServletRequest request, HttpServletResponse response, String address)
388     {
389         return true;
390     }
391 
392     /**
393      * @deprecated use {@link #newDownstreamConnection(EndPoint, ConcurrentMap)} instead
394      */
395     @Deprecated
396     protected DownstreamConnection newDownstreamConnection(EndPoint endPoint, ConcurrentMap<String, Object> context, ByteBuffer buffer)
397     {
398         return newDownstreamConnection(endPoint, context);
399     }
400 
401     protected DownstreamConnection newDownstreamConnection(EndPoint endPoint, ConcurrentMap<String, Object> context)
402     {
403         return new DownstreamConnection(endPoint, getExecutor(), getByteBufferPool(), context);
404     }
405 
406     protected UpstreamConnection newUpstreamConnection(EndPoint endPoint, ConnectContext connectContext)
407     {
408         return new UpstreamConnection(endPoint, getExecutor(), getByteBufferPool(), connectContext);
409     }
410 
411     protected void prepareContext(HttpServletRequest request, ConcurrentMap<String, Object> context)
412     {
413     }
414 
415     private void upgradeConnection(HttpServletRequest request, HttpServletResponse response, Connection connection)
416     {
417         // Set the new connection as request attribute and change the status to 101
418         // so that Jetty understands that it has to upgrade the connection
419         request.setAttribute(HttpConnection.UPGRADE_CONNECTION_ATTRIBUTE, connection);
420         response.setStatus(HttpServletResponse.SC_SWITCHING_PROTOCOLS);
421         if (LOG.isDebugEnabled())
422             LOG.debug("Upgraded connection to {}", connection);
423     }
424 
425     /**
426      * <p>Reads (with non-blocking semantic) into the given {@code buffer} from the given {@code endPoint}.</p>
427      *
428      * @param endPoint the endPoint to read from
429      * @param buffer   the buffer to read data into
430      * @param context  the context information related to the connection
431      * @return the number of bytes read (possibly 0 since the read is non-blocking)
432      *         or -1 if the channel has been closed remotely
433      * @throws IOException if the endPoint cannot be read
434      */
435     protected int read(EndPoint endPoint, ByteBuffer buffer, ConcurrentMap<String, Object> context) throws IOException
436     {
437         int read = read(endPoint, buffer);
438         if (LOG.isDebugEnabled())
439             LOG.debug("{} read {} bytes", this, read);
440         return read;
441     }
442 
443     /**
444      * @deprecated override {@link #read(EndPoint, ByteBuffer, ConcurrentMap)} instead
445      */
446     @Deprecated
447     protected int read(EndPoint endPoint, ByteBuffer buffer) throws IOException
448     {
449         return endPoint.fill(buffer);
450     }
451 
452     /**
453      * <p>Writes (with non-blocking semantic) the given buffer of data onto the given endPoint.</p>
454      *
455      * @param endPoint the endPoint to write to
456      * @param buffer   the buffer to write
457      * @param callback the completion callback to invoke
458      * @param context  the context information related to the connection
459      */
460     protected void write(EndPoint endPoint, ByteBuffer buffer, Callback callback, ConcurrentMap<String, Object> context)
461     {
462         if (LOG.isDebugEnabled())
463             LOG.debug("{} writing {} bytes", this, buffer.remaining());
464         write(endPoint, buffer, callback);
465     }
466 
467     /**
468      * @deprecated override {@link #write(EndPoint, ByteBuffer, Callback, ConcurrentMap)} instead
469      */
470     @Deprecated
471     protected void write(EndPoint endPoint, ByteBuffer buffer, Callback callback)
472     {
473         endPoint.write(callback, buffer);
474     }
475 
476     public Set<String> getWhiteListHosts()
477     {
478         return whiteList;
479     }
480 
481     public Set<String> getBlackListHosts()
482     {
483         return blackList;
484     }
485 
486     /**
487      * Checks the given {@code host} and {@code port} against whitelist and blacklist.
488      *
489      * @param host the host to check
490      * @param port the port to check
491      * @return true if it is allowed to connect to the given host and port
492      */
493     public boolean validateDestination(String host, int port)
494     {
495         String hostPort = host + ":" + port;
496         if (!whiteList.isEmpty())
497         {
498             if (!whiteList.contains(hostPort))
499             {
500                 if (LOG.isDebugEnabled())
501                     LOG.debug("Host {}:{} not whitelisted", host, port);
502                 return false;
503             }
504         }
505         if (!blackList.isEmpty())
506         {
507             if (blackList.contains(hostPort))
508             {
509                 if (LOG.isDebugEnabled())
510                     LOG.debug("Host {}:{} blacklisted", host, port);
511                 return false;
512             }
513         }
514         return true;
515     }
516 
517     @Override
518     public void dump(Appendable out, String indent) throws IOException
519     {
520         dumpThis(out);
521         dump(out, indent, getBeans(), TypeUtil.asList(getHandlers()));
522     }
523 
524     protected class ConnectManager extends SelectorManager
525     {
526         protected ConnectManager(Executor executor, Scheduler scheduler, int selectors)
527         {
528             super(executor, scheduler, selectors);
529         }
530 
531         @Override
532         protected EndPoint newEndPoint(SocketChannel channel, ManagedSelector selector, SelectionKey selectionKey) throws IOException
533         {
534             return new SelectChannelEndPoint(channel, selector, selectionKey, getScheduler(), getIdleTimeout());
535         }
536 
537         @Override
538         public Connection newConnection(SocketChannel channel, EndPoint endpoint, Object attachment) throws IOException
539         {
540             if (ConnectHandler.LOG.isDebugEnabled())
541                 ConnectHandler.LOG.debug("Connected to {}", channel.getRemoteAddress());
542             ConnectContext connectContext = (ConnectContext)attachment;
543             UpstreamConnection connection = newUpstreamConnection(endpoint, connectContext);
544             connection.setInputBufferSize(getBufferSize());
545             return connection;
546         }
547 
548         @Override
549         protected void connectionFailed(SocketChannel channel, final Throwable ex, final Object attachment)
550         {
551             close(channel);
552             ConnectContext connectContext = (ConnectContext)attachment;
553             onConnectFailure(connectContext.request, connectContext.response, connectContext.asyncContext, ex);
554         }
555     }
556 
557     protected static class ConnectContext
558     {
559         private final ConcurrentMap<String, Object> context = new ConcurrentHashMap<>();
560         private final HttpServletRequest request;
561         private final HttpServletResponse response;
562         private final AsyncContext asyncContext;
563         private final HttpConnection httpConnection;
564 
565         public ConnectContext(HttpServletRequest request, HttpServletResponse response, AsyncContext asyncContext, HttpConnection httpConnection)
566         {
567             this.request = request;
568             this.response = response;
569             this.asyncContext = asyncContext;
570             this.httpConnection = httpConnection;
571         }
572 
573         public ConcurrentMap<String, Object> getContext()
574         {
575             return context;
576         }
577 
578         public HttpServletRequest getRequest()
579         {
580             return request;
581         }
582 
583         public HttpServletResponse getResponse()
584         {
585             return response;
586         }
587 
588         public AsyncContext getAsyncContext()
589         {
590             return asyncContext;
591         }
592 
593         public HttpConnection getHttpConnection()
594         {
595             return httpConnection;
596         }
597     }
598 
599     public class UpstreamConnection extends ProxyConnection
600     {
601         private ConnectContext connectContext;
602 
603         public UpstreamConnection(EndPoint endPoint, Executor executor, ByteBufferPool bufferPool, ConnectContext connectContext)
604         {
605             super(endPoint, executor, bufferPool, connectContext.getContext());
606             this.connectContext = connectContext;
607         }
608 
609         @Override
610         public void onOpen()
611         {
612             super.onOpen();
613             onConnectSuccess(connectContext, UpstreamConnection.this);
614             fillInterested();
615         }
616 
617         @Override
618         protected int read(EndPoint endPoint, ByteBuffer buffer) throws IOException
619         {
620             return ConnectHandler.this.read(endPoint, buffer, getContext());
621         }
622 
623         @Override
624         protected void write(EndPoint endPoint, ByteBuffer buffer,Callback callback)
625         {
626             ConnectHandler.this.write(endPoint, buffer, callback, getContext());
627         }
628     }
629 
630     public class DownstreamConnection extends ProxyConnection implements Connection.UpgradeTo
631     {
632         private ByteBuffer buffer;
633 
634         public DownstreamConnection(EndPoint endPoint, Executor executor, ByteBufferPool bufferPool, ConcurrentMap<String, Object> context)
635         {
636             super(endPoint, executor, bufferPool, context);
637         }
638 
639         /**
640          * @deprecated use {@link #DownstreamConnection(EndPoint, Executor, ByteBufferPool, ConcurrentMap)} instead
641          */
642         @Deprecated
643         public DownstreamConnection(EndPoint endPoint, Executor executor, ByteBufferPool bufferPool, ConcurrentMap<String, Object> context, ByteBuffer buffer)
644         {
645             this(endPoint, executor, bufferPool, context);
646         }
647 
648         @Override
649         public void onUpgradeTo(ByteBuffer buffer)
650         {
651             this.buffer = buffer == null ? BufferUtil.EMPTY_BUFFER : buffer;
652         }
653 
654         @Override
655         public void onOpen()
656         {
657             super.onOpen();
658             final int remaining = buffer.remaining();
659             write(getConnection().getEndPoint(), buffer, new Callback()
660             {
661                 @Override
662                 public void succeeded()
663                 {
664                     if (LOG.isDebugEnabled())
665                         LOG.debug("{} wrote initial {} bytes to server", DownstreamConnection.this, remaining);
666                     fillInterested();
667                 }
668 
669                 @Override
670                 public void failed(Throwable x)
671                 {
672                     if (LOG.isDebugEnabled())
673                         LOG.debug(this + " failed to write initial " + remaining + " bytes to server", x);
674                     close();
675                     getConnection().close();
676                 }
677             });
678         }
679 
680         @Override
681         protected int read(EndPoint endPoint, ByteBuffer buffer) throws IOException
682         {
683             return ConnectHandler.this.read(endPoint, buffer, getContext());
684         }
685 
686         @Override
687         protected void write(EndPoint endPoint, ByteBuffer buffer, Callback callback)
688         {
689             ConnectHandler.this.write(endPoint, buffer, callback, getContext());
690         }
691     }
692 }