View Javadoc

1   // ========================================================================
2   // Copyright (c) 2009 Mort Bay Consulting Pty. Ltd.
3   // ------------------------------------------------------------------------
4   // All rights reserved. This program and the accompanying materials
5   // are made available under the terms of the Eclipse Public License v1.0
6   // and Apache License v2.0 which accompanies this distribution.
7   // The Eclipse Public License is available at 
8   // http://www.eclipse.org/legal/epl-v10.html
9   // The Apache License v2.0 is available at
10  // http://www.opensource.org/licenses/apache2.0.php
11  // You may elect to redistribute this code under either of these licenses. 
12  // ========================================================================
13  
14  package org.eclipse.jetty.servlets;
15  
16  import java.io.IOException;
17  import java.util.HashSet;
18  import java.util.Queue;
19  import java.util.StringTokenizer;
20  import java.util.concurrent.ConcurrentHashMap;
21  import java.util.concurrent.Semaphore;
22  import java.util.concurrent.TimeUnit;
23  
24  import javax.servlet.Filter;
25  import javax.servlet.FilterChain;
26  import javax.servlet.FilterConfig;
27  import javax.servlet.ServletContext;
28  import javax.servlet.ServletException;
29  import javax.servlet.ServletRequest;
30  import javax.servlet.ServletResponse;
31  import javax.servlet.http.HttpServletRequest;
32  import javax.servlet.http.HttpServletResponse;
33  import javax.servlet.http.HttpSession;
34  import javax.servlet.http.HttpSessionBindingEvent;
35  import javax.servlet.http.HttpSessionBindingListener;
36  
37  import org.eclipse.jetty.continuation.Continuation;
38  import org.eclipse.jetty.continuation.ContinuationSupport;
39  import org.eclipse.jetty.util.ArrayQueue;
40  import org.eclipse.jetty.util.log.Log;
41  import org.eclipse.jetty.util.thread.Timeout;
42  
43  /**
44   * Denial of Service filter
45   * 
46   * <p>
47   * This filter is based on the {@link QoSFilter}. it is useful for limiting
48   * exposure to abuse from request flooding, whether malicious, or as a result of
49   * a misconfigured client.
50   * <p>
51   * The filter keeps track of the number of requests from a connection per
52   * second. If a limit is exceeded, the request is either rejected, delayed, or
53   * throttled.
54   * <p>
55   * When a request is throttled, it is placed in a priority queue. Priority is
56   * given first to authenticated users and users with an HttpSession, then
57   * connections which can be identified by their IP addresses. Connections with
58   * no way to identify them are given lowest priority.
59   * <p>
60   * The {@link #extractUserId(ServletRequest request)} function should be
61   * implemented, in order to uniquely identify authenticated users.
62   * <p>
63   * The following init parameters control the behavior of the filter:
64   * 
65   * maxRequestsPerSec    the maximum number of requests from a connection per
66   *                      second. Requests in excess of this are first delayed, 
67   *                      then throttled.
68   * 
69   * delayMs              is the delay given to all requests over the rate limit, 
70   *                      before they are considered at all. -1 means just reject request, 
71   *                      0 means no delay, otherwise it is the delay.
72   * 
73   * maxWaitMs            how long to blocking wait for the throttle semaphore.
74   * 
75   * throttledRequests    is the number of requests over the rate limit able to be
76   *                      considered at once.
77   * 
78   * throttleMs           how long to async wait for semaphore.
79   * 
80   * maxRequestMs         how long to allow this request to run.
81   * 
82   * maxIdleTrackerMs     how long to keep track of request rates for a connection, 
83   *                      before deciding that the user has gone away, and discarding it
84   * 
85   * insertHeaders        if true , insert the DoSFilter headers into the response. Defaults to true.
86   * 
87   * trackSessions        if true, usage rate is tracked by session if a session exists. Defaults to true.
88   * 
89   * remotePort           if true and session tracking is not used, then rate is tracked by IP+port (effectively connection). Defaults to false.
90   * 
91   * ipWhitelist          a comma-separated list of IP addresses that will not be rate limited
92   */
93  
94  public class DoSFilter implements Filter
95  {
96      final static String __TRACKER = "DoSFilter.Tracker";
97      final static String __THROTTLED = "DoSFilter.Throttled";
98  
99      final static int __DEFAULT_MAX_REQUESTS_PER_SEC = 25;
100     final static int __DEFAULT_DELAY_MS = 100;
101     final static int __DEFAULT_THROTTLE = 5;
102     final static int __DEFAULT_WAIT_MS=50;
103     final static long __DEFAULT_THROTTLE_MS = 30000L;
104     final static long __DEFAULT_MAX_REQUEST_MS_INIT_PARAM=30000L;
105     final static long __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM=30000L;
106 
107     final static String MAX_REQUESTS_PER_S_INIT_PARAM = "maxRequestsPerSec";
108     final static String DELAY_MS_INIT_PARAM = "delayMs";
109     final static String THROTTLED_REQUESTS_INIT_PARAM = "throttledRequests";
110     final static String MAX_WAIT_INIT_PARAM="maxWaitMs";
111     final static String THROTTLE_MS_INIT_PARAM = "throttleMs";
112     final static String MAX_REQUEST_MS_INIT_PARAM="maxRequestMs";
113     final static String MAX_IDLE_TRACKER_MS_INIT_PARAM="maxIdleTrackerMs";
114     final static String INSERT_HEADERS_INIT_PARAM="insertHeaders";
115     final static String TRACK_SESSIONS_INIT_PARAM="trackSessions";
116     final static String REMOTE_PORT_INIT_PARAM="remotePort";
117     final static String IP_WHITELIST_INIT_PARAM="ipWhitelist";
118 
119     final static int USER_AUTH = 2;
120     final static int USER_SESSION = 2;
121     final static int USER_IP = 1;
122     final static int USER_UNKNOWN = 0;
123 
124     ServletContext _context;
125 
126     protected long _delayMs;
127     protected long _throttleMs;
128     protected long _waitMs;
129     protected long _maxRequestMs;
130     protected long _maxIdleTrackerMs;
131     protected boolean _insertHeaders;
132     protected boolean _trackSessions;
133     protected boolean _remotePort;
134     protected Semaphore _passes;
135     protected Queue<Continuation>[] _queue;
136 
137     protected int _maxRequestsPerSec;
138     protected final ConcurrentHashMap<String, RateTracker> _rateTrackers=new ConcurrentHashMap<String, RateTracker>();
139     private HashSet<String> _whitelist; 
140     
141     private final Timeout _requestTimeoutQ = new Timeout();
142     private final Timeout _trackerTimeoutQ = new Timeout();
143 
144     private Thread _timerThread;
145     private volatile boolean _running;
146 
147     public void init(FilterConfig filterConfig)
148     {
149         _context = filterConfig.getServletContext();
150 
151         _queue = new Queue[getMaxPriority() + 1];
152         for (int p = 0; p < _queue.length; p++)
153             _queue[p] = new ArrayQueue<Continuation>();
154 
155         int baseRateLimit = __DEFAULT_MAX_REQUESTS_PER_SEC;
156         if (filterConfig.getInitParameter(MAX_REQUESTS_PER_S_INIT_PARAM) != null)
157             baseRateLimit = Integer.parseInt(filterConfig.getInitParameter(MAX_REQUESTS_PER_S_INIT_PARAM));
158         _maxRequestsPerSec = baseRateLimit;
159 
160         long delay = __DEFAULT_DELAY_MS;
161         if (filterConfig.getInitParameter(DELAY_MS_INIT_PARAM) != null)
162             delay = Integer.parseInt(filterConfig.getInitParameter(DELAY_MS_INIT_PARAM));
163         _delayMs = delay;
164 
165         int passes = __DEFAULT_THROTTLE;
166         if (filterConfig.getInitParameter(THROTTLED_REQUESTS_INIT_PARAM) != null)
167             passes = Integer.parseInt(filterConfig.getInitParameter(THROTTLED_REQUESTS_INIT_PARAM));
168         _passes = new Semaphore(passes,true);
169 
170         long wait = __DEFAULT_WAIT_MS;
171         if (filterConfig.getInitParameter(MAX_WAIT_INIT_PARAM) != null)
172             wait = Integer.parseInt(filterConfig.getInitParameter(MAX_WAIT_INIT_PARAM));
173         _waitMs = wait;
174 
175         long suspend = __DEFAULT_THROTTLE_MS;
176         if (filterConfig.getInitParameter(THROTTLE_MS_INIT_PARAM) != null)
177             suspend = Integer.parseInt(filterConfig.getInitParameter(THROTTLE_MS_INIT_PARAM));
178         _throttleMs = suspend;
179 
180         long maxRequestMs = __DEFAULT_MAX_REQUEST_MS_INIT_PARAM;
181         if (filterConfig.getInitParameter(MAX_REQUEST_MS_INIT_PARAM) != null )
182             maxRequestMs = Long.parseLong(filterConfig.getInitParameter(MAX_REQUEST_MS_INIT_PARAM));
183         _maxRequestMs = maxRequestMs;
184 
185         long maxIdleTrackerMs = __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM;
186         if (filterConfig.getInitParameter(MAX_IDLE_TRACKER_MS_INIT_PARAM) != null )
187             maxIdleTrackerMs = Long.parseLong(filterConfig.getInitParameter(MAX_IDLE_TRACKER_MS_INIT_PARAM));
188         _maxIdleTrackerMs = maxIdleTrackerMs;
189         
190         String whitelistString = "";
191         if (filterConfig.getInitParameter(IP_WHITELIST_INIT_PARAM) !=null )
192             whitelistString = filterConfig.getInitParameter(IP_WHITELIST_INIT_PARAM);
193         
194         // empty 
195         if (whitelistString.length() == 0 )
196             _whitelist = new HashSet<String>();
197         else
198         {
199             StringTokenizer tokenizer = new StringTokenizer(whitelistString, ",");
200             _whitelist = new HashSet<String>(tokenizer.countTokens());
201             while (tokenizer.hasMoreTokens())
202                 _whitelist.add(tokenizer.nextToken().trim());
203             
204             Log.info("Whitelisted IP addresses: {}", _whitelist.toString());
205         }
206 
207         String tmp = filterConfig.getInitParameter(INSERT_HEADERS_INIT_PARAM);
208         _insertHeaders = tmp==null || Boolean.parseBoolean(tmp); 
209         
210         tmp = filterConfig.getInitParameter(TRACK_SESSIONS_INIT_PARAM);
211         _trackSessions = tmp==null || Boolean.parseBoolean(tmp);
212         
213         tmp = filterConfig.getInitParameter(REMOTE_PORT_INIT_PARAM);
214         _remotePort = tmp!=null&& Boolean.parseBoolean(tmp);
215 
216         _requestTimeoutQ.setNow();
217         _requestTimeoutQ.setDuration(_maxRequestMs);
218         
219         _trackerTimeoutQ.setNow();
220         _trackerTimeoutQ.setDuration(_maxIdleTrackerMs);
221         
222         _running=true;
223         _timerThread = (new Thread()
224         {
225             public void run()
226             {
227                 try
228                 {
229                     while (_running)
230                     {
231                         synchronized (_requestTimeoutQ)
232                         {
233                             _requestTimeoutQ.setNow();
234                             _requestTimeoutQ.tick();
235 
236                             _trackerTimeoutQ.setNow(_requestTimeoutQ.getNow());
237                             _trackerTimeoutQ.tick();
238                         }
239                         try
240                         {
241                             Thread.sleep(100);
242                         }
243                         catch (InterruptedException e)
244                         {
245                             Log.ignore(e);
246                         }
247                     }
248                 }
249                 finally
250                 {
251                     Log.info("DoSFilter timer exited");
252                 }
253             }
254         });
255         _timerThread.start();
256     }
257     
258 
259     public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterchain) throws IOException, ServletException
260     {
261         final HttpServletRequest srequest = (HttpServletRequest)request;
262         final HttpServletResponse sresponse = (HttpServletResponse)response;
263         
264         final long now=_requestTimeoutQ.getNow();
265         
266         // Look for the rate tracker for this request
267         RateTracker tracker = (RateTracker)request.getAttribute(__TRACKER);
268             
269         if (tracker==null)
270         {
271             // This is the first time we have seen this request.
272             
273             // get a rate tracker associated with this request, and record one hit
274             tracker = getRateTracker(request);
275             
276             // Calculate the rate and check it is over the allowed limit
277             final boolean overRateLimit = tracker.isRateExceeded(now);
278 
279             // pass it through if  we are not currently over the rate limit
280             if (!overRateLimit)
281             {
282                 doFilterChain(filterchain,srequest,sresponse);
283                 return;
284             }   
285             
286             // We are over the limit.
287             Log.warn("DOS ALERT: ip="+srequest.getRemoteAddr()+",session="+srequest.getRequestedSessionId()+",user="+srequest.getUserPrincipal());
288             
289             // So either reject it, delay it or throttle it
290             switch((int)_delayMs)
291             {
292                 case -1: 
293                 {
294                     // Reject this request
295                     if (_insertHeaders)
296                         ((HttpServletResponse)response).addHeader("DoSFilter","unavailable");
297                     ((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
298                     return;
299                 }
300                 case 0:
301                 {
302                     // fall through to throttle code
303                     request.setAttribute(__TRACKER,tracker);
304                     break;
305                 }
306                 default:
307                 {
308                     // insert a delay before throttling the request
309                     if (_insertHeaders)
310                         ((HttpServletResponse)response).addHeader("DoSFilter","delayed");
311                     Continuation continuation = ContinuationSupport.getContinuation(request,response);
312                     request.setAttribute(__TRACKER,tracker);
313                     if (_delayMs > 0)
314                         continuation.setTimeout(_delayMs);
315                     continuation.suspend();
316                     return;
317                 }
318             }
319         }
320 
321         // Throttle the request
322         boolean accepted = false;
323         try
324         {
325             // check if we can afford to accept another request at this time
326             accepted = _passes.tryAcquire(_waitMs,TimeUnit.MILLISECONDS);
327 
328             if (!accepted)
329             {
330                 // we were not accepted, so either we suspend to wait,or if we were woken up we insist or we fail
331                 final Continuation continuation = ContinuationSupport.getContinuation(request,response);
332                 
333                 Boolean throttled = (Boolean)request.getAttribute(__THROTTLED);
334                 if (throttled!=Boolean.TRUE && _throttleMs>0)
335                 {
336                     int priority = getPriority(request,tracker);
337                     request.setAttribute(__THROTTLED,Boolean.TRUE);
338                     if (_insertHeaders)
339                         ((HttpServletResponse)response).addHeader("DoSFilter","throttled");
340                     if (_throttleMs > 0)
341                         continuation.setTimeout(_throttleMs);
342                     continuation.suspend();
343 
344                     _queue[priority].add(continuation);
345                     return;
346                 }
347                 // else were we resumed?
348                 else if (request.getAttribute("javax.servlet.resumed")==Boolean.TRUE)
349                 {
350                     // we were resumed and somebody stole our pass, so we wait for the next one.
351                     _passes.acquire();
352                     accepted = true;
353                 }
354             }
355             
356             // if we were accepted (either immediately or after throttle)
357             if (accepted)       
358                 // call the chain
359                 doFilterChain(filterchain,srequest,sresponse);
360             else                
361             {
362                 // fail the request
363                 if (_insertHeaders)
364                     ((HttpServletResponse)response).addHeader("DoSFilter","unavailable");
365                 ((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
366             }
367         }
368         catch (InterruptedException e)
369         {
370             _context.log("DoS",e);
371             ((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
372         }
373         finally
374         {
375             if (accepted)
376             {
377                 // wake up the next highest priority request.
378                 synchronized (_queue)
379                 {
380                     for (int p = _queue.length; p-- > 0;)
381                     {
382                         Continuation continuation = _queue[p].poll();
383 
384                         if (continuation != null)
385                         {
386                             continuation.resume();
387                             break;
388                         }
389                     }
390                 }
391                 _passes.release();
392             }
393         }
394     }
395 
396     /**
397      * @param chain
398      * @param request
399      * @param response
400      * @throws IOException
401      * @throws ServletException
402      */
403     protected void doFilterChain(FilterChain chain, final HttpServletRequest request, final HttpServletResponse response) 
404         throws IOException, ServletException
405     {
406         final Thread thread=Thread.currentThread();
407         
408         final Timeout.Task requestTimeout = new Timeout.Task()
409         {
410             public void expired()
411             {
412                 closeConnection(request, response, thread);
413             }
414         };
415 
416         try
417         {
418             synchronized (_requestTimeoutQ)
419             {
420                 _requestTimeoutQ.schedule(requestTimeout);
421             }
422             chain.doFilter(request,response);
423         }
424         finally
425         {
426             synchronized (_requestTimeoutQ)
427             {
428                 requestTimeout.cancel();
429             }
430         }
431     }
432 
433     /**
434      * Takes drastic measures to return this response and stop this thread.
435      * Due to the way the connection is interrupted, may return mixed up headers.
436      * @param request current request
437      * @param response current response, which must be stopped
438      * @param thread the handling thread
439      */
440     protected void closeConnection(HttpServletRequest request, HttpServletResponse response, Thread thread)
441     {
442         // take drastic measures to return this response and stop this thread.
443         if( !response.isCommitted() )
444         {
445             response.setHeader("Connection", "close");
446         }
447         try 
448         {
449             try
450             {
451                 response.getWriter().close();
452             }
453             catch (IllegalStateException e)
454             {
455                 response.getOutputStream().close();
456             }
457         }
458         catch (IOException e)
459         {
460             Log.warn(e);
461         }
462         
463         // interrupt the handling thread
464         thread.interrupt();
465     }
466         
467     /**
468      * Get priority for this request, based on user type
469      * 
470      * @param request
471      * @param tracker
472      * @return priority
473      */
474     protected int getPriority(ServletRequest request, RateTracker tracker)
475     {
476         if (extractUserId(request)!=null)
477             return USER_AUTH;
478         if (tracker!=null)
479             return tracker.getType();
480         return USER_UNKNOWN;
481     }
482 
483     /**
484      * @return the maximum priority that we can assign to a request
485      */
486     protected int getMaxPriority()
487     {
488         return USER_AUTH;
489     }
490 
491     /**
492      * Return a request rate tracker associated with this connection; keeps
493      * track of this connection's request rate. If this is not the first request
494      * from this connection, return the existing object with the stored stats.
495      * If it is the first request, then create a new request tracker.
496      * 
497      * Assumes that each connection has an identifying characteristic, and goes
498      * through them in order, taking the first that matches: user id (logged
499      * in), session id, client IP address. Unidentifiable connections are lumped
500      * into one.
501      * 
502      * When a session expires, its rate tracker is automatically deleted.
503      * 
504      * @param request
505      * @return the request rate tracker for the current connection
506      */
507     public RateTracker getRateTracker(ServletRequest request)
508     {
509         HttpServletRequest srequest = (HttpServletRequest)request;
510 
511         String loadId;
512         final int type;
513         
514         loadId = extractUserId(request);
515         HttpSession session=srequest.getSession(false);
516         if (_trackSessions && session!=null && !session.isNew())
517         {
518             loadId=session.getId();
519             type = USER_SESSION;
520         }
521         else
522         {
523             loadId = _remotePort?(request.getRemoteAddr()+request.getRemotePort()):request.getRemoteAddr();
524             type = USER_IP;
525         }
526 
527         RateTracker tracker=_rateTrackers.get(loadId);
528         
529         if (tracker==null)
530         {
531             RateTracker t;
532             if (_whitelist.contains(request.getRemoteAddr()))
533             {
534                 t = new FixedRateTracker(loadId,type,_maxRequestsPerSec);
535             }
536             else
537             {
538                 t = new RateTracker(loadId,type,_maxRequestsPerSec);
539             }
540             
541             tracker=_rateTrackers.putIfAbsent(loadId,t);
542             if (tracker==null)
543                 tracker=t;
544             
545             if (type == USER_IP)
546             {
547                 // USER_IP expiration from _rateTrackers is handled by the _trackerTimeoutQ
548                 synchronized (_trackerTimeoutQ)
549                 {
550                     _trackerTimeoutQ.schedule(tracker);
551                 }
552             }
553             else if (session!=null)
554                 // USER_SESSION expiration from _rateTrackers are handled by the HttpSessionBindingListener
555                 session.setAttribute(__TRACKER,tracker);
556         }
557 
558         return tracker;
559     }
560 
561     public void destroy()
562     {
563         _running=false;
564         _timerThread.interrupt();
565         synchronized (_requestTimeoutQ)
566         {
567             _requestTimeoutQ.cancelAll();
568             _trackerTimeoutQ.cancelAll();
569         }
570     }
571 
572     /**
573      * Returns the user id, used to track this connection.
574      * This SHOULD be overridden by subclasses.
575      * 
576      * @param request
577      * @return a unique user id, if logged in; otherwise null.
578      */
579     protected String extractUserId(ServletRequest request)
580     {
581         return null;
582     }
583 
584     /**
585      * A RateTracker is associated with a connection, and stores request rate
586      * data.
587      */
588     class RateTracker extends Timeout.Task implements HttpSessionBindingListener
589     {
590         protected final String _id;
591         protected final int _type;
592         protected final long[] _timestamps;
593         protected int _next;
594         
595         public RateTracker(String id, int type,int maxRequestsPerSecond)
596         {
597             _id = id;
598             _type = type;
599             _timestamps=new long[maxRequestsPerSecond];
600             _next=0;
601         }
602 
603         /**
604          * @return the current calculated request rate over the last second
605          */
606         public boolean isRateExceeded(long now)
607         {
608             final long last;
609             synchronized (this)
610             {
611                 last=_timestamps[_next];
612                 _timestamps[_next]=now;
613                 _next= (_next+1)%_timestamps.length;
614             }
615 
616             boolean exceeded=last!=0 && (now-last)<1000L;
617             return exceeded;
618         }
619 
620 
621         public String getId()
622         {
623             return _id;
624         }
625 
626         public int getType()
627         {
628             return _type;
629         }
630 
631         
632         public void valueBound(HttpSessionBindingEvent event)
633         {
634         }
635 
636         public void valueUnbound(HttpSessionBindingEvent event)
637         {
638             _rateTrackers.remove(_id);
639         }
640         
641         public void expired()
642         {
643             long now = _trackerTimeoutQ.getNow();
644             int latestIndex = _next == 0 ? 3 : (_next - 1 ) % _timestamps.length; 
645             long last=_timestamps[latestIndex];
646             boolean hasRecentRequest = last != 0 && (now-last)<1000L;
647             
648             if (hasRecentRequest)
649                 reschedule();
650             else
651                 _rateTrackers.remove(_id);
652         }
653     }
654     
655     class FixedRateTracker extends RateTracker
656     {
657         public FixedRateTracker(String id, int type, int numRecentRequestsTracked)
658         {
659             super(id,type,numRecentRequestsTracked);
660         }
661 
662         public boolean isRateExceeded(long now)
663         {
664             // rate limit is never exceeded, but we keep track of the request timestamps
665             // so that we know whether there was recent activity on this tracker
666             // and whether it should be expired
667             synchronized (this)
668             {
669                 _timestamps[_next]=now;
670                 _next= (_next+1)%_timestamps.length;
671             }
672 
673             return false;
674         }        
675     }
676 }