View Javadoc

1   //
2   //  ========================================================================
3   //  Copyright (c) 1995-2013 Mort Bay Consulting Pty. Ltd.
4   //  ------------------------------------------------------------------------
5   //  All rights reserved. This program and the accompanying materials
6   //  are made available under the terms of the Eclipse Public License v1.0
7   //  and Apache License v2.0 which accompanies this distribution.
8   //
9   //      The Eclipse Public License is available at
10  //      http://www.eclipse.org/legal/epl-v10.html
11  //
12  //      The Apache License v2.0 is available at
13  //      http://www.opensource.org/licenses/apache2.0.php
14  //
15  //  You may elect to redistribute this code under either of these licenses.
16  //  ========================================================================
17  //
18  
19  package org.eclipse.jetty.proxy;
20  
21  import java.io.IOException;
22  import java.net.InetSocketAddress;
23  import java.nio.ByteBuffer;
24  import java.nio.channels.SelectionKey;
25  import java.nio.channels.SocketChannel;
26  import java.util.HashSet;
27  import java.util.Set;
28  import java.util.concurrent.ConcurrentHashMap;
29  import java.util.concurrent.ConcurrentMap;
30  import java.util.concurrent.Executor;
31  import javax.servlet.AsyncContext;
32  import javax.servlet.ServletException;
33  import javax.servlet.http.HttpServletRequest;
34  import javax.servlet.http.HttpServletResponse;
35  
36  import org.eclipse.jetty.http.HttpHeader;
37  import org.eclipse.jetty.http.HttpHeaderValue;
38  import org.eclipse.jetty.http.HttpMethod;
39  import org.eclipse.jetty.io.ByteBufferPool;
40  import org.eclipse.jetty.io.Connection;
41  import org.eclipse.jetty.io.EndPoint;
42  import org.eclipse.jetty.io.MappedByteBufferPool;
43  import org.eclipse.jetty.io.SelectChannelEndPoint;
44  import org.eclipse.jetty.io.SelectorManager;
45  import org.eclipse.jetty.server.Handler;
46  import org.eclipse.jetty.server.HttpConnection;
47  import org.eclipse.jetty.server.Request;
48  import org.eclipse.jetty.server.handler.HandlerWrapper;
49  import org.eclipse.jetty.util.BufferUtil;
50  import org.eclipse.jetty.util.Callback;
51  import org.eclipse.jetty.util.TypeUtil;
52  import org.eclipse.jetty.util.log.Log;
53  import org.eclipse.jetty.util.log.Logger;
54  import org.eclipse.jetty.util.thread.ScheduledExecutorScheduler;
55  import org.eclipse.jetty.util.thread.Scheduler;
56  
57  /**
58   * <p>Implementation of a {@link Handler} that supports HTTP CONNECT.</p>
59   */
60  public class ConnectHandler extends HandlerWrapper
61  {
62      protected static final Logger LOG = Log.getLogger(ConnectHandler.class);
63  
64      private final Set<String> whiteList = new HashSet<>();
65      private final Set<String> blackList = new HashSet<>();
66      private Executor executor;
67      private Scheduler scheduler;
68      private ByteBufferPool bufferPool;
69      private SelectorManager selector;
70      private long connectTimeout = 15000;
71      private long idleTimeout = 30000;
72      private int bufferSize = 4096;
73  
74      public ConnectHandler()
75      {
76          this(null);
77      }
78  
79      public ConnectHandler(Handler handler)
80      {
81          setHandler(handler);
82      }
83  
84      public Executor getExecutor()
85      {
86          return executor;
87      }
88  
89      public void setExecutor(Executor executor)
90      {
91          this.executor = executor;
92      }
93  
94      public Scheduler getScheduler()
95      {
96          return scheduler;
97      }
98  
99      public void setScheduler(Scheduler scheduler)
100     {
101         this.scheduler = scheduler;
102     }
103 
104     public ByteBufferPool getByteBufferPool()
105     {
106         return bufferPool;
107     }
108 
109     public void setByteBufferPool(ByteBufferPool bufferPool)
110     {
111         this.bufferPool = bufferPool;
112     }
113 
114     /**
115      * @return the timeout, in milliseconds, to connect to the remote server
116      */
117     public long getConnectTimeout()
118     {
119         return connectTimeout;
120     }
121 
122     /**
123      * @param connectTimeout the timeout, in milliseconds, to connect to the remote server
124      */
125     public void setConnectTimeout(long connectTimeout)
126     {
127         this.connectTimeout = connectTimeout;
128     }
129 
130     /**
131      * @return the idle timeout, in milliseconds
132      */
133     public long getIdleTimeout()
134     {
135         return idleTimeout;
136     }
137 
138     /**
139      * @param idleTimeout the idle timeout, in milliseconds
140      */
141     public void setIdleTimeout(long idleTimeout)
142     {
143         this.idleTimeout = idleTimeout;
144     }
145 
146     public int getBufferSize()
147     {
148         return bufferSize;
149     }
150 
151     public void setBufferSize(int bufferSize)
152     {
153         this.bufferSize = bufferSize;
154     }
155 
156     @Override
157     protected void doStart() throws Exception
158     {
159         if (executor == null)
160         {
161             setExecutor(getServer().getThreadPool());
162         }
163         if (scheduler == null)
164         {
165             setScheduler(new ScheduledExecutorScheduler());
166             addBean(getScheduler());
167         }
168         if (bufferPool == null)
169         {
170             setByteBufferPool(new MappedByteBufferPool());
171             addBean(getByteBufferPool());
172         }
173         addBean(selector = newSelectorManager());
174         selector.setConnectTimeout(getConnectTimeout());
175         super.doStart();
176     }
177 
178     protected SelectorManager newSelectorManager()
179     {
180         return new Manager(getExecutor(), getScheduler(), 1);
181     }
182 
183     @Override
184     public void handle(String target, Request baseRequest, HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException
185     {
186         if (HttpMethod.CONNECT.is(request.getMethod()))
187         {
188             String serverAddress = request.getRequestURI();
189             LOG.debug("CONNECT request for {}", serverAddress);
190             try
191             {
192                 handleConnect(baseRequest, request, response, serverAddress);
193             }
194             catch (Exception x)
195             {
196                 // TODO
197                 LOG.warn("ConnectHandler " + baseRequest.getUri() + " " + x);
198                 LOG.debug(x);
199             }
200         }
201         else
202         {
203             super.handle(target, baseRequest, request, response);
204         }
205     }
206 
207     /**
208      * <p>Handles a CONNECT request.</p>
209      * <p>CONNECT requests may have authentication headers such as {@code Proxy-Authorization}
210      * that authenticate the client with the proxy.</p>
211      *
212      * @param jettyRequest  Jetty-specific http request
213      * @param request       the http request
214      * @param response      the http response
215      * @param serverAddress the remote server address in the form {@code host:port}
216      */
217     protected void handleConnect(Request jettyRequest, HttpServletRequest request, HttpServletResponse response, String serverAddress)
218     {
219         jettyRequest.setHandled(true);
220         try
221         {
222             boolean proceed = handleAuthentication(request, response, serverAddress);
223             if (!proceed)
224             {
225                 LOG.debug("Missing proxy authentication");
226                 sendConnectResponse(request, response, HttpServletResponse.SC_PROXY_AUTHENTICATION_REQUIRED);
227                 return;
228             }
229 
230             String host = serverAddress;
231             int port = 80;
232             int colon = serverAddress.indexOf(':');
233             if (colon > 0)
234             {
235                 host = serverAddress.substring(0, colon);
236                 port = Integer.parseInt(serverAddress.substring(colon + 1));
237             }
238 
239             if (!validateDestination(host, port))
240             {
241                 LOG.debug("Destination {}:{} forbidden", host, port);
242                 sendConnectResponse(request, response, HttpServletResponse.SC_FORBIDDEN);
243                 return;
244             }
245 
246             SocketChannel channel = SocketChannel.open();
247             channel.socket().setTcpNoDelay(true);
248             channel.configureBlocking(false);
249             InetSocketAddress address = new InetSocketAddress(host, port);
250             channel.connect(address);
251 
252             AsyncContext asyncContext = request.startAsync();
253             asyncContext.setTimeout(0);
254 
255             LOG.debug("Connecting to {}", address);
256             ConnectContext connectContext = new ConnectContext(request, response, asyncContext, HttpConnection.getCurrentConnection());
257             selector.connect(channel, connectContext);
258         }
259         catch (Exception x)
260         {
261             onConnectFailure(request, response, null, x);
262         }
263     }
264 
265     protected void onConnectSuccess(ConnectContext connectContext, UpstreamConnection upstreamConnection)
266     {
267         HttpConnection httpConnection = connectContext.getHttpConnection();
268         ByteBuffer requestBuffer = httpConnection.getRequestBuffer();
269         ByteBuffer buffer = BufferUtil.EMPTY_BUFFER;
270         int remaining = requestBuffer.remaining();
271         if (remaining > 0)
272         {
273             buffer = bufferPool.acquire(remaining, requestBuffer.isDirect());
274             BufferUtil.flipToFill(buffer);
275             buffer.put(requestBuffer);
276             buffer.flip();
277         }
278 
279         ConcurrentMap<String, Object> context = connectContext.getContext();
280         HttpServletRequest request = connectContext.getRequest();
281         prepareContext(request, context);
282 
283         EndPoint downstreamEndPoint = httpConnection.getEndPoint();
284         DownstreamConnection downstreamConnection = newDownstreamConnection(downstreamEndPoint, context, buffer);
285         downstreamConnection.setInputBufferSize(getBufferSize());
286 
287         upstreamConnection.setConnection(downstreamConnection);
288         downstreamConnection.setConnection(upstreamConnection);
289         LOG.debug("Connection setup completed: {}<->{}", downstreamConnection, upstreamConnection);
290 
291         HttpServletResponse response = connectContext.getResponse();
292         sendConnectResponse(request, response, HttpServletResponse.SC_OK);
293 
294         upgradeConnection(request, response, downstreamConnection);
295         connectContext.getAsyncContext().complete();
296     }
297 
298     protected void onConnectFailure(HttpServletRequest request, HttpServletResponse response, AsyncContext asyncContext, Throwable failure)
299     {
300         LOG.debug("CONNECT failed", failure);
301         sendConnectResponse(request, response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
302         if (asyncContext != null)
303             asyncContext.complete();
304     }
305 
306     private void sendConnectResponse(HttpServletRequest request, HttpServletResponse response, int statusCode)
307     {
308         try
309         {
310             response.setStatus(statusCode);
311             if (statusCode != HttpServletResponse.SC_OK)
312                 response.setHeader(HttpHeader.CONNECTION.asString(), HttpHeaderValue.CLOSE.asString());
313             response.getOutputStream().close();
314             LOG.debug("CONNECT response sent {} {}", request.getProtocol(), response.getStatus());
315         }
316         catch (IOException x)
317         {
318             // TODO: nothing we can do, close the connection
319         }
320     }
321 
322     /**
323      * <p>Handles the authentication before setting up the tunnel to the remote server.</p>
324      * <p>The default implementation returns true.</p>
325      *
326      * @param request  the HTTP request
327      * @param response the HTTP response
328      * @param address  the address of the remote server in the form {@code host:port}.
329      * @return true to allow to connect to the remote host, false otherwise
330      */
331     protected boolean handleAuthentication(HttpServletRequest request, HttpServletResponse response, String address)
332     {
333         return true;
334     }
335 
336     protected DownstreamConnection newDownstreamConnection(EndPoint endPoint, ConcurrentMap<String, Object> context, ByteBuffer buffer)
337     {
338         return new DownstreamConnection(endPoint, getExecutor(), getByteBufferPool(), context, buffer);
339     }
340 
341     protected UpstreamConnection newUpstreamConnection(EndPoint endPoint, ConnectContext connectContext)
342     {
343         return new UpstreamConnection(endPoint, getExecutor(), getByteBufferPool(), connectContext);
344     }
345 
346     protected void prepareContext(HttpServletRequest request, ConcurrentMap<String, Object> context)
347     {
348     }
349 
350     private void upgradeConnection(HttpServletRequest request, HttpServletResponse response, Connection connection)
351     {
352         // Set the new connection as request attribute and change the status to 101
353         // so that Jetty understands that it has to upgrade the connection
354         request.setAttribute(HttpConnection.UPGRADE_CONNECTION_ATTRIBUTE, connection);
355         response.setStatus(HttpServletResponse.SC_SWITCHING_PROTOCOLS);
356         LOG.debug("Upgraded connection to {}", connection);
357     }
358 
359     /**
360      * <p>Reads (with non-blocking semantic) into the given {@code buffer} from the given {@code endPoint}.</p>
361      *
362      * @param endPoint the endPoint to read from
363      * @param buffer   the buffer to read data into
364      * @return the number of bytes read (possibly 0 since the read is non-blocking)
365      *         or -1 if the channel has been closed remotely
366      * @throws IOException if the endPoint cannot be read
367      */
368     protected int read(EndPoint endPoint, ByteBuffer buffer) throws IOException
369     {
370         return endPoint.fill(buffer);
371     }
372 
373     /**
374      * <p>Writes (with non-blocking semantic) the given buffer of data onto the given endPoint.</p>
375      *
376      * @param endPoint the endPoint to write to
377      * @param buffer   the buffer to write
378      * @param callback the completion callback to invoke
379      */
380     protected void write(EndPoint endPoint, ByteBuffer buffer, Callback callback)
381     {
382         LOG.debug("{} writing {} bytes", this, buffer.remaining());
383         endPoint.write(callback, buffer);
384     }
385 
386     public Set<String> getWhiteListHosts()
387     {
388         return whiteList;
389     }
390 
391     public Set<String> getBlackListHosts()
392     {
393         return blackList;
394     }
395 
396     /**
397      * Checks the given {@code host} and {@code port} against whitelist and blacklist.
398      *
399      * @param host the host to check
400      * @param port the port to check
401      * @return true if it is allowed to connect to the given host and port
402      */
403     public boolean validateDestination(String host, int port)
404     {
405         String hostPort = host + ":" + port;
406         if (!whiteList.isEmpty())
407         {
408             if (!whiteList.contains(hostPort))
409             {
410                 LOG.debug("Host {}:{} not whitelisted", host, port);
411                 return false;
412             }
413         }
414         if (!blackList.isEmpty())
415         {
416             if (blackList.contains(hostPort))
417             {
418                 LOG.debug("Host {}:{} blacklisted", host, port);
419                 return false;
420             }
421         }
422         return true;
423     }
424 
425     @Override
426     public void dump(Appendable out, String indent) throws IOException
427     {
428         dumpThis(out);
429         dump(out, indent, getBeans(), TypeUtil.asList(getHandlers()));
430     }
431 
432     protected class Manager extends SelectorManager
433     {
434 
435         private Manager(Executor executor, Scheduler scheduler, int selectors)
436         {
437             super(executor, scheduler, selectors);
438         }
439 
440         @Override
441         protected EndPoint newEndPoint(SocketChannel channel, ManagedSelector selector, SelectionKey selectionKey) throws IOException
442         {
443             return new SelectChannelEndPoint(channel, selector, selectionKey, getScheduler(), getIdleTimeout());
444         }
445 
446         @Override
447         public Connection newConnection(SocketChannel channel, EndPoint endpoint, Object attachment) throws IOException
448         {
449             ConnectHandler.LOG.debug("Connected to {}", channel.getRemoteAddress());
450             ConnectContext connectContext = (ConnectContext)attachment;
451             UpstreamConnection connection = newUpstreamConnection(endpoint, connectContext);
452             connection.setInputBufferSize(getBufferSize());
453             return connection;
454         }
455 
456         @Override
457         protected void connectionFailed(SocketChannel channel, Throwable ex, Object attachment)
458         {
459             ConnectContext connectContext = (ConnectContext)attachment;
460             onConnectFailure(connectContext.request, connectContext.response, connectContext.asyncContext, ex);
461         }
462     }
463 
464     protected static class ConnectContext
465     {
466         private final ConcurrentMap<String, Object> context = new ConcurrentHashMap<>();
467         private final HttpServletRequest request;
468         private final HttpServletResponse response;
469         private final AsyncContext asyncContext;
470         private final HttpConnection httpConnection;
471 
472         public ConnectContext(HttpServletRequest request, HttpServletResponse response, AsyncContext asyncContext, HttpConnection httpConnection)
473         {
474             this.request = request;
475             this.response = response;
476             this.asyncContext = asyncContext;
477             this.httpConnection = httpConnection;
478         }
479 
480         public ConcurrentMap<String, Object> getContext()
481         {
482             return context;
483         }
484 
485         public HttpServletRequest getRequest()
486         {
487             return request;
488         }
489 
490         public HttpServletResponse getResponse()
491         {
492             return response;
493         }
494 
495         public AsyncContext getAsyncContext()
496         {
497             return asyncContext;
498         }
499 
500         public HttpConnection getHttpConnection()
501         {
502             return httpConnection;
503         }
504     }
505 
506     public class UpstreamConnection extends ProxyConnection
507     {
508         private ConnectContext connectContext;
509 
510         public UpstreamConnection(EndPoint endPoint, Executor executor, ByteBufferPool bufferPool, ConnectContext connectContext)
511         {
512             super(endPoint, executor, bufferPool, connectContext.getContext());
513             this.connectContext = connectContext;
514         }
515 
516         @Override
517         public void onOpen()
518         {
519             super.onOpen();
520             onConnectSuccess(connectContext, this);
521             fillInterested();
522         }
523 
524         @Override
525         protected int read(EndPoint endPoint, ByteBuffer buffer) throws IOException
526         {
527             return ConnectHandler.this.read(endPoint, buffer);
528         }
529 
530         @Override
531         protected void write(EndPoint endPoint, ByteBuffer buffer,Callback callback)
532         {
533             ConnectHandler.this.write(endPoint, buffer, callback);
534         }
535     }
536 
537     public class DownstreamConnection extends ProxyConnection
538     {
539         private final ByteBuffer buffer;
540 
541         public DownstreamConnection(EndPoint endPoint, Executor executor, ByteBufferPool bufferPool, ConcurrentMap<String, Object> context, ByteBuffer buffer)
542         {
543             super(endPoint, executor, bufferPool, context);
544             this.buffer = buffer;
545         }
546 
547         @Override
548         public void onOpen()
549         {
550             super.onOpen();
551             final int remaining = buffer.remaining();
552             write(getConnection().getEndPoint(), buffer, new Callback()
553             {
554                 @Override
555                 public void succeeded()
556                 {
557                     LOG.debug("{} wrote initial {} bytes to server", DownstreamConnection.this, remaining);
558                     fillInterested();
559                 }
560 
561                 @Override
562                 public void failed(Throwable x)
563                 {
564                     LOG.debug(this + " failed to write initial " + remaining + " bytes to server", x);
565                     close();
566                     getConnection().close();
567                 }
568             });
569         }
570 
571         @Override
572         protected int read(EndPoint endPoint, ByteBuffer buffer) throws IOException
573         {
574             return ConnectHandler.this.read(endPoint, buffer);
575         }
576 
577         @Override
578         protected void write(EndPoint endPoint, ByteBuffer buffer, Callback callback)
579         {
580             ConnectHandler.this.write(endPoint, buffer, callback);
581         }
582     }
583 }