View Javadoc

1   //
2   //  ========================================================================
3   //  Copyright (c) 1995-2013 Mort Bay Consulting Pty. Ltd.
4   //  ------------------------------------------------------------------------
5   //  All rights reserved. This program and the accompanying materials
6   //  are made available under the terms of the Eclipse Public License v1.0
7   //  and Apache License v2.0 which accompanies this distribution.
8   //
9   //      The Eclipse Public License is available at
10  //      http://www.eclipse.org/legal/epl-v10.html
11  //
12  //      The Apache License v2.0 is available at
13  //      http://www.opensource.org/licenses/apache2.0.php
14  //
15  //  You may elect to redistribute this code under either of these licenses.
16  //  ========================================================================
17  //
18  
19  package org.eclipse.jetty.proxy;
20  
21  import java.io.IOException;
22  import java.net.InetSocketAddress;
23  import java.nio.ByteBuffer;
24  import java.nio.channels.SelectionKey;
25  import java.nio.channels.SocketChannel;
26  import java.util.HashSet;
27  import java.util.Set;
28  import java.util.concurrent.ConcurrentHashMap;
29  import java.util.concurrent.ConcurrentMap;
30  import java.util.concurrent.Executor;
31  
32  import javax.servlet.AsyncContext;
33  import javax.servlet.ServletException;
34  import javax.servlet.http.HttpServletRequest;
35  import javax.servlet.http.HttpServletResponse;
36  
37  import org.eclipse.jetty.http.HttpHeader;
38  import org.eclipse.jetty.http.HttpHeaderValue;
39  import org.eclipse.jetty.http.HttpMethod;
40  import org.eclipse.jetty.io.ByteBufferPool;
41  import org.eclipse.jetty.io.Connection;
42  import org.eclipse.jetty.io.EndPoint;
43  import org.eclipse.jetty.io.MappedByteBufferPool;
44  import org.eclipse.jetty.io.SelectChannelEndPoint;
45  import org.eclipse.jetty.io.SelectorManager;
46  import org.eclipse.jetty.server.Handler;
47  import org.eclipse.jetty.server.HttpConnection;
48  import org.eclipse.jetty.server.Request;
49  import org.eclipse.jetty.server.handler.HandlerWrapper;
50  import org.eclipse.jetty.util.BufferUtil;
51  import org.eclipse.jetty.util.Callback;
52  import org.eclipse.jetty.util.TypeUtil;
53  import org.eclipse.jetty.util.log.Log;
54  import org.eclipse.jetty.util.log.Logger;
55  import org.eclipse.jetty.util.thread.ScheduledExecutorScheduler;
56  import org.eclipse.jetty.util.thread.Scheduler;
57  
58  /**
59   * <p>Implementation of a {@link Handler} that supports HTTP CONNECT.</p>
60   */
61  public class ConnectHandler extends HandlerWrapper
62  {
63      protected static final Logger LOG = Log.getLogger(ConnectHandler.class);
64  
65      private final Set<String> whiteList = new HashSet<>();
66      private final Set<String> blackList = new HashSet<>();
67      private Executor executor;
68      private Scheduler scheduler;
69      private ByteBufferPool bufferPool;
70      private SelectorManager selector;
71      private long connectTimeout = 15000;
72      private long idleTimeout = 30000;
73      private int bufferSize = 4096;
74  
75      public ConnectHandler()
76      {
77          this(null);
78      }
79  
80      public ConnectHandler(Handler handler)
81      {
82          setHandler(handler);
83      }
84  
85      public Executor getExecutor()
86      {
87          return executor;
88      }
89  
90      public void setExecutor(Executor executor)
91      {
92          this.executor = executor;
93      }
94  
95      public Scheduler getScheduler()
96      {
97          return scheduler;
98      }
99  
100     public void setScheduler(Scheduler scheduler)
101     {
102         this.scheduler = scheduler;
103     }
104 
105     public ByteBufferPool getByteBufferPool()
106     {
107         return bufferPool;
108     }
109 
110     public void setByteBufferPool(ByteBufferPool bufferPool)
111     {
112         this.bufferPool = bufferPool;
113     }
114 
115     /**
116      * @return the timeout, in milliseconds, to connect to the remote server
117      */
118     public long getConnectTimeout()
119     {
120         return connectTimeout;
121     }
122 
123     /**
124      * @param connectTimeout the timeout, in milliseconds, to connect to the remote server
125      */
126     public void setConnectTimeout(long connectTimeout)
127     {
128         this.connectTimeout = connectTimeout;
129     }
130 
131     /**
132      * @return the idle timeout, in milliseconds
133      */
134     public long getIdleTimeout()
135     {
136         return idleTimeout;
137     }
138 
139     /**
140      * @param idleTimeout the idle timeout, in milliseconds
141      */
142     public void setIdleTimeout(long idleTimeout)
143     {
144         this.idleTimeout = idleTimeout;
145     }
146 
147     public int getBufferSize()
148     {
149         return bufferSize;
150     }
151 
152     public void setBufferSize(int bufferSize)
153     {
154         this.bufferSize = bufferSize;
155     }
156 
157     @Override
158     protected void doStart() throws Exception
159     {
160         if (executor == null)
161         {
162             setExecutor(getServer().getThreadPool());
163         }
164         if (scheduler == null)
165         {
166             setScheduler(new ScheduledExecutorScheduler());
167             addBean(getScheduler());
168         }
169         if (bufferPool == null)
170         {
171             setByteBufferPool(new MappedByteBufferPool());
172             addBean(getByteBufferPool());
173         }
174         addBean(selector = newSelectorManager());
175         selector.setConnectTimeout(getConnectTimeout());
176         super.doStart();
177     }
178 
179     protected SelectorManager newSelectorManager()
180     {
181         return new ConnectManager(getExecutor(), getScheduler(), 1);
182     }
183 
184     @Override
185     public void handle(String target, Request baseRequest, HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException
186     {
187         if (HttpMethod.CONNECT.is(request.getMethod()))
188         {
189             String serverAddress = request.getRequestURI();
190             LOG.debug("CONNECT request for {}", serverAddress);
191             try
192             {
193                 handleConnect(baseRequest, request, response, serverAddress);
194             }
195             catch (Exception x)
196             {
197                 // TODO
198                 LOG.warn("ConnectHandler " + baseRequest.getUri() + " " + x);
199                 LOG.debug(x);
200             }
201         }
202         else
203         {
204             super.handle(target, baseRequest, request, response);
205         }
206     }
207 
208     /**
209      * <p>Handles a CONNECT request.</p>
210      * <p>CONNECT requests may have authentication headers such as {@code Proxy-Authorization}
211      * that authenticate the client with the proxy.</p>
212      *
213      * @param jettyRequest  Jetty-specific http request
214      * @param request       the http request
215      * @param response      the http response
216      * @param serverAddress the remote server address in the form {@code host:port}
217      */
218     protected void handleConnect(Request jettyRequest, HttpServletRequest request, HttpServletResponse response, String serverAddress)
219     {
220         jettyRequest.setHandled(true);
221         try
222         {
223             boolean proceed = handleAuthentication(request, response, serverAddress);
224             if (!proceed)
225             {
226                 LOG.debug("Missing proxy authentication");
227                 sendConnectResponse(request, response, HttpServletResponse.SC_PROXY_AUTHENTICATION_REQUIRED);
228                 return;
229             }
230 
231             String host = serverAddress;
232             int port = 80;
233             int colon = serverAddress.indexOf(':');
234             if (colon > 0)
235             {
236                 host = serverAddress.substring(0, colon);
237                 port = Integer.parseInt(serverAddress.substring(colon + 1));
238             }
239 
240             if (!validateDestination(host, port))
241             {
242                 LOG.debug("Destination {}:{} forbidden", host, port);
243                 sendConnectResponse(request, response, HttpServletResponse.SC_FORBIDDEN);
244                 return;
245             }
246 
247             SocketChannel channel = SocketChannel.open();
248             channel.socket().setTcpNoDelay(true);
249             channel.configureBlocking(false);
250             InetSocketAddress address = new InetSocketAddress(host, port);
251             channel.connect(address);
252 
253             AsyncContext asyncContext = request.startAsync();
254             asyncContext.setTimeout(0);
255 
256             LOG.debug("Connecting to {}", address);
257             ConnectContext connectContext = new ConnectContext(request, response, asyncContext, HttpConnection.getCurrentConnection());
258             selector.connect(channel, connectContext);
259         }
260         catch (Exception x)
261         {
262             onConnectFailure(request, response, null, x);
263         }
264     }
265 
266     protected void onConnectSuccess(ConnectContext connectContext, UpstreamConnection upstreamConnection)
267     {
268         HttpConnection httpConnection = connectContext.getHttpConnection();
269         ByteBuffer requestBuffer = httpConnection.getRequestBuffer();
270         ByteBuffer buffer = BufferUtil.EMPTY_BUFFER;
271         int remaining = requestBuffer.remaining();
272         if (remaining > 0)
273         {
274             buffer = bufferPool.acquire(remaining, requestBuffer.isDirect());
275             BufferUtil.flipToFill(buffer);
276             buffer.put(requestBuffer);
277             buffer.flip();
278         }
279 
280         ConcurrentMap<String, Object> context = connectContext.getContext();
281         HttpServletRequest request = connectContext.getRequest();
282         prepareContext(request, context);
283 
284         EndPoint downstreamEndPoint = httpConnection.getEndPoint();
285         DownstreamConnection downstreamConnection = newDownstreamConnection(downstreamEndPoint, context, buffer);
286         downstreamConnection.setInputBufferSize(getBufferSize());
287 
288         upstreamConnection.setConnection(downstreamConnection);
289         downstreamConnection.setConnection(upstreamConnection);
290         LOG.debug("Connection setup completed: {}<->{}", downstreamConnection, upstreamConnection);
291 
292         HttpServletResponse response = connectContext.getResponse();
293         sendConnectResponse(request, response, HttpServletResponse.SC_OK);
294 
295         upgradeConnection(request, response, downstreamConnection);
296         connectContext.getAsyncContext().complete();
297     }
298 
299     protected void onConnectFailure(HttpServletRequest request, HttpServletResponse response, AsyncContext asyncContext, Throwable failure)
300     {
301         LOG.debug("CONNECT failed", failure);
302         sendConnectResponse(request, response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
303         if (asyncContext != null)
304             asyncContext.complete();
305     }
306 
307     private void sendConnectResponse(HttpServletRequest request, HttpServletResponse response, int statusCode)
308     {
309         try
310         {
311             response.setStatus(statusCode);
312             if (statusCode != HttpServletResponse.SC_OK)
313                 response.setHeader(HttpHeader.CONNECTION.asString(), HttpHeaderValue.CLOSE.asString());
314             response.getOutputStream().close();
315             LOG.debug("CONNECT response sent {} {}", request.getProtocol(), response.getStatus());
316         }
317         catch (IOException x)
318         {
319             // TODO: nothing we can do, close the connection
320         }
321     }
322 
323     /**
324      * <p>Handles the authentication before setting up the tunnel to the remote server.</p>
325      * <p>The default implementation returns true.</p>
326      *
327      * @param request  the HTTP request
328      * @param response the HTTP response
329      * @param address  the address of the remote server in the form {@code host:port}.
330      * @return true to allow to connect to the remote host, false otherwise
331      */
332     protected boolean handleAuthentication(HttpServletRequest request, HttpServletResponse response, String address)
333     {
334         return true;
335     }
336 
337     protected DownstreamConnection newDownstreamConnection(EndPoint endPoint, ConcurrentMap<String, Object> context, ByteBuffer buffer)
338     {
339         return new DownstreamConnection(endPoint, getExecutor(), getByteBufferPool(), context, buffer);
340     }
341 
342     protected UpstreamConnection newUpstreamConnection(EndPoint endPoint, ConnectContext connectContext)
343     {
344         return new UpstreamConnection(endPoint, getExecutor(), getByteBufferPool(), connectContext);
345     }
346 
347     protected void prepareContext(HttpServletRequest request, ConcurrentMap<String, Object> context)
348     {
349     }
350 
351     private void upgradeConnection(HttpServletRequest request, HttpServletResponse response, Connection connection)
352     {
353         // Set the new connection as request attribute and change the status to 101
354         // so that Jetty understands that it has to upgrade the connection
355         request.setAttribute(HttpConnection.UPGRADE_CONNECTION_ATTRIBUTE, connection);
356         response.setStatus(HttpServletResponse.SC_SWITCHING_PROTOCOLS);
357         LOG.debug("Upgraded connection to {}", connection);
358     }
359 
360     /**
361      * <p>Reads (with non-blocking semantic) into the given {@code buffer} from the given {@code endPoint}.</p>
362      *
363      * @param endPoint the endPoint to read from
364      * @param buffer   the buffer to read data into
365      * @return the number of bytes read (possibly 0 since the read is non-blocking)
366      *         or -1 if the channel has been closed remotely
367      * @throws IOException if the endPoint cannot be read
368      */
369     protected int read(EndPoint endPoint, ByteBuffer buffer) throws IOException
370     {
371         return endPoint.fill(buffer);
372     }
373 
374     /**
375      * <p>Writes (with non-blocking semantic) the given buffer of data onto the given endPoint.</p>
376      *
377      * @param endPoint the endPoint to write to
378      * @param buffer   the buffer to write
379      * @param callback the completion callback to invoke
380      */
381     protected void write(EndPoint endPoint, ByteBuffer buffer, Callback callback)
382     {
383         LOG.debug("{} writing {} bytes", this, buffer.remaining());
384         endPoint.write(callback, buffer);
385     }
386 
387     public Set<String> getWhiteListHosts()
388     {
389         return whiteList;
390     }
391 
392     public Set<String> getBlackListHosts()
393     {
394         return blackList;
395     }
396 
397     /**
398      * Checks the given {@code host} and {@code port} against whitelist and blacklist.
399      *
400      * @param host the host to check
401      * @param port the port to check
402      * @return true if it is allowed to connect to the given host and port
403      */
404     public boolean validateDestination(String host, int port)
405     {
406         String hostPort = host + ":" + port;
407         if (!whiteList.isEmpty())
408         {
409             if (!whiteList.contains(hostPort))
410             {
411                 LOG.debug("Host {}:{} not whitelisted", host, port);
412                 return false;
413             }
414         }
415         if (!blackList.isEmpty())
416         {
417             if (blackList.contains(hostPort))
418             {
419                 LOG.debug("Host {}:{} blacklisted", host, port);
420                 return false;
421             }
422         }
423         return true;
424     }
425 
426     @Override
427     public void dump(Appendable out, String indent) throws IOException
428     {
429         dumpThis(out);
430         dump(out, indent, getBeans(), TypeUtil.asList(getHandlers()));
431     }
432 
433     protected class ConnectManager extends SelectorManager
434     {
435 
436         private ConnectManager(Executor executor, Scheduler scheduler, int selectors)
437         {
438             super(executor, scheduler, selectors);
439         }
440 
441         @Override
442         protected EndPoint newEndPoint(SocketChannel channel, ManagedSelector selector, SelectionKey selectionKey) throws IOException
443         {
444             return new SelectChannelEndPoint(channel, selector, selectionKey, getScheduler(), getIdleTimeout());
445         }
446 
447         @Override
448         public Connection newConnection(SocketChannel channel, EndPoint endpoint, Object attachment) throws IOException
449         {
450             ConnectHandler.LOG.debug("Connected to {}", channel.getRemoteAddress());
451             ConnectContext connectContext = (ConnectContext)attachment;
452             UpstreamConnection connection = newUpstreamConnection(endpoint, connectContext);
453             connection.setInputBufferSize(getBufferSize());
454             return connection;
455         }
456 
457         @Override
458         protected void connectionFailed(SocketChannel channel, Throwable ex, Object attachment)
459         {
460             ConnectContext connectContext = (ConnectContext)attachment;
461             onConnectFailure(connectContext.request, connectContext.response, connectContext.asyncContext, ex);
462         }
463     }
464 
465     protected static class ConnectContext
466     {
467         private final ConcurrentMap<String, Object> context = new ConcurrentHashMap<>();
468         private final HttpServletRequest request;
469         private final HttpServletResponse response;
470         private final AsyncContext asyncContext;
471         private final HttpConnection httpConnection;
472 
473         public ConnectContext(HttpServletRequest request, HttpServletResponse response, AsyncContext asyncContext, HttpConnection httpConnection)
474         {
475             this.request = request;
476             this.response = response;
477             this.asyncContext = asyncContext;
478             this.httpConnection = httpConnection;
479         }
480 
481         public ConcurrentMap<String, Object> getContext()
482         {
483             return context;
484         }
485 
486         public HttpServletRequest getRequest()
487         {
488             return request;
489         }
490 
491         public HttpServletResponse getResponse()
492         {
493             return response;
494         }
495 
496         public AsyncContext getAsyncContext()
497         {
498             return asyncContext;
499         }
500 
501         public HttpConnection getHttpConnection()
502         {
503             return httpConnection;
504         }
505     }
506 
507     public class UpstreamConnection extends ProxyConnection
508     {
509         private ConnectContext connectContext;
510 
511         public UpstreamConnection(EndPoint endPoint, Executor executor, ByteBufferPool bufferPool, ConnectContext connectContext)
512         {
513             super(endPoint, executor, bufferPool, connectContext.getContext());
514             this.connectContext = connectContext;
515         }
516 
517         @Override
518         public void onOpen()
519         {
520             super.onOpen();
521             onConnectSuccess(connectContext, this);
522             fillInterested();
523         }
524 
525         @Override
526         protected int read(EndPoint endPoint, ByteBuffer buffer) throws IOException
527         {
528             return ConnectHandler.this.read(endPoint, buffer);
529         }
530 
531         @Override
532         protected void write(EndPoint endPoint, ByteBuffer buffer,Callback callback)
533         {
534             ConnectHandler.this.write(endPoint, buffer, callback);
535         }
536     }
537 
538     public class DownstreamConnection extends ProxyConnection
539     {
540         private final ByteBuffer buffer;
541 
542         public DownstreamConnection(EndPoint endPoint, Executor executor, ByteBufferPool bufferPool, ConcurrentMap<String, Object> context, ByteBuffer buffer)
543         {
544             super(endPoint, executor, bufferPool, context);
545             this.buffer = buffer;
546         }
547 
548         @Override
549         public void onOpen()
550         {
551             super.onOpen();
552             final int remaining = buffer.remaining();
553             write(getConnection().getEndPoint(), buffer, new Callback()
554             {
555                 @Override
556                 public void succeeded()
557                 {
558                     LOG.debug("{} wrote initial {} bytes to server", DownstreamConnection.this, remaining);
559                     fillInterested();
560                 }
561 
562                 @Override
563                 public void failed(Throwable x)
564                 {
565                     LOG.debug(this + " failed to write initial " + remaining + " bytes to server", x);
566                     close();
567                     getConnection().close();
568                 }
569             });
570         }
571 
572         @Override
573         protected int read(EndPoint endPoint, ByteBuffer buffer) throws IOException
574         {
575             return ConnectHandler.this.read(endPoint, buffer);
576         }
577 
578         @Override
579         protected void write(EndPoint endPoint, ByteBuffer buffer, Callback callback)
580         {
581             ConnectHandler.this.write(endPoint, buffer, callback);
582         }
583     }
584 }