View Javadoc

1   // ========================================================================
2   // Copyright (c) 2009 Mort Bay Consulting Pty. Ltd.
3   // ------------------------------------------------------------------------
4   // All rights reserved. This program and the accompanying materials
5   // are made available under the terms of the Eclipse Public License v1.0
6   // and Apache License v2.0 which accompanies this distribution.
7   // The Eclipse Public License is available at
8   // http://www.eclipse.org/legal/epl-v10.html
9   // The Apache License v2.0 is available at
10  // http://www.opensource.org/licenses/apache2.0.php
11  // You may elect to redistribute this code under either of these licenses.
12  // ========================================================================
13  
14  package org.eclipse.jetty.servlets;
15  
16  import java.io.IOException;
17  import java.util.HashSet;
18  import java.util.Queue;
19  import java.util.StringTokenizer;
20  import java.util.concurrent.ConcurrentHashMap;
21  import java.util.concurrent.ConcurrentLinkedQueue;
22  import java.util.concurrent.Semaphore;
23  import java.util.concurrent.TimeUnit;
24  import javax.servlet.Filter;
25  import javax.servlet.FilterChain;
26  import javax.servlet.FilterConfig;
27  import javax.servlet.ServletContext;
28  import javax.servlet.ServletException;
29  import javax.servlet.ServletRequest;
30  import javax.servlet.ServletResponse;
31  import javax.servlet.http.HttpServletRequest;
32  import javax.servlet.http.HttpServletResponse;
33  import javax.servlet.http.HttpSession;
34  import javax.servlet.http.HttpSessionBindingEvent;
35  import javax.servlet.http.HttpSessionBindingListener;
36  
37  import org.eclipse.jetty.continuation.Continuation;
38  import org.eclipse.jetty.continuation.ContinuationListener;
39  import org.eclipse.jetty.continuation.ContinuationSupport;
40  import org.eclipse.jetty.util.log.Log;
41  import org.eclipse.jetty.util.thread.Timeout;
42  
43  /**
44   * Denial of Service filter
45   *
46   * <p>
47   * This filter is based on the {@link QoSFilter}. it is useful for limiting
48   * exposure to abuse from request flooding, whether malicious, or as a result of
49   * a misconfigured client.
50   * <p>
51   * The filter keeps track of the number of requests from a connection per
52   * second. If a limit is exceeded, the request is either rejected, delayed, or
53   * throttled.
54   * <p>
55   * When a request is throttled, it is placed in a priority queue. Priority is
56   * given first to authenticated users and users with an HttpSession, then
57   * connections which can be identified by their IP addresses. Connections with
58   * no way to identify them are given lowest priority.
59   * <p>
60   * The {@link #extractUserId(ServletRequest request)} function should be
61   * implemented, in order to uniquely identify authenticated users.
62   * <p>
63   * The following init parameters control the behavior of the filter:
64   *
65   * maxRequestsPerSec    the maximum number of requests from a connection per
66   *                      second. Requests in excess of this are first delayed,
67   *                      then throttled.
68   *
69   * delayMs              is the delay given to all requests over the rate limit,
70   *                      before they are considered at all. -1 means just reject request,
71   *                      0 means no delay, otherwise it is the delay.
72   *
73   * maxWaitMs            how long to blocking wait for the throttle semaphore.
74   *
75   * throttledRequests    is the number of requests over the rate limit able to be
76   *                      considered at once.
77   *
78   * throttleMs           how long to async wait for semaphore.
79   *
80   * maxRequestMs         how long to allow this request to run.
81   *
82   * maxIdleTrackerMs     how long to keep track of request rates for a connection,
83   *                      before deciding that the user has gone away, and discarding it
84   *
85   * insertHeaders        if true , insert the DoSFilter headers into the response. Defaults to true.
86   *
87   * trackSessions        if true, usage rate is tracked by session if a session exists. Defaults to true.
88   *
89   * remotePort           if true and session tracking is not used, then rate is tracked by IP+port (effectively connection). Defaults to false.
90   *
91   * ipWhitelist          a comma-separated list of IP addresses that will not be rate limited
92   */
93  
94  public class DoSFilter implements Filter
95  {
96      final static String __TRACKER = "DoSFilter.Tracker";
97      final static String __THROTTLED = "DoSFilter.Throttled";
98  
99      final static int __DEFAULT_MAX_REQUESTS_PER_SEC = 25;
100     final static int __DEFAULT_DELAY_MS = 100;
101     final static int __DEFAULT_THROTTLE = 5;
102     final static int __DEFAULT_WAIT_MS=50;
103     final static long __DEFAULT_THROTTLE_MS = 30000L;
104     final static long __DEFAULT_MAX_REQUEST_MS_INIT_PARAM=30000L;
105     final static long __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM=30000L;
106 
107     final static String MAX_REQUESTS_PER_S_INIT_PARAM = "maxRequestsPerSec";
108     final static String DELAY_MS_INIT_PARAM = "delayMs";
109     final static String THROTTLED_REQUESTS_INIT_PARAM = "throttledRequests";
110     final static String MAX_WAIT_INIT_PARAM="maxWaitMs";
111     final static String THROTTLE_MS_INIT_PARAM = "throttleMs";
112     final static String MAX_REQUEST_MS_INIT_PARAM="maxRequestMs";
113     final static String MAX_IDLE_TRACKER_MS_INIT_PARAM="maxIdleTrackerMs";
114     final static String INSERT_HEADERS_INIT_PARAM="insertHeaders";
115     final static String TRACK_SESSIONS_INIT_PARAM="trackSessions";
116     final static String REMOTE_PORT_INIT_PARAM="remotePort";
117     final static String IP_WHITELIST_INIT_PARAM="ipWhitelist";
118 
119     final static int USER_AUTH = 2;
120     final static int USER_SESSION = 2;
121     final static int USER_IP = 1;
122     final static int USER_UNKNOWN = 0;
123 
124     ServletContext _context;
125 
126     protected long _delayMs;
127     protected long _throttleMs;
128     protected long _waitMs;
129     protected long _maxRequestMs;
130     protected long _maxIdleTrackerMs;
131     protected boolean _insertHeaders;
132     protected boolean _trackSessions;
133     protected boolean _remotePort;
134     protected Semaphore _passes;
135     protected Queue<Continuation>[] _queue;
136     protected ContinuationListener[] _listener;
137 
138     protected int _maxRequestsPerSec;
139     protected final ConcurrentHashMap<String, RateTracker> _rateTrackers=new ConcurrentHashMap<String, RateTracker>();
140     private final HashSet<String> _whitelist = new HashSet<String>();
141 
142     private final Timeout _requestTimeoutQ = new Timeout();
143     private final Timeout _trackerTimeoutQ = new Timeout();
144 
145     private Thread _timerThread;
146     private volatile boolean _running;
147 
148     public void init(FilterConfig filterConfig)
149     {
150         _context = filterConfig.getServletContext();
151 
152         _queue = new Queue[getMaxPriority() + 1];
153         _listener = new ContinuationListener[getMaxPriority() + 1];
154         for (int p = 0; p < _queue.length; p++)
155         {
156             _queue[p] = new ConcurrentLinkedQueue<Continuation>();
157 
158             final int priority=p;
159             _listener[p] = new ContinuationListener()
160             {
161                 public void onComplete(Continuation continuation)
162                 {
163                 }
164 
165                 public void onTimeout(Continuation continuation)
166                 {
167                     _queue[priority].remove(continuation);
168                 }
169             };
170         }
171 
172         _rateTrackers.clear();
173         _whitelist.clear();
174 
175         int baseRateLimit = __DEFAULT_MAX_REQUESTS_PER_SEC;
176         if (filterConfig.getInitParameter(MAX_REQUESTS_PER_S_INIT_PARAM) != null)
177             baseRateLimit = Integer.parseInt(filterConfig.getInitParameter(MAX_REQUESTS_PER_S_INIT_PARAM));
178         _maxRequestsPerSec = baseRateLimit;
179 
180         long delay = __DEFAULT_DELAY_MS;
181         if (filterConfig.getInitParameter(DELAY_MS_INIT_PARAM) != null)
182             delay = Integer.parseInt(filterConfig.getInitParameter(DELAY_MS_INIT_PARAM));
183         _delayMs = delay;
184 
185         int passes = __DEFAULT_THROTTLE;
186         if (filterConfig.getInitParameter(THROTTLED_REQUESTS_INIT_PARAM) != null)
187             passes = Integer.parseInt(filterConfig.getInitParameter(THROTTLED_REQUESTS_INIT_PARAM));
188         _passes = new Semaphore(passes,true);
189 
190         long wait = __DEFAULT_WAIT_MS;
191         if (filterConfig.getInitParameter(MAX_WAIT_INIT_PARAM) != null)
192             wait = Integer.parseInt(filterConfig.getInitParameter(MAX_WAIT_INIT_PARAM));
193         _waitMs = wait;
194 
195         long suspend = __DEFAULT_THROTTLE_MS;
196         if (filterConfig.getInitParameter(THROTTLE_MS_INIT_PARAM) != null)
197             suspend = Integer.parseInt(filterConfig.getInitParameter(THROTTLE_MS_INIT_PARAM));
198         _throttleMs = suspend;
199 
200         long maxRequestMs = __DEFAULT_MAX_REQUEST_MS_INIT_PARAM;
201         if (filterConfig.getInitParameter(MAX_REQUEST_MS_INIT_PARAM) != null )
202             maxRequestMs = Long.parseLong(filterConfig.getInitParameter(MAX_REQUEST_MS_INIT_PARAM));
203         _maxRequestMs = maxRequestMs;
204 
205         long maxIdleTrackerMs = __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM;
206         if (filterConfig.getInitParameter(MAX_IDLE_TRACKER_MS_INIT_PARAM) != null )
207             maxIdleTrackerMs = Long.parseLong(filterConfig.getInitParameter(MAX_IDLE_TRACKER_MS_INIT_PARAM));
208         _maxIdleTrackerMs = maxIdleTrackerMs;
209 
210         String whitelistString = "";
211         if (filterConfig.getInitParameter(IP_WHITELIST_INIT_PARAM) !=null )
212             whitelistString = filterConfig.getInitParameter(IP_WHITELIST_INIT_PARAM);
213 
214         if (whitelistString.length() > 0)
215         {
216             StringTokenizer tokenizer = new StringTokenizer(whitelistString, ",");
217             while (tokenizer.hasMoreTokens())
218                 _whitelist.add(tokenizer.nextToken().trim());
219 
220             Log.info("Whitelisted IP addresses: {}", _whitelist.toString());
221         }
222 
223         String tmp = filterConfig.getInitParameter(INSERT_HEADERS_INIT_PARAM);
224         _insertHeaders = tmp==null || Boolean.parseBoolean(tmp);
225 
226         tmp = filterConfig.getInitParameter(TRACK_SESSIONS_INIT_PARAM);
227         _trackSessions = tmp==null || Boolean.parseBoolean(tmp);
228 
229         tmp = filterConfig.getInitParameter(REMOTE_PORT_INIT_PARAM);
230         _remotePort = tmp!=null&& Boolean.parseBoolean(tmp);
231 
232         _requestTimeoutQ.setNow();
233         _requestTimeoutQ.setDuration(_maxRequestMs);
234 
235         _trackerTimeoutQ.setNow();
236         _trackerTimeoutQ.setDuration(_maxIdleTrackerMs);
237 
238         _running=true;
239         _timerThread = (new Thread()
240         {
241             public void run()
242             {
243                 try
244                 {
245                     while (_running)
246                     {
247                         long now;
248                         synchronized (_requestTimeoutQ)
249                         {
250                             now = _requestTimeoutQ.setNow();
251                             _requestTimeoutQ.tick();
252                         }
253                         synchronized (_trackerTimeoutQ)
254                         {
255                             _trackerTimeoutQ.setNow(now);
256                             _trackerTimeoutQ.tick();
257                         }
258                         try
259                         {
260                             Thread.sleep(100);
261                         }
262                         catch (InterruptedException e)
263                         {
264                             Log.ignore(e);
265                         }
266                     }
267                 }
268                 finally
269                 {
270                     Log.info("DoSFilter timer exited");
271                 }
272             }
273         });
274         _timerThread.start();
275     }
276 
277 
278     public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterchain) throws IOException, ServletException
279     {
280         final HttpServletRequest srequest = (HttpServletRequest)request;
281         final HttpServletResponse sresponse = (HttpServletResponse)response;
282 
283         final long now=_requestTimeoutQ.getNow();
284 
285         // Look for the rate tracker for this request
286         RateTracker tracker = (RateTracker)request.getAttribute(__TRACKER);
287 
288         if (tracker==null)
289         {
290             // This is the first time we have seen this request.
291 
292             // get a rate tracker associated with this request, and record one hit
293             tracker = getRateTracker(request);
294 
295             // Calculate the rate and check it is over the allowed limit
296             final boolean overRateLimit = tracker.isRateExceeded(now);
297 
298             // pass it through if  we are not currently over the rate limit
299             if (!overRateLimit)
300             {
301                 doFilterChain(filterchain,srequest,sresponse);
302                 return;
303             }
304 
305             // We are over the limit.
306             Log.warn("DOS ALERT: ip="+srequest.getRemoteAddr()+",session="+srequest.getRequestedSessionId()+",user="+srequest.getUserPrincipal());
307 
308             // So either reject it, delay it or throttle it
309             switch((int)_delayMs)
310             {
311                 case -1:
312                 {
313                     // Reject this request
314                     if (_insertHeaders)
315                         ((HttpServletResponse)response).addHeader("DoSFilter","unavailable");
316                     ((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
317                     return;
318                 }
319                 case 0:
320                 {
321                     // fall through to throttle code
322                     request.setAttribute(__TRACKER,tracker);
323                     break;
324                 }
325                 default:
326                 {
327                     // insert a delay before throttling the request
328                     if (_insertHeaders)
329                         ((HttpServletResponse)response).addHeader("DoSFilter","delayed");
330                     Continuation continuation = ContinuationSupport.getContinuation(request);
331                     request.setAttribute(__TRACKER,tracker);
332                     if (_delayMs > 0)
333                         continuation.setTimeout(_delayMs);
334                     continuation.suspend();
335                     return;
336                 }
337             }
338         }
339 
340         // Throttle the request
341         boolean accepted = false;
342         try
343         {
344             // check if we can afford to accept another request at this time
345             accepted = _passes.tryAcquire(_waitMs,TimeUnit.MILLISECONDS);
346 
347             if (!accepted)
348             {
349                 // we were not accepted, so either we suspend to wait,or if we were woken up we insist or we fail
350                 final Continuation continuation = ContinuationSupport.getContinuation(request);
351 
352                 Boolean throttled = (Boolean)request.getAttribute(__THROTTLED);
353                 if (throttled!=Boolean.TRUE && _throttleMs>0)
354                 {
355                     int priority = getPriority(request,tracker);
356                     request.setAttribute(__THROTTLED,Boolean.TRUE);
357                     if (_insertHeaders)
358                         ((HttpServletResponse)response).addHeader("DoSFilter","throttled");
359                     if (_throttleMs > 0)
360                         continuation.setTimeout(_throttleMs);
361                     continuation.suspend();
362 
363                     continuation.addContinuationListener(_listener[priority]);
364                     _queue[priority].add(continuation);
365                     return;
366                 }
367                 // else were we resumed?
368                 else if (request.getAttribute("javax.servlet.resumed")==Boolean.TRUE)
369                 {
370                     // we were resumed and somebody stole our pass, so we wait for the next one.
371                     _passes.acquire();
372                     accepted = true;
373                 }
374             }
375 
376             // if we were accepted (either immediately or after throttle)
377             if (accepted)
378                 // call the chain
379                 doFilterChain(filterchain,srequest,sresponse);
380             else
381             {
382                 // fail the request
383                 if (_insertHeaders)
384                     ((HttpServletResponse)response).addHeader("DoSFilter","unavailable");
385                 ((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
386             }
387         }
388         catch (InterruptedException e)
389         {
390             _context.log("DoS",e);
391             ((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
392         }
393         finally
394         {
395             if (accepted)
396             {
397                 // wake up the next highest priority request.
398                 for (int p = _queue.length; p-- > 0;)
399                 {
400                     Continuation continuation = _queue[p].poll();
401                     if (continuation != null && continuation.isSuspended())
402                     {
403                         continuation.resume();
404                         break;
405                     }
406                 }
407                 _passes.release();
408             }
409         }
410     }
411 
412     /**
413      * @param chain
414      * @param request
415      * @param response
416      * @throws IOException
417      * @throws ServletException
418      */
419     protected void doFilterChain(FilterChain chain, final HttpServletRequest request, final HttpServletResponse response)
420         throws IOException, ServletException
421     {
422         final Thread thread=Thread.currentThread();
423 
424         final Timeout.Task requestTimeout = new Timeout.Task()
425         {
426             public void expired()
427             {
428                 closeConnection(request, response, thread);
429             }
430         };
431 
432         try
433         {
434             synchronized (_requestTimeoutQ)
435             {
436                 _requestTimeoutQ.schedule(requestTimeout);
437             }
438             chain.doFilter(request,response);
439         }
440         finally
441         {
442             synchronized (_requestTimeoutQ)
443             {
444                 requestTimeout.cancel();
445             }
446         }
447     }
448 
449     /**
450      * Takes drastic measures to return this response and stop this thread.
451      * Due to the way the connection is interrupted, may return mixed up headers.
452      * @param request current request
453      * @param response current response, which must be stopped
454      * @param thread the handling thread
455      */
456     protected void closeConnection(HttpServletRequest request, HttpServletResponse response, Thread thread)
457     {
458         // take drastic measures to return this response and stop this thread.
459         if( !response.isCommitted() )
460         {
461             response.setHeader("Connection", "close");
462         }
463         try
464         {
465             try
466             {
467                 response.getWriter().close();
468             }
469             catch (IllegalStateException e)
470             {
471                 response.getOutputStream().close();
472             }
473         }
474         catch (IOException e)
475         {
476             Log.warn(e);
477         }
478 
479         // interrupt the handling thread
480         thread.interrupt();
481     }
482 
483     /**
484      * Get priority for this request, based on user type
485      *
486      * @param request
487      * @param tracker
488      * @return priority
489      */
490     protected int getPriority(ServletRequest request, RateTracker tracker)
491     {
492         if (extractUserId(request)!=null)
493             return USER_AUTH;
494         if (tracker!=null)
495             return tracker.getType();
496         return USER_UNKNOWN;
497     }
498 
499     /**
500      * @return the maximum priority that we can assign to a request
501      */
502     protected int getMaxPriority()
503     {
504         return USER_AUTH;
505     }
506 
507     /**
508      * Return a request rate tracker associated with this connection; keeps
509      * track of this connection's request rate. If this is not the first request
510      * from this connection, return the existing object with the stored stats.
511      * If it is the first request, then create a new request tracker.
512      *
513      * Assumes that each connection has an identifying characteristic, and goes
514      * through them in order, taking the first that matches: user id (logged
515      * in), session id, client IP address. Unidentifiable connections are lumped
516      * into one.
517      *
518      * When a session expires, its rate tracker is automatically deleted.
519      *
520      * @param request
521      * @return the request rate tracker for the current connection
522      */
523     public RateTracker getRateTracker(ServletRequest request)
524     {
525         HttpServletRequest srequest = (HttpServletRequest)request;
526         HttpSession session=srequest.getSession(false);
527 
528         String loadId = extractUserId(request);
529         final int type;
530         if (loadId != null)
531         {
532             type = USER_AUTH;
533         }
534         else
535         {
536             if (_trackSessions && session!=null && !session.isNew())
537             {
538                 loadId=session.getId();
539                 type = USER_SESSION;
540             }
541             else
542             {
543                 loadId = _remotePort?(request.getRemoteAddr()+request.getRemotePort()):request.getRemoteAddr();
544                 type = USER_IP;
545             }
546         }
547 
548         RateTracker tracker=_rateTrackers.get(loadId);
549 
550         if (tracker==null)
551         {
552             RateTracker t;
553             if (_whitelist.contains(request.getRemoteAddr()))
554             {
555                 t = new FixedRateTracker(loadId,type,_maxRequestsPerSec);
556             }
557             else
558             {
559                 t = new RateTracker(loadId,type,_maxRequestsPerSec);
560             }
561 
562             tracker=_rateTrackers.putIfAbsent(loadId,t);
563             if (tracker==null)
564                 tracker=t;
565 
566             if (type == USER_IP)
567             {
568                 // USER_IP expiration from _rateTrackers is handled by the _trackerTimeoutQ
569                 synchronized (_trackerTimeoutQ)
570                 {
571                     _trackerTimeoutQ.schedule(tracker);
572                 }
573             }
574             else if (session!=null)
575                 // USER_SESSION expiration from _rateTrackers are handled by the HttpSessionBindingListener
576                 session.setAttribute(__TRACKER,tracker);
577         }
578 
579         return tracker;
580     }
581 
582     public void destroy()
583     {
584         _running=false;
585         _timerThread.interrupt();
586         synchronized (_requestTimeoutQ)
587         {
588             _requestTimeoutQ.cancelAll();
589         }
590         synchronized (_trackerTimeoutQ)
591         {
592             _trackerTimeoutQ.cancelAll();
593         }
594         _rateTrackers.clear();
595         _whitelist.clear();
596     }
597 
598     /**
599      * Returns the user id, used to track this connection.
600      * This SHOULD be overridden by subclasses.
601      *
602      * @param request
603      * @return a unique user id, if logged in; otherwise null.
604      */
605     protected String extractUserId(ServletRequest request)
606     {
607         return null;
608     }
609 
610     /**
611      * A RateTracker is associated with a connection, and stores request rate
612      * data.
613      */
614     class RateTracker extends Timeout.Task implements HttpSessionBindingListener
615     {
616         protected final String _id;
617         protected final int _type;
618         protected final long[] _timestamps;
619         protected int _next;
620 
621         public RateTracker(String id, int type,int maxRequestsPerSecond)
622         {
623             _id = id;
624             _type = type;
625             _timestamps=new long[maxRequestsPerSecond];
626             _next=0;
627         }
628 
629         /**
630          * @return the current calculated request rate over the last second
631          */
632         public boolean isRateExceeded(long now)
633         {
634             final long last;
635             synchronized (this)
636             {
637                 last=_timestamps[_next];
638                 _timestamps[_next]=now;
639                 _next= (_next+1)%_timestamps.length;
640             }
641 
642             boolean exceeded=last!=0 && (now-last)<1000L;
643             return exceeded;
644         }
645 
646 
647         public String getId()
648         {
649             return _id;
650         }
651 
652         public int getType()
653         {
654             return _type;
655         }
656 
657 
658         public void valueBound(HttpSessionBindingEvent event)
659         {
660         }
661 
662         public void valueUnbound(HttpSessionBindingEvent event)
663         {
664             _rateTrackers.remove(_id);
665         }
666 
667         public void expired()
668         {
669             long now = _trackerTimeoutQ.getNow();
670             int latestIndex = _next == 0 ? 3 : (_next - 1 ) % _timestamps.length;
671             long last=_timestamps[latestIndex];
672             boolean hasRecentRequest = last != 0 && (now-last)<1000L;
673 
674             if (hasRecentRequest)
675                 reschedule();
676             else
677                 _rateTrackers.remove(_id);
678         }
679 
680         @Override
681         public String toString()
682         {
683             return "RateTracker/"+_id+"/"+_type;
684         }
685     }
686 
687     class FixedRateTracker extends RateTracker
688     {
689         public FixedRateTracker(String id, int type, int numRecentRequestsTracked)
690         {
691             super(id,type,numRecentRequestsTracked);
692         }
693 
694         @Override
695         public boolean isRateExceeded(long now)
696         {
697             // rate limit is never exceeded, but we keep track of the request timestamps
698             // so that we know whether there was recent activity on this tracker
699             // and whether it should be expired
700             synchronized (this)
701             {
702                 _timestamps[_next]=now;
703                 _next= (_next+1)%_timestamps.length;
704             }
705 
706             return false;
707         }
708 
709         @Override
710         public String toString()
711         {
712             return "Fixed"+super.toString();
713         }
714     }
715 }