View Javadoc

1   //
2   //  ========================================================================
3   //  Copyright (c) 1995-2016 Mort Bay Consulting Pty. Ltd.
4   //  ------------------------------------------------------------------------
5   //  All rights reserved. This program and the accompanying materials
6   //  are made available under the terms of the Eclipse Public License v1.0
7   //  and Apache License v2.0 which accompanies this distribution.
8   //
9   //      The Eclipse Public License is available at
10  //      http://www.eclipse.org/legal/epl-v10.html
11  //
12  //      The Apache License v2.0 is available at
13  //      http://www.opensource.org/licenses/apache2.0.php
14  //
15  //  You may elect to redistribute this code under either of these licenses.
16  //  ========================================================================
17  //
18  
19  package org.eclipse.jetty.servlets;
20  
21  import java.io.IOException;
22  import java.io.Serializable;
23  import java.util.ArrayList;
24  import java.util.Iterator;
25  import java.util.List;
26  import java.util.Queue;
27  import java.util.concurrent.ConcurrentHashMap;
28  import java.util.concurrent.ConcurrentLinkedQueue;
29  import java.util.concurrent.CopyOnWriteArrayList;
30  import java.util.concurrent.Semaphore;
31  import java.util.concurrent.TimeUnit;
32  import java.util.regex.Matcher;
33  import java.util.regex.Pattern;
34  
35  import javax.servlet.AsyncContext;
36  import javax.servlet.AsyncEvent;
37  import javax.servlet.AsyncListener;
38  import javax.servlet.DispatcherType;
39  import javax.servlet.Filter;
40  import javax.servlet.FilterChain;
41  import javax.servlet.FilterConfig;
42  import javax.servlet.ServletContext;
43  import javax.servlet.ServletException;
44  import javax.servlet.ServletRequest;
45  import javax.servlet.ServletResponse;
46  import javax.servlet.http.HttpServletRequest;
47  import javax.servlet.http.HttpServletResponse;
48  import javax.servlet.http.HttpSession;
49  import javax.servlet.http.HttpSessionActivationListener;
50  import javax.servlet.http.HttpSessionBindingEvent;
51  import javax.servlet.http.HttpSessionBindingListener;
52  import javax.servlet.http.HttpSessionEvent;
53  
54  import org.eclipse.jetty.server.handler.ContextHandler;
55  import org.eclipse.jetty.util.StringUtil;
56  import org.eclipse.jetty.util.annotation.ManagedAttribute;
57  import org.eclipse.jetty.util.annotation.ManagedObject;
58  import org.eclipse.jetty.util.annotation.ManagedOperation;
59  import org.eclipse.jetty.util.annotation.Name;
60  import org.eclipse.jetty.util.log.Log;
61  import org.eclipse.jetty.util.log.Logger;
62  import org.eclipse.jetty.util.thread.ScheduledExecutorScheduler;
63  import org.eclipse.jetty.util.thread.Scheduler;
64  
65  /**
66   * Denial of Service filter
67   * <p>
68   * This filter is useful for limiting
69   * exposure to abuse from request flooding, whether malicious, or as a result of
70   * a misconfigured client.
71   * <p>
72   * The filter keeps track of the number of requests from a connection per
73   * second. If a limit is exceeded, the request is either rejected, delayed, or
74   * throttled.
75   * <p>
76   * When a request is throttled, it is placed in a priority queue. Priority is
77   * given first to authenticated users and users with an HttpSession, then
78   * connections which can be identified by their IP addresses. Connections with
79   * no way to identify them are given lowest priority.
80   * <p>
81   * The {@link #extractUserId(ServletRequest request)} function should be
82   * implemented, in order to uniquely identify authenticated users.
83   * <p>
84   * The following init parameters control the behavior of the filter:
85   * <dl>
86   * <dt>maxRequestsPerSec</dt>
87   * <dd>the maximum number of requests from a connection per
88   * second. Requests in excess of this are first delayed,
89   * then throttled.</dd>
90   * <dt>delayMs</dt>
91   * <dd>is the delay given to all requests over the rate limit,
92   * before they are considered at all. -1 means just reject request,
93   * 0 means no delay, otherwise it is the delay.</dd>
94   * <dt>maxWaitMs</dt>
95   * <dd>how long to blocking wait for the throttle semaphore.</dd>
96   * <dt>throttledRequests</dt>
97   * <dd>is the number of requests over the rate limit able to be
98   * considered at once.</dd>
99   * <dt>throttleMs</dt>
100  * <dd>how long to async wait for semaphore.</dd>
101  * <dt>maxRequestMs</dt>
102  * <dd>how long to allow this request to run.</dd>
103  * <dt>maxIdleTrackerMs</dt>
104  * <dd>how long to keep track of request rates for a connection,
105  * before deciding that the user has gone away, and discarding it</dd>
106  * <dt>insertHeaders</dt>
107  * <dd>if true , insert the DoSFilter headers into the response. Defaults to true.</dd>
108  * <dt>trackSessions</dt>
109  * <dd>if true, usage rate is tracked by session if a session exists. Defaults to true.</dd>
110  * <dt>remotePort</dt>
111  * <dd>if true and session tracking is not used, then rate is tracked by IP+port (effectively connection). Defaults to false.</dd>
112  * <dt>ipWhitelist</dt>
113  * <dd>a comma-separated list of IP addresses that will not be rate limited</dd>
114  * <dt>managedAttr</dt>
115  * <dd>if set to true, then this servlet is set as a {@link ServletContext} attribute with the
116  * filter name as the attribute name.  This allows context external mechanism (eg JMX via {@link ContextHandler#MANAGED_ATTRIBUTES}) to
117  * manage the configuration of the filter.</dd>
118  * <dt>tooManyCode</dt>
119  * <dd>The status code to send if there are too many requests.  By default is 429 (too many requests), but 503 (Unavailable) is 
120  * another option</dd>
121  * </dl>
122  * <p>
123  * This filter should be configured for {@link DispatcherType#REQUEST} and {@link DispatcherType#ASYNC} and with 
124  * <code>&lt;async-supported&gt;true&lt;/async-supported&gt;</code>.
125  * </p>
126  */
127 @ManagedObject("limits exposure to abuse from request flooding, whether malicious, or as a result of a misconfigured client")
128 public class DoSFilter implements Filter
129 {
130     private static final Logger LOG = Log.getLogger(DoSFilter.class);
131 
132     private static final String IPv4_GROUP = "(\\d{1,3})";
133     private static final Pattern IPv4_PATTERN = Pattern.compile(IPv4_GROUP+"\\."+IPv4_GROUP+"\\."+IPv4_GROUP+"\\."+IPv4_GROUP);
134     private static final String IPv6_GROUP = "(\\p{XDigit}{1,4})";
135     private static final Pattern IPv6_PATTERN = Pattern.compile(IPv6_GROUP+":"+IPv6_GROUP+":"+IPv6_GROUP+":"+IPv6_GROUP+":"+IPv6_GROUP+":"+IPv6_GROUP+":"+IPv6_GROUP+":"+IPv6_GROUP);
136     private static final Pattern CIDR_PATTERN = Pattern.compile("([^/]+)/(\\d+)");
137 
138     private static final String __TRACKER = "DoSFilter.Tracker";
139     private static final String __THROTTLED = "DoSFilter.Throttled";
140 
141     private static final int __DEFAULT_MAX_REQUESTS_PER_SEC = 25;
142     private static final int __DEFAULT_DELAY_MS = 100;
143     private static final int __DEFAULT_THROTTLE = 5;
144     private static final int __DEFAULT_MAX_WAIT_MS = 50;
145     private static final long __DEFAULT_THROTTLE_MS = 30000L;
146     private static final long __DEFAULT_MAX_REQUEST_MS_INIT_PARAM = 30000L;
147     private static final long __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM = 30000L;
148 
149     static final String MANAGED_ATTR_INIT_PARAM = "managedAttr";
150     static final String MAX_REQUESTS_PER_S_INIT_PARAM = "maxRequestsPerSec";
151     static final String DELAY_MS_INIT_PARAM = "delayMs";
152     static final String THROTTLED_REQUESTS_INIT_PARAM = "throttledRequests";
153     static final String MAX_WAIT_INIT_PARAM = "maxWaitMs";
154     static final String THROTTLE_MS_INIT_PARAM = "throttleMs";
155     static final String MAX_REQUEST_MS_INIT_PARAM = "maxRequestMs";
156     static final String MAX_IDLE_TRACKER_MS_INIT_PARAM = "maxIdleTrackerMs";
157     static final String INSERT_HEADERS_INIT_PARAM = "insertHeaders";
158     static final String TRACK_SESSIONS_INIT_PARAM = "trackSessions";
159     static final String REMOTE_PORT_INIT_PARAM = "remotePort";
160     static final String IP_WHITELIST_INIT_PARAM = "ipWhitelist";
161     static final String ENABLED_INIT_PARAM = "enabled";
162     static final String TOO_MANY_CODE = "tooManyCode";
163 
164     private static final int USER_AUTH = 2;
165     private static final int USER_SESSION = 2;
166     private static final int USER_IP = 1;
167     private static final int USER_UNKNOWN = 0;
168 
169     private final String _suspended = "DoSFilter@" + Integer.toHexString(hashCode()) + ".SUSPENDED";
170     private final String _resumed = "DoSFilter@" + Integer.toHexString(hashCode()) + ".RESUMED";
171     private final ConcurrentHashMap<String, RateTracker> _rateTrackers = new ConcurrentHashMap<>();
172     private final List<String> _whitelist = new CopyOnWriteArrayList<>();
173     private int _tooManyCode;
174     private volatile long _delayMs;
175     private volatile long _throttleMs;
176     private volatile long _maxWaitMs;
177     private volatile long _maxRequestMs;
178     private volatile long _maxIdleTrackerMs;
179     private volatile boolean _insertHeaders;
180     private volatile boolean _trackSessions;
181     private volatile boolean _remotePort;
182     private volatile boolean _enabled;
183     private Semaphore _passes;
184     private volatile int _throttledRequests;
185     private volatile int _maxRequestsPerSec;
186     private Queue<AsyncContext>[] _queues;
187     private AsyncListener[] _listeners;
188     private Scheduler _scheduler;
189 
190     public void init(FilterConfig filterConfig) throws ServletException
191     {
192         _queues = new Queue[getMaxPriority() + 1];
193         _listeners = new AsyncListener[_queues.length];
194         for (int p = 0; p < _queues.length; p++)
195         {
196             _queues[p] = new ConcurrentLinkedQueue<>();
197             _listeners[p] = new DoSAsyncListener(p);
198         }
199 
200         _rateTrackers.clear();
201 
202         int maxRequests = __DEFAULT_MAX_REQUESTS_PER_SEC;
203         String parameter = filterConfig.getInitParameter(MAX_REQUESTS_PER_S_INIT_PARAM);
204         if (parameter != null)
205             maxRequests = Integer.parseInt(parameter);
206         setMaxRequestsPerSec(maxRequests);
207 
208         long delay = __DEFAULT_DELAY_MS;
209         parameter = filterConfig.getInitParameter(DELAY_MS_INIT_PARAM);
210         if (parameter != null)
211             delay = Long.parseLong(parameter);
212         setDelayMs(delay);
213 
214         int throttledRequests = __DEFAULT_THROTTLE;
215         parameter = filterConfig.getInitParameter(THROTTLED_REQUESTS_INIT_PARAM);
216         if (parameter != null)
217             throttledRequests = Integer.parseInt(parameter);
218         setThrottledRequests(throttledRequests);
219 
220         long maxWait = __DEFAULT_MAX_WAIT_MS;
221         parameter = filterConfig.getInitParameter(MAX_WAIT_INIT_PARAM);
222         if (parameter != null)
223             maxWait = Long.parseLong(parameter);
224         setMaxWaitMs(maxWait);
225 
226         long throttle = __DEFAULT_THROTTLE_MS;
227         parameter = filterConfig.getInitParameter(THROTTLE_MS_INIT_PARAM);
228         if (parameter != null)
229             throttle = Long.parseLong(parameter);
230         setThrottleMs(throttle);
231 
232         long maxRequestMs = __DEFAULT_MAX_REQUEST_MS_INIT_PARAM;
233         parameter = filterConfig.getInitParameter(MAX_REQUEST_MS_INIT_PARAM);
234         if (parameter != null)
235             maxRequestMs = Long.parseLong(parameter);
236         setMaxRequestMs(maxRequestMs);
237 
238         long maxIdleTrackerMs = __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM;
239         parameter = filterConfig.getInitParameter(MAX_IDLE_TRACKER_MS_INIT_PARAM);
240         if (parameter != null)
241             maxIdleTrackerMs = Long.parseLong(parameter);
242         setMaxIdleTrackerMs(maxIdleTrackerMs);
243 
244         String whiteList = "";
245         parameter = filterConfig.getInitParameter(IP_WHITELIST_INIT_PARAM);
246         if (parameter != null)
247             whiteList = parameter;
248         setWhitelist(whiteList);
249 
250         parameter = filterConfig.getInitParameter(INSERT_HEADERS_INIT_PARAM);
251         setInsertHeaders(parameter == null || Boolean.parseBoolean(parameter));
252 
253         parameter = filterConfig.getInitParameter(TRACK_SESSIONS_INIT_PARAM);
254         setTrackSessions(parameter == null || Boolean.parseBoolean(parameter));
255 
256         parameter = filterConfig.getInitParameter(REMOTE_PORT_INIT_PARAM);
257         setRemotePort(parameter != null && Boolean.parseBoolean(parameter));
258 
259         parameter = filterConfig.getInitParameter(ENABLED_INIT_PARAM);
260         setEnabled(parameter == null || Boolean.parseBoolean(parameter));
261         
262         parameter = filterConfig.getInitParameter(TOO_MANY_CODE);
263         setTooManyCode(parameter==null?429:Integer.parseInt(parameter));
264         
265         _scheduler = startScheduler();
266 
267         ServletContext context = filterConfig.getServletContext();
268         if (context != null && Boolean.parseBoolean(filterConfig.getInitParameter(MANAGED_ATTR_INIT_PARAM)))
269             context.setAttribute(filterConfig.getFilterName(), this);
270     }
271 
272     protected Scheduler startScheduler() throws ServletException
273     {
274         try
275         {
276             Scheduler result = new ScheduledExecutorScheduler();
277             result.start();
278             return result;
279         }
280         catch (Exception x)
281         {
282             throw new ServletException(x);
283         }
284     }
285 
286     public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterChain) throws IOException, ServletException
287     {
288         doFilter((HttpServletRequest)request, (HttpServletResponse)response, filterChain);
289     }
290 
291     protected void doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws IOException, ServletException
292     {
293         if (!isEnabled())
294         {
295             filterChain.doFilter(request, response);
296             return;
297         }
298 
299         // Look for the rate tracker for this request.
300         RateTracker tracker = (RateTracker)request.getAttribute(__TRACKER);
301         if (tracker == null)
302         {
303             // This is the first time we have seen this request.
304             if (LOG.isDebugEnabled())
305                 LOG.debug("Filtering {}", request);
306 
307             // Get a rate tracker associated with this request, and record one hit.
308             tracker = getRateTracker(request);
309 
310             // Calculate the rate and check it is over the allowed limit
311             final boolean overRateLimit = tracker.isRateExceeded(System.currentTimeMillis());
312 
313             // Pass it through if  we are not currently over the rate limit.
314             if (!overRateLimit)
315             {
316                 if (LOG.isDebugEnabled())
317                     LOG.debug("Allowing {}", request);
318                 doFilterChain(filterChain, request, response);
319                 return;
320             }
321 
322             // We are over the limit.
323 
324             // So either reject it, delay it or throttle it.
325             long delayMs = getDelayMs();
326             boolean insertHeaders = isInsertHeaders();
327             switch ((int)delayMs)
328             {
329                 case -1:
330                 {
331                     // Reject this request.
332                     LOG.warn("DOS ALERT: Request rejected ip={}, session={}, user={}", request.getRemoteAddr(), request.getRequestedSessionId(), request.getUserPrincipal());
333                     if (insertHeaders)
334                         response.addHeader("DoSFilter", "unavailable");
335                     response.sendError(getTooManyCode());
336                     return;
337                 }
338                 case 0:
339                 {
340                     // Fall through to throttle the request.
341                     LOG.warn("DOS ALERT: Request throttled ip={}, session={}, user={}", request.getRemoteAddr(), request.getRequestedSessionId(), request.getUserPrincipal());
342                     request.setAttribute(__TRACKER, tracker);
343                     break;
344                 }
345                 default:
346                 {
347                     // Insert a delay before throttling the request,
348                     // using the suspend+timeout mechanism of AsyncContext.
349                     LOG.warn("DOS ALERT: Request delayed={}ms, ip={}, session={}, user={}", delayMs, request.getRemoteAddr(), request.getRequestedSessionId(), request.getUserPrincipal());
350                     if (insertHeaders)
351                         response.addHeader("DoSFilter", "delayed");
352                     request.setAttribute(__TRACKER, tracker);
353                     AsyncContext asyncContext = request.startAsync();
354                     if (delayMs > 0)
355                         asyncContext.setTimeout(delayMs);
356                     asyncContext.addListener(new DoSTimeoutAsyncListener());
357                     return;
358                 }
359             }
360         }
361 
362         if (LOG.isDebugEnabled())
363             LOG.debug("Throttling {}", request);
364 
365         // Throttle the request.
366         boolean accepted = false;
367         try
368         {
369             // Check if we can afford to accept another request at this time.
370             accepted = _passes.tryAcquire(getMaxWaitMs(), TimeUnit.MILLISECONDS);
371             if (!accepted)
372             {
373                 // We were not accepted, so either we suspend to wait,
374                 // or if we were woken up we insist or we fail.
375                 Boolean throttled = (Boolean)request.getAttribute(__THROTTLED);
376                 long throttleMs = getThrottleMs();
377                 if (throttled != Boolean.TRUE && throttleMs > 0)
378                 {
379                     int priority = getPriority(request, tracker);
380                     request.setAttribute(__THROTTLED, Boolean.TRUE);
381                     if (isInsertHeaders())
382                         response.addHeader("DoSFilter", "throttled");
383                     AsyncContext asyncContext = request.startAsync();
384                     request.setAttribute(_suspended, Boolean.TRUE);
385                     if (throttleMs > 0)
386                         asyncContext.setTimeout(throttleMs);
387                     asyncContext.addListener(_listeners[priority]);
388                     _queues[priority].add(asyncContext);
389                     if (LOG.isDebugEnabled())
390                         LOG.debug("Throttled {}, {}ms", request, throttleMs);
391                     return;
392                 }
393 
394                 Boolean resumed = (Boolean)request.getAttribute(_resumed);
395                 if (resumed == Boolean.TRUE)
396                 {
397                     // We were resumed, we wait for the next pass.
398                     _passes.acquire();
399                     accepted = true;
400                 }
401             }
402 
403             // If we were accepted (either immediately or after throttle)...
404             if (accepted)
405             {
406                 // ...call the chain.
407                 if (LOG.isDebugEnabled())
408                     LOG.debug("Allowing {}", request);
409                 doFilterChain(filterChain, request, response);
410             }
411             else
412             {
413                 // ...otherwise fail the request.
414                 if (LOG.isDebugEnabled())
415                     LOG.debug("Rejecting {}", request);
416                 if (isInsertHeaders())
417                     response.addHeader("DoSFilter", "unavailable");
418                 response.sendError(getTooManyCode());
419             }
420         }
421         catch (InterruptedException e)
422         {
423             LOG.ignore(e);
424             response.sendError(getTooManyCode());
425         }
426         finally
427         {
428             if (accepted)
429             {
430                 try
431                 {
432                     // Wake up the next highest priority request.
433                     for (int p = _queues.length - 1; p >= 0; --p)
434                     {
435                         AsyncContext asyncContext = _queues[p].poll();
436                         if (asyncContext != null)
437                         {
438                             ServletRequest candidate = asyncContext.getRequest();
439                             Boolean suspended = (Boolean)candidate.getAttribute(_suspended);
440                             if (suspended == Boolean.TRUE)
441                             {
442                                 if (LOG.isDebugEnabled())
443                                     LOG.debug("Resuming {}", request);
444                                 candidate.setAttribute(_resumed, Boolean.TRUE);
445                                 asyncContext.dispatch();
446                                 break;
447                             }
448                         }
449                     }
450                 }
451                 finally
452                 {
453                     _passes.release();
454                 }
455             }
456         }
457     }
458 
459     protected void doFilterChain(FilterChain chain, final HttpServletRequest request, final HttpServletResponse response) throws IOException, ServletException
460     {
461         final Thread thread = Thread.currentThread();
462         Runnable requestTimeout = new Runnable()
463         {
464             @Override
465             public void run()
466             {
467                 closeConnection(request, response, thread);
468             }
469         };
470         Scheduler.Task task = _scheduler.schedule(requestTimeout, getMaxRequestMs(), TimeUnit.MILLISECONDS);
471         try
472         {
473             chain.doFilter(request, response);
474         }
475         finally
476         {
477             task.cancel();
478         }
479     }
480 
481     /**
482      * Takes drastic measures to return this response and stop this thread.
483      * Due to the way the connection is interrupted, may return mixed up headers.
484      *
485      * @param request  current request
486      * @param response current response, which must be stopped
487      * @param thread   the handling thread
488      */
489     protected void closeConnection(HttpServletRequest request, HttpServletResponse response, Thread thread)
490     {
491         // take drastic measures to return this response and stop this thread.
492         if (!response.isCommitted())
493         {
494             response.setHeader("Connection", "close");
495         }
496         try
497         {
498             try
499             {
500                 response.getWriter().close();
501             }
502             catch (IllegalStateException e)
503             {
504                 response.getOutputStream().close();
505             }
506         }
507         catch (IOException e)
508         {
509             LOG.warn(e);
510         }
511 
512         // interrupt the handling thread
513         thread.interrupt();
514     }
515 
516     /**
517      * Get priority for this request, based on user type
518      *
519      * @param request the current request
520      * @param tracker the rate tracker for this request
521      * @return the priority for this request
522      */
523     protected int getPriority(HttpServletRequest request, RateTracker tracker)
524     {
525         if (extractUserId(request) != null)
526             return USER_AUTH;
527         if (tracker != null)
528             return tracker.getType();
529         return USER_UNKNOWN;
530     }
531 
532     /**
533      * @return the maximum priority that we can assign to a request
534      */
535     protected int getMaxPriority()
536     {
537         return USER_AUTH;
538     }
539 
540     /**
541      * Return a request rate tracker associated with this connection; keeps
542      * track of this connection's request rate. If this is not the first request
543      * from this connection, return the existing object with the stored stats.
544      * If it is the first request, then create a new request tracker.
545      * <p>
546      * Assumes that each connection has an identifying characteristic, and goes
547      * through them in order, taking the first that matches: user id (logged
548      * in), session id, client IP address. Unidentifiable connections are lumped
549      * into one.
550      * <p>
551      * When a session expires, its rate tracker is automatically deleted.
552      *
553      * @param request the current request
554      * @return the request rate tracker for the current connection
555      */
556     public RateTracker getRateTracker(ServletRequest request)
557     {
558         HttpSession session = ((HttpServletRequest)request).getSession(false);
559 
560         String loadId = extractUserId(request);
561         final int type;
562         if (loadId != null)
563         {
564             type = USER_AUTH;
565         }
566         else
567         {
568             if (isTrackSessions() && session != null && !session.isNew())
569             {
570                 loadId = session.getId();
571                 type = USER_SESSION;
572             }
573             else
574             {
575                 loadId = isRemotePort() ? (request.getRemoteAddr() + request.getRemotePort()) : request.getRemoteAddr();
576                 type = USER_IP;
577             }
578         }
579 
580         RateTracker tracker = _rateTrackers.get(loadId);
581 
582         if (tracker == null)
583         {
584             boolean allowed = checkWhitelist(request.getRemoteAddr());
585             int maxRequestsPerSec = getMaxRequestsPerSec();
586             tracker = allowed ? new FixedRateTracker(loadId, type, maxRequestsPerSec)
587                     : new RateTracker(loadId, type, maxRequestsPerSec);
588             RateTracker existing = _rateTrackers.putIfAbsent(loadId, tracker);
589             if (existing != null)
590                 tracker = existing;
591 
592             if (type == USER_IP)
593             {
594                 // USER_IP expiration from _rateTrackers is handled by the _scheduler
595                 _scheduler.schedule(tracker, getMaxIdleTrackerMs(), TimeUnit.MILLISECONDS);
596             }
597             else if (session != null)
598             {
599                 // USER_SESSION expiration from _rateTrackers are handled by the HttpSessionBindingListener
600                 session.setAttribute(__TRACKER, tracker);
601             }
602         }
603 
604         return tracker;
605     }
606 
607     protected boolean checkWhitelist(String candidate)
608     {
609         for (String address : _whitelist)
610         {
611             if (address.contains("/"))
612             {
613                 if (subnetMatch(address, candidate))
614                     return true;
615             }
616             else
617             {
618                 if (address.equals(candidate))
619                     return true;
620             }
621         }
622         return false;
623     }
624 
625     @Deprecated
626     protected boolean checkWhitelist(List<String> whitelist, String candidate)
627     {
628         for (String address : whitelist)
629         {
630             if (address.contains("/"))
631             {
632                 if (subnetMatch(address, candidate))
633                     return true;
634             }
635             else
636             {
637                 if (address.equals(candidate))
638                     return true;
639             }
640         }
641         return false;
642     }
643 
644     protected boolean subnetMatch(String subnetAddress, String address)
645     {
646         Matcher cidrMatcher = CIDR_PATTERN.matcher(subnetAddress);
647         if (!cidrMatcher.matches())
648             return false;
649 
650         String subnet = cidrMatcher.group(1);
651         int prefix;
652         try
653         {
654             prefix = Integer.parseInt(cidrMatcher.group(2));
655         }
656         catch (NumberFormatException x)
657         {
658             LOG.info("Ignoring malformed CIDR address {}", subnetAddress);
659             return false;
660         }
661 
662         byte[] subnetBytes = addressToBytes(subnet);
663         if (subnetBytes == null)
664         {
665             LOG.info("Ignoring malformed CIDR address {}", subnetAddress);
666             return false;
667         }
668         byte[] addressBytes = addressToBytes(address);
669         if (addressBytes == null)
670         {
671             LOG.info("Ignoring malformed remote address {}", address);
672             return false;
673         }
674 
675         // Comparing IPv4 with IPv6 ?
676         int length = subnetBytes.length;
677         if (length != addressBytes.length)
678             return false;
679 
680         byte[] mask = prefixToBytes(prefix, length);
681 
682         for (int i = 0; i < length; ++i)
683         {
684             if ((subnetBytes[i] & mask[i]) != (addressBytes[i] & mask[i]))
685                 return false;
686         }
687 
688         return true;
689     }
690 
691     private byte[] addressToBytes(String address)
692     {
693         Matcher ipv4Matcher = IPv4_PATTERN.matcher(address);
694         if (ipv4Matcher.matches())
695         {
696             byte[] result = new byte[4];
697             for (int i = 0; i < result.length; ++i)
698                 result[i] = Integer.valueOf(ipv4Matcher.group(i + 1)).byteValue();
699             return result;
700         }
701         else
702         {
703             Matcher ipv6Matcher = IPv6_PATTERN.matcher(address);
704             if (ipv6Matcher.matches())
705             {
706                 byte[] result = new byte[16];
707                 for (int i = 0; i < result.length; i += 2)
708                 {
709                     int word = Integer.valueOf(ipv6Matcher.group(i / 2 + 1), 16);
710                     result[i] = (byte)((word & 0xFF00) >>> 8);
711                     result[i + 1] = (byte)(word & 0xFF);
712                 }
713                 return result;
714             }
715         }
716         return null;
717     }
718 
719     private byte[] prefixToBytes(int prefix, int length)
720     {
721         byte[] result = new byte[length];
722         int index = 0;
723         while (prefix / 8 > 0)
724         {
725             result[index] = -1;
726             prefix -= 8;
727             ++index;
728         }
729         
730         if (index == result.length)
731             return result;
732                
733         // Sets the _prefix_ most significant bits to 1
734         result[index] = (byte)~((1 << (8 - prefix)) - 1);
735         return result;
736     }
737 
738     public void destroy()
739     {
740         LOG.debug("Destroy {}",this);
741         stopScheduler();
742         _rateTrackers.clear();
743         _whitelist.clear();
744     }
745 
746     protected void stopScheduler()
747     {
748         try
749         {
750             _scheduler.stop();
751         }
752         catch (Exception x)
753         {
754             LOG.ignore(x);
755         }
756     }
757 
758     /**
759      * Returns the user id, used to track this connection.
760      * This SHOULD be overridden by subclasses.
761      *
762      * @param request the current request
763      * @return a unique user id, if logged in; otherwise null.
764      */
765     protected String extractUserId(ServletRequest request)
766     {
767         return null;
768     }
769 
770     /**
771      * Get maximum number of requests from a connection per
772      * second. Requests in excess of this are first delayed,
773      * then throttled.
774      *
775      * @return maximum number of requests
776      */
777     @ManagedAttribute("maximum number of requests allowed from a connection per second")
778     public int getMaxRequestsPerSec()
779     {
780         return _maxRequestsPerSec;
781     }
782 
783     /**
784      * Get maximum number of requests from a connection per
785      * second. Requests in excess of this are first delayed,
786      * then throttled.
787      *
788      * @param value maximum number of requests
789      */
790     public void setMaxRequestsPerSec(int value)
791     {
792         _maxRequestsPerSec = value;
793     }
794 
795     /**
796      * Get delay (in milliseconds) that is applied to all requests
797      * over the rate limit, before they are considered at all.
798      * @return the delay in milliseconds
799      */
800     @ManagedAttribute("delay applied to all requests over the rate limit (in ms)")
801     public long getDelayMs()
802     {
803         return _delayMs;
804     }
805 
806     /**
807      * Set delay (in milliseconds) that is applied to all requests
808      * over the rate limit, before they are considered at all.
809      *
810      * @param value delay (in milliseconds), 0 - no delay, -1 - reject request
811      */
812     public void setDelayMs(long value)
813     {
814         _delayMs = value;
815     }
816 
817     /**
818      * Get maximum amount of time (in milliseconds) the filter will
819      * blocking wait for the throttle semaphore.
820      *
821      * @return maximum wait time
822      */
823     @ManagedAttribute("maximum time the filter will block waiting throttled connections, (0 for no delay, -1 to reject requests)")
824     public long getMaxWaitMs()
825     {
826         return _maxWaitMs;
827     }
828 
829     /**
830      * Set maximum amount of time (in milliseconds) the filter will
831      * blocking wait for the throttle semaphore.
832      *
833      * @param value maximum wait time
834      */
835     public void setMaxWaitMs(long value)
836     {
837         _maxWaitMs = value;
838     }
839 
840     /**
841      * Get number of requests over the rate limit able to be
842      * considered at once.
843      *
844      * @return number of requests
845      */
846     @ManagedAttribute("number of requests over rate limit")
847     public int getThrottledRequests()
848     {
849         return _throttledRequests;
850     }
851 
852     /**
853      * Set number of requests over the rate limit able to be
854      * considered at once.
855      *
856      * @param value number of requests
857      */
858     public void setThrottledRequests(int value)
859     {
860         int permits = _passes == null ? 0 : _passes.availablePermits();
861         _passes = new Semaphore((value - _throttledRequests + permits), true);
862         _throttledRequests = value;
863     }
864 
865     /**
866      * Get amount of time (in milliseconds) to async wait for semaphore.
867      *
868      * @return wait time
869      */
870     @ManagedAttribute("amount of time to async wait for semaphore")
871     public long getThrottleMs()
872     {
873         return _throttleMs;
874     }
875 
876     /**
877      * Set amount of time (in milliseconds) to async wait for semaphore.
878      *
879      * @param value wait time
880      */
881     public void setThrottleMs(long value)
882     {
883         _throttleMs = value;
884     }
885 
886     /**
887      * Get maximum amount of time (in milliseconds) to allow
888      * the request to process.
889      *
890      * @return maximum processing time
891      */
892     @ManagedAttribute("maximum time to allow requests to process (in ms)")
893     public long getMaxRequestMs()
894     {
895         return _maxRequestMs;
896     }
897 
898     /**
899      * Set maximum amount of time (in milliseconds) to allow
900      * the request to process.
901      *
902      * @param value maximum processing time
903      */
904     public void setMaxRequestMs(long value)
905     {
906         _maxRequestMs = value;
907     }
908 
909     /**
910      * Get maximum amount of time (in milliseconds) to keep track
911      * of request rates for a connection, before deciding that
912      * the user has gone away, and discarding it.
913      *
914      * @return maximum tracking time
915      */
916     @ManagedAttribute("maximum time to track of request rates for connection before discarding")
917     public long getMaxIdleTrackerMs()
918     {
919         return _maxIdleTrackerMs;
920     }
921 
922     /**
923      * Set maximum amount of time (in milliseconds) to keep track
924      * of request rates for a connection, before deciding that
925      * the user has gone away, and discarding it.
926      *
927      * @param value maximum tracking time
928      */
929     public void setMaxIdleTrackerMs(long value)
930     {
931         _maxIdleTrackerMs = value;
932     }
933 
934     /**
935      * Check flag to insert the DoSFilter headers into the response.
936      *
937      * @return value of the flag
938      */
939     @ManagedAttribute("inser DoSFilter headers in response")
940     public boolean isInsertHeaders()
941     {
942         return _insertHeaders;
943     }
944 
945     /**
946      * Set flag to insert the DoSFilter headers into the response.
947      *
948      * @param value value of the flag
949      */
950     public void setInsertHeaders(boolean value)
951     {
952         _insertHeaders = value;
953     }
954 
955     /**
956      * Get flag to have usage rate tracked by session if a session exists.
957      *
958      * @return value of the flag
959      */
960     @ManagedAttribute("usage rate is tracked by session if one exists")
961     public boolean isTrackSessions()
962     {
963         return _trackSessions;
964     }
965 
966     /**
967      * Set flag to have usage rate tracked by session if a session exists.
968      *
969      * @param value value of the flag
970      */
971     public void setTrackSessions(boolean value)
972     {
973         _trackSessions = value;
974     }
975 
976     /**
977      * Get flag to have usage rate tracked by IP+port (effectively connection)
978      * if session tracking is not used.
979      *
980      * @return value of the flag
981      */
982     @ManagedAttribute("usage rate is tracked by IP+port is session tracking not used")
983     public boolean isRemotePort()
984     {
985         return _remotePort;
986     }
987 
988     /**
989      * Set flag to have usage rate tracked by IP+port (effectively connection)
990      * if session tracking is not used.
991      *
992      * @param value value of the flag
993      */
994     public void setRemotePort(boolean value)
995     {
996         _remotePort = value;
997     }
998 
999     /**
1000      * @return whether this filter is enabled
1001      */
1002     @ManagedAttribute("whether this filter is enabled")
1003     public boolean isEnabled()
1004     {
1005         return _enabled;
1006     }
1007 
1008     /**
1009      * @param enabled whether this filter is enabled
1010      */
1011     public void setEnabled(boolean enabled)
1012     {
1013         _enabled = enabled;
1014     }
1015     
1016     public int getTooManyCode()
1017     {
1018         return _tooManyCode;
1019     }
1020 
1021     public void setTooManyCode(int tooManyCode)
1022     {
1023         _tooManyCode = tooManyCode;
1024     }
1025 
1026     /**
1027      * Get a list of IP addresses that will not be rate limited.
1028      *
1029      * @return comma-separated whitelist
1030      */
1031     @ManagedAttribute("list of IPs that will not be rate limited")
1032     public String getWhitelist()
1033     {
1034         StringBuilder result = new StringBuilder();
1035         for (Iterator<String> iterator = _whitelist.iterator(); iterator.hasNext();)
1036         {
1037             String address = iterator.next();
1038             result.append(address);
1039             if (iterator.hasNext())
1040                 result.append(",");
1041         }
1042         return result.toString();
1043     }
1044 
1045     /**
1046      * Set a list of IP addresses that will not be rate limited.
1047      *
1048      * @param commaSeparatedList comma-separated whitelist
1049      */
1050     public void setWhitelist(String commaSeparatedList)
1051     {
1052         List<String> result = new ArrayList<>();
1053         for (String address : StringUtil.csvSplit(commaSeparatedList))
1054             addWhitelistAddress(result, address);
1055         clearWhitelist();
1056         _whitelist.addAll(result);
1057         LOG.debug("Whitelisted IP addresses: {}", result);
1058     }
1059 
1060     /**
1061      * Clears the list of whitelisted IP addresses
1062      */
1063     @ManagedOperation("clears the list of IP addresses that will not be rate limited")
1064     public void clearWhitelist()
1065     {
1066         _whitelist.clear();
1067     }
1068 
1069     /**
1070      * Adds the given IP address, either in the form of a dotted decimal notation A.B.C.D
1071      * or in the CIDR notation A.B.C.D/M, to the list of whitelisted IP addresses.
1072      *
1073      * @param address the address to add
1074      * @return whether the address was added to the list
1075      * @see #removeWhitelistAddress(String)
1076      */
1077     @ManagedOperation("adds an IP address that will not be rate limited")
1078     public boolean addWhitelistAddress(@Name("address") String address)
1079     {
1080         return addWhitelistAddress(_whitelist, address);
1081     }
1082 
1083     private boolean addWhitelistAddress(List<String> list, String address)
1084     {
1085         address = address.trim();
1086         return address.length() > 0 && list.add(address);
1087     }
1088 
1089     /**
1090      * Removes the given address from the list of whitelisted IP addresses.
1091      *
1092      * @param address the address to remove
1093      * @return whether the address was removed from the list
1094      * @see #addWhitelistAddress(String)
1095      */
1096     @ManagedOperation("removes an IP address that will not be rate limited")
1097     public boolean removeWhitelistAddress(@Name("address") String address)
1098     {
1099         return _whitelist.remove(address);
1100     }
1101 
1102     /**
1103      * A RateTracker is associated with a connection, and stores request rate
1104      * data.
1105      */
1106     class RateTracker implements Runnable, HttpSessionBindingListener, HttpSessionActivationListener, Serializable
1107     {
1108         private static final long serialVersionUID = 3534663738034577872L;
1109 
1110         protected final String _id;
1111         protected final int _type;
1112         protected final long[] _timestamps;
1113         protected int _next;
1114 
1115         public RateTracker(String id, int type, int maxRequestsPerSecond)
1116         {
1117             _id = id;
1118             _type = type;
1119             _timestamps = new long[maxRequestsPerSecond];
1120             _next = 0;
1121         }
1122 
1123         /**
1124          * @param now the time now (in milliseconds)
1125          * @return the current calculated request rate over the last second
1126          */
1127         public boolean isRateExceeded(long now)
1128         {
1129             final long last;
1130             synchronized (this)
1131             {
1132                 last = _timestamps[_next];
1133                 _timestamps[_next] = now;
1134                 _next = (_next + 1) % _timestamps.length;
1135             }
1136 
1137             return last != 0 && (now - last) < 1000L;
1138         }
1139 
1140         public String getId()
1141         {
1142             return _id;
1143         }
1144 
1145         public int getType()
1146         {
1147             return _type;
1148         }
1149 
1150         public void valueBound(HttpSessionBindingEvent event)
1151         {
1152             if (LOG.isDebugEnabled())
1153                 LOG.debug("Value bound: {}", getId());
1154         }
1155 
1156         public void valueUnbound(HttpSessionBindingEvent event)
1157         {
1158             //take the tracker out of the list of trackers
1159             _rateTrackers.remove(_id);
1160             if (LOG.isDebugEnabled())
1161                 LOG.debug("Tracker removed: {}", getId());
1162         }
1163 
1164         public void sessionWillPassivate(HttpSessionEvent se)
1165         {
1166             //take the tracker of the list of trackers (if its still there)
1167             _rateTrackers.remove(_id);
1168         }
1169 
1170         public void sessionDidActivate(HttpSessionEvent se)
1171         {
1172             RateTracker tracker = (RateTracker)se.getSession().getAttribute(__TRACKER);
1173             if (tracker!=null)
1174                 _rateTrackers.put(tracker.getId(),tracker);
1175         }
1176 
1177         @Override
1178         public void run()
1179         {
1180             int latestIndex = _next == 0 ? (_timestamps.length - 1) : (_next - 1);
1181             long last = _timestamps[latestIndex];
1182             boolean hasRecentRequest = last != 0 && (System.currentTimeMillis() - last) < 1000L;
1183 
1184             if (hasRecentRequest)
1185                 _scheduler.schedule(this, getMaxIdleTrackerMs(), TimeUnit.MILLISECONDS);
1186             else
1187                 _rateTrackers.remove(_id);
1188         }
1189 
1190         @Override
1191         public String toString()
1192         {
1193             return "RateTracker/" + _id + "/" + _type;
1194         }
1195     }
1196 
1197     class FixedRateTracker extends RateTracker
1198     {
1199         public FixedRateTracker(String id, int type, int numRecentRequestsTracked)
1200         {
1201             super(id, type, numRecentRequestsTracked);
1202         }
1203 
1204         @Override
1205         public boolean isRateExceeded(long now)
1206         {
1207             // rate limit is never exceeded, but we keep track of the request timestamps
1208             // so that we know whether there was recent activity on this tracker
1209             // and whether it should be expired
1210             synchronized (this)
1211             {
1212                 _timestamps[_next] = now;
1213                 _next = (_next + 1) % _timestamps.length;
1214             }
1215 
1216             return false;
1217         }
1218 
1219         @Override
1220         public String toString()
1221         {
1222             return "Fixed" + super.toString();
1223         }
1224     }
1225 
1226     private class DoSTimeoutAsyncListener implements AsyncListener
1227     {
1228         @Override
1229         public void onStartAsync(AsyncEvent event) throws IOException
1230         {
1231         }
1232 
1233         @Override
1234         public void onComplete(AsyncEvent event) throws IOException
1235         {
1236         }
1237 
1238         @Override
1239         public void onTimeout(AsyncEvent event) throws IOException
1240         {
1241             event.getAsyncContext().dispatch();
1242         }
1243 
1244         @Override
1245         public void onError(AsyncEvent event) throws IOException
1246         {
1247         }
1248     }
1249 
1250     private class DoSAsyncListener extends DoSTimeoutAsyncListener
1251     {
1252         private final int priority;
1253 
1254         public DoSAsyncListener(int priority)
1255         {
1256             this.priority = priority;
1257         }
1258 
1259         @Override
1260         public void onTimeout(AsyncEvent event) throws IOException
1261         {
1262             _queues[priority].remove(event.getAsyncContext());
1263             super.onTimeout(event);
1264         }
1265     }
1266 }