View Javadoc

1   //
2   //  ========================================================================
3   //  Copyright (c) 1995-2014 Mort Bay Consulting Pty. Ltd.
4   //  ------------------------------------------------------------------------
5   //  All rights reserved. This program and the accompanying materials
6   //  are made available under the terms of the Eclipse Public License v1.0
7   //  and Apache License v2.0 which accompanies this distribution.
8   //
9   //      The Eclipse Public License is available at
10  //      http://www.eclipse.org/legal/epl-v10.html
11  //
12  //      The Apache License v2.0 is available at
13  //      http://www.opensource.org/licenses/apache2.0.php
14  //
15  //  You may elect to redistribute this code under either of these licenses.
16  //  ========================================================================
17  //
18  
19  package org.eclipse.jetty.proxy;
20  
21  import java.io.IOException;
22  import java.net.InetSocketAddress;
23  import java.nio.ByteBuffer;
24  import java.nio.channels.SelectionKey;
25  import java.nio.channels.SocketChannel;
26  import java.util.HashSet;
27  import java.util.Set;
28  import java.util.concurrent.ConcurrentHashMap;
29  import java.util.concurrent.ConcurrentMap;
30  import java.util.concurrent.Executor;
31  import javax.servlet.AsyncContext;
32  import javax.servlet.ServletException;
33  import javax.servlet.http.HttpServletRequest;
34  import javax.servlet.http.HttpServletResponse;
35  
36  import org.eclipse.jetty.http.HttpHeader;
37  import org.eclipse.jetty.http.HttpHeaderValue;
38  import org.eclipse.jetty.http.HttpMethod;
39  import org.eclipse.jetty.io.ByteBufferPool;
40  import org.eclipse.jetty.io.Connection;
41  import org.eclipse.jetty.io.EndPoint;
42  import org.eclipse.jetty.io.MappedByteBufferPool;
43  import org.eclipse.jetty.io.SelectChannelEndPoint;
44  import org.eclipse.jetty.io.SelectorManager;
45  import org.eclipse.jetty.server.Handler;
46  import org.eclipse.jetty.server.HttpConnection;
47  import org.eclipse.jetty.server.Request;
48  import org.eclipse.jetty.server.handler.HandlerWrapper;
49  import org.eclipse.jetty.util.BufferUtil;
50  import org.eclipse.jetty.util.Callback;
51  import org.eclipse.jetty.util.TypeUtil;
52  import org.eclipse.jetty.util.log.Log;
53  import org.eclipse.jetty.util.log.Logger;
54  import org.eclipse.jetty.util.thread.ScheduledExecutorScheduler;
55  import org.eclipse.jetty.util.thread.Scheduler;
56  
57  /**
58   * <p>Implementation of a {@link Handler} that supports HTTP CONNECT.</p>
59   */
60  public class ConnectHandler extends HandlerWrapper
61  {
62      protected static final Logger LOG = Log.getLogger(ConnectHandler.class);
63  
64      private final Set<String> whiteList = new HashSet<>();
65      private final Set<String> blackList = new HashSet<>();
66      private Executor executor;
67      private Scheduler scheduler;
68      private ByteBufferPool bufferPool;
69      private SelectorManager selector;
70      private long connectTimeout = 15000;
71      private long idleTimeout = 30000;
72      private int bufferSize = 4096;
73  
74      public ConnectHandler()
75      {
76          this(null);
77      }
78  
79      public ConnectHandler(Handler handler)
80      {
81          setHandler(handler);
82      }
83  
84      public Executor getExecutor()
85      {
86          return executor;
87      }
88  
89      public void setExecutor(Executor executor)
90      {
91          this.executor = executor;
92      }
93  
94      public Scheduler getScheduler()
95      {
96          return scheduler;
97      }
98  
99      public void setScheduler(Scheduler scheduler)
100     {
101         this.scheduler = scheduler;
102     }
103 
104     public ByteBufferPool getByteBufferPool()
105     {
106         return bufferPool;
107     }
108 
109     public void setByteBufferPool(ByteBufferPool bufferPool)
110     {
111         this.bufferPool = bufferPool;
112     }
113 
114     /**
115      * @return the timeout, in milliseconds, to connect to the remote server
116      */
117     public long getConnectTimeout()
118     {
119         return connectTimeout;
120     }
121 
122     /**
123      * @param connectTimeout the timeout, in milliseconds, to connect to the remote server
124      */
125     public void setConnectTimeout(long connectTimeout)
126     {
127         this.connectTimeout = connectTimeout;
128     }
129 
130     /**
131      * @return the idle timeout, in milliseconds
132      */
133     public long getIdleTimeout()
134     {
135         return idleTimeout;
136     }
137 
138     /**
139      * @param idleTimeout the idle timeout, in milliseconds
140      */
141     public void setIdleTimeout(long idleTimeout)
142     {
143         this.idleTimeout = idleTimeout;
144     }
145 
146     public int getBufferSize()
147     {
148         return bufferSize;
149     }
150 
151     public void setBufferSize(int bufferSize)
152     {
153         this.bufferSize = bufferSize;
154     }
155 
156     @Override
157     protected void doStart() throws Exception
158     {
159         if (executor == null)
160         {
161             setExecutor(getServer().getThreadPool());
162         }
163         if (scheduler == null)
164         {
165             setScheduler(new ScheduledExecutorScheduler());
166             addBean(getScheduler());
167         }
168         if (bufferPool == null)
169         {
170             setByteBufferPool(new MappedByteBufferPool());
171             addBean(getByteBufferPool());
172         }
173         addBean(selector = newSelectorManager());
174         selector.setConnectTimeout(getConnectTimeout());
175         super.doStart();
176     }
177 
178     protected SelectorManager newSelectorManager()
179     {
180         return new ConnectManager(getExecutor(), getScheduler(), 1);
181     }
182 
183     @Override
184     public void handle(String target, Request baseRequest, HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException
185     {
186         if (HttpMethod.CONNECT.is(request.getMethod()))
187         {
188             String serverAddress = request.getRequestURI();
189             if (LOG.isDebugEnabled())
190                 LOG.debug("CONNECT request for {}", serverAddress);
191             try
192             {
193                 handleConnect(baseRequest, request, response, serverAddress);
194             }
195             catch (Exception x)
196             {
197                 // TODO
198                 LOG.warn("ConnectHandler " + baseRequest.getUri() + " " + x);
199                 LOG.debug(x);
200             }
201         }
202         else
203         {
204             super.handle(target, baseRequest, request, response);
205         }
206     }
207 
208     /**
209      * <p>Handles a CONNECT request.</p>
210      * <p>CONNECT requests may have authentication headers such as {@code Proxy-Authorization}
211      * that authenticate the client with the proxy.</p>
212      *
213      * @param jettyRequest  Jetty-specific http request
214      * @param request       the http request
215      * @param response      the http response
216      * @param serverAddress the remote server address in the form {@code host:port}
217      */
218     protected void handleConnect(Request jettyRequest, HttpServletRequest request, HttpServletResponse response, String serverAddress)
219     {
220         jettyRequest.setHandled(true);
221         try
222         {
223             boolean proceed = handleAuthentication(request, response, serverAddress);
224             if (!proceed)
225             {
226                 if (LOG.isDebugEnabled())
227                     LOG.debug("Missing proxy authentication");
228                 sendConnectResponse(request, response, HttpServletResponse.SC_PROXY_AUTHENTICATION_REQUIRED);
229                 return;
230             }
231 
232             String host = serverAddress;
233             int port = 80;
234             int colon = serverAddress.indexOf(':');
235             if (colon > 0)
236             {
237                 host = serverAddress.substring(0, colon);
238                 port = Integer.parseInt(serverAddress.substring(colon + 1));
239             }
240 
241             if (!validateDestination(host, port))
242             {
243                 if (LOG.isDebugEnabled())
244                     LOG.debug("Destination {}:{} forbidden", host, port);
245                 sendConnectResponse(request, response, HttpServletResponse.SC_FORBIDDEN);
246                 return;
247             }
248 
249             SocketChannel channel = SocketChannel.open();
250             channel.socket().setTcpNoDelay(true);
251             channel.configureBlocking(false);
252             InetSocketAddress address = new InetSocketAddress(host, port);
253 
254             AsyncContext asyncContext = request.startAsync();
255             asyncContext.setTimeout(0);
256 
257             if (LOG.isDebugEnabled())
258                 LOG.debug("Connecting to {}", address);
259 
260             ConnectContext connectContext = new ConnectContext(request, response, asyncContext, HttpConnection.getCurrentConnection());
261             if (channel.connect(address))
262                 selector.accept(channel, connectContext);
263             else
264                 selector.connect(channel, connectContext);
265         }
266         catch (Exception x)
267         {
268             onConnectFailure(request, response, null, x);
269         }
270     }
271 
272     protected void onConnectSuccess(ConnectContext connectContext, UpstreamConnection upstreamConnection)
273     {
274         HttpConnection httpConnection = connectContext.getHttpConnection();
275         ByteBuffer requestBuffer = httpConnection.getRequestBuffer();
276         ByteBuffer buffer = BufferUtil.EMPTY_BUFFER;
277         int remaining = requestBuffer.remaining();
278         if (remaining > 0)
279         {
280             buffer = bufferPool.acquire(remaining, requestBuffer.isDirect());
281             BufferUtil.flipToFill(buffer);
282             buffer.put(requestBuffer);
283             buffer.flip();
284         }
285 
286         ConcurrentMap<String, Object> context = connectContext.getContext();
287         HttpServletRequest request = connectContext.getRequest();
288         prepareContext(request, context);
289 
290         EndPoint downstreamEndPoint = httpConnection.getEndPoint();
291         DownstreamConnection downstreamConnection = newDownstreamConnection(downstreamEndPoint, context, buffer);
292         downstreamConnection.setInputBufferSize(getBufferSize());
293 
294         upstreamConnection.setConnection(downstreamConnection);
295         downstreamConnection.setConnection(upstreamConnection);
296         if (LOG.isDebugEnabled())
297             LOG.debug("Connection setup completed: {}<->{}", downstreamConnection, upstreamConnection);
298 
299         HttpServletResponse response = connectContext.getResponse();
300         sendConnectResponse(request, response, HttpServletResponse.SC_OK);
301 
302         upgradeConnection(request, response, downstreamConnection);
303         connectContext.getAsyncContext().complete();
304     }
305 
306     protected void onConnectFailure(HttpServletRequest request, HttpServletResponse response, AsyncContext asyncContext, Throwable failure)
307     {
308         if (LOG.isDebugEnabled())
309             LOG.debug("CONNECT failed", failure);
310         sendConnectResponse(request, response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
311         if (asyncContext != null)
312             asyncContext.complete();
313     }
314 
315     private void sendConnectResponse(HttpServletRequest request, HttpServletResponse response, int statusCode)
316     {
317         try
318         {
319             response.setStatus(statusCode);
320             if (statusCode != HttpServletResponse.SC_OK)
321                 response.setHeader(HttpHeader.CONNECTION.asString(), HttpHeaderValue.CLOSE.asString());
322             response.getOutputStream().close();
323             if (LOG.isDebugEnabled())
324                 LOG.debug("CONNECT response sent {} {}", request.getProtocol(), response.getStatus());
325         }
326         catch (IOException x)
327         {
328             // TODO: nothing we can do, close the connection
329         }
330     }
331 
332     /**
333      * <p>Handles the authentication before setting up the tunnel to the remote server.</p>
334      * <p>The default implementation returns true.</p>
335      *
336      * @param request  the HTTP request
337      * @param response the HTTP response
338      * @param address  the address of the remote server in the form {@code host:port}.
339      * @return true to allow to connect to the remote host, false otherwise
340      */
341     protected boolean handleAuthentication(HttpServletRequest request, HttpServletResponse response, String address)
342     {
343         return true;
344     }
345 
346     protected DownstreamConnection newDownstreamConnection(EndPoint endPoint, ConcurrentMap<String, Object> context, ByteBuffer buffer)
347     {
348         return new DownstreamConnection(endPoint, getExecutor(), getByteBufferPool(), context, buffer);
349     }
350 
351     protected UpstreamConnection newUpstreamConnection(EndPoint endPoint, ConnectContext connectContext)
352     {
353         return new UpstreamConnection(endPoint, getExecutor(), getByteBufferPool(), connectContext);
354     }
355 
356     protected void prepareContext(HttpServletRequest request, ConcurrentMap<String, Object> context)
357     {
358     }
359 
360     private void upgradeConnection(HttpServletRequest request, HttpServletResponse response, Connection connection)
361     {
362         // Set the new connection as request attribute and change the status to 101
363         // so that Jetty understands that it has to upgrade the connection
364         request.setAttribute(HttpConnection.UPGRADE_CONNECTION_ATTRIBUTE, connection);
365         response.setStatus(HttpServletResponse.SC_SWITCHING_PROTOCOLS);
366         if (LOG.isDebugEnabled())
367             LOG.debug("Upgraded connection to {}", connection);
368     }
369 
370     /**
371      * <p>Reads (with non-blocking semantic) into the given {@code buffer} from the given {@code endPoint}.</p>
372      *
373      * @param endPoint the endPoint to read from
374      * @param buffer   the buffer to read data into
375      * @return the number of bytes read (possibly 0 since the read is non-blocking)
376      *         or -1 if the channel has been closed remotely
377      * @throws IOException if the endPoint cannot be read
378      */
379     protected int read(EndPoint endPoint, ByteBuffer buffer) throws IOException
380     {
381         return endPoint.fill(buffer);
382     }
383 
384     /**
385      * <p>Writes (with non-blocking semantic) the given buffer of data onto the given endPoint.</p>
386      *
387      * @param endPoint the endPoint to write to
388      * @param buffer   the buffer to write
389      * @param callback the completion callback to invoke
390      */
391     protected void write(EndPoint endPoint, ByteBuffer buffer, Callback callback)
392     {
393         if (LOG.isDebugEnabled())
394             LOG.debug("{} writing {} bytes", this, buffer.remaining());
395         endPoint.write(callback, buffer);
396     }
397 
398     public Set<String> getWhiteListHosts()
399     {
400         return whiteList;
401     }
402 
403     public Set<String> getBlackListHosts()
404     {
405         return blackList;
406     }
407 
408     /**
409      * Checks the given {@code host} and {@code port} against whitelist and blacklist.
410      *
411      * @param host the host to check
412      * @param port the port to check
413      * @return true if it is allowed to connect to the given host and port
414      */
415     public boolean validateDestination(String host, int port)
416     {
417         String hostPort = host + ":" + port;
418         if (!whiteList.isEmpty())
419         {
420             if (!whiteList.contains(hostPort))
421             {
422                 if (LOG.isDebugEnabled())
423                     LOG.debug("Host {}:{} not whitelisted", host, port);
424                 return false;
425             }
426         }
427         if (!blackList.isEmpty())
428         {
429             if (blackList.contains(hostPort))
430             {
431                 if (LOG.isDebugEnabled())
432                     LOG.debug("Host {}:{} blacklisted", host, port);
433                 return false;
434             }
435         }
436         return true;
437     }
438 
439     @Override
440     public void dump(Appendable out, String indent) throws IOException
441     {
442         dumpThis(out);
443         dump(out, indent, getBeans(), TypeUtil.asList(getHandlers()));
444     }
445 
446     protected class ConnectManager extends SelectorManager
447     {
448         protected ConnectManager(Executor executor, Scheduler scheduler, int selectors)
449         {
450             super(executor, scheduler, selectors);
451         }
452 
453         @Override
454         protected EndPoint newEndPoint(SocketChannel channel, ManagedSelector selector, SelectionKey selectionKey) throws IOException
455         {
456             return new SelectChannelEndPoint(channel, selector, selectionKey, getScheduler(), getIdleTimeout());
457         }
458 
459         @Override
460         public Connection newConnection(SocketChannel channel, EndPoint endpoint, Object attachment) throws IOException
461         {
462             if (ConnectHandler.LOG.isDebugEnabled())
463                 ConnectHandler.LOG.debug("Connected to {}", channel.getRemoteAddress());
464             ConnectContext connectContext = (ConnectContext)attachment;
465             UpstreamConnection connection = newUpstreamConnection(endpoint, connectContext);
466             connection.setInputBufferSize(getBufferSize());
467             return connection;
468         }
469 
470         @Override
471         protected void connectionFailed(SocketChannel channel, final Throwable ex, final Object attachment)
472         {
473             getExecutor().execute(new Runnable()
474             {
475                 public void run()
476                 {
477                     ConnectContext connectContext = (ConnectContext)attachment;
478                     onConnectFailure(connectContext.request, connectContext.response, connectContext.asyncContext, ex);
479                 }
480             });
481         }
482     }
483 
484     protected static class ConnectContext
485     {
486         private final ConcurrentMap<String, Object> context = new ConcurrentHashMap<>();
487         private final HttpServletRequest request;
488         private final HttpServletResponse response;
489         private final AsyncContext asyncContext;
490         private final HttpConnection httpConnection;
491 
492         public ConnectContext(HttpServletRequest request, HttpServletResponse response, AsyncContext asyncContext, HttpConnection httpConnection)
493         {
494             this.request = request;
495             this.response = response;
496             this.asyncContext = asyncContext;
497             this.httpConnection = httpConnection;
498         }
499 
500         public ConcurrentMap<String, Object> getContext()
501         {
502             return context;
503         }
504 
505         public HttpServletRequest getRequest()
506         {
507             return request;
508         }
509 
510         public HttpServletResponse getResponse()
511         {
512             return response;
513         }
514 
515         public AsyncContext getAsyncContext()
516         {
517             return asyncContext;
518         }
519 
520         public HttpConnection getHttpConnection()
521         {
522             return httpConnection;
523         }
524     }
525 
526     public class UpstreamConnection extends ProxyConnection
527     {
528         private ConnectContext connectContext;
529 
530         public UpstreamConnection(EndPoint endPoint, Executor executor, ByteBufferPool bufferPool, ConnectContext connectContext)
531         {
532             super(endPoint, executor, bufferPool, connectContext.getContext());
533             this.connectContext = connectContext;
534         }
535 
536         @Override
537         public void onOpen()
538         {
539             super.onOpen();
540             getExecutor().execute(new Runnable()
541             {
542                 public void run()
543                 {
544                     onConnectSuccess(connectContext, UpstreamConnection.this);
545                     fillInterested();
546                 }
547             });
548         }
549 
550         @Override
551         protected int read(EndPoint endPoint, ByteBuffer buffer) throws IOException
552         {
553             return ConnectHandler.this.read(endPoint, buffer);
554         }
555 
556         @Override
557         protected void write(EndPoint endPoint, ByteBuffer buffer,Callback callback)
558         {
559             ConnectHandler.this.write(endPoint, buffer, callback);
560         }
561     }
562 
563     public class DownstreamConnection extends ProxyConnection
564     {
565         private final ByteBuffer buffer;
566 
567         public DownstreamConnection(EndPoint endPoint, Executor executor, ByteBufferPool bufferPool, ConcurrentMap<String, Object> context, ByteBuffer buffer)
568         {
569             super(endPoint, executor, bufferPool, context);
570             this.buffer = buffer;
571         }
572 
573         @Override
574         public void onOpen()
575         {
576             super.onOpen();
577             final int remaining = buffer.remaining();
578             write(getConnection().getEndPoint(), buffer, new Callback()
579             {
580                 @Override
581                 public void succeeded()
582                 {
583                     if (LOG.isDebugEnabled())
584                         LOG.debug("{} wrote initial {} bytes to server", DownstreamConnection.this, remaining);
585                     fillInterested();
586                 }
587 
588                 @Override
589                 public void failed(Throwable x)
590                 {
591                     if (LOG.isDebugEnabled())
592                         LOG.debug(this + " failed to write initial " + remaining + " bytes to server", x);
593                     close();
594                     getConnection().close();
595                 }
596             });
597         }
598 
599         @Override
600         protected int read(EndPoint endPoint, ByteBuffer buffer) throws IOException
601         {
602             return ConnectHandler.this.read(endPoint, buffer);
603         }
604 
605         @Override
606         protected void write(EndPoint endPoint, ByteBuffer buffer, Callback callback)
607         {
608             ConnectHandler.this.write(endPoint, buffer, callback);
609         }
610     }
611 }