View Javadoc

1   //
2   //  ========================================================================
3   //  Copyright (c) 1995-2013 Mort Bay Consulting Pty. Ltd.
4   //  ------------------------------------------------------------------------
5   //  All rights reserved. This program and the accompanying materials
6   //  are made available under the terms of the Eclipse Public License v1.0
7   //  and Apache License v2.0 which accompanies this distribution.
8   //
9   //      The Eclipse Public License is available at
10  //      http://www.eclipse.org/legal/epl-v10.html
11  //
12  //      The Apache License v2.0 is available at
13  //      http://www.opensource.org/licenses/apache2.0.php
14  //
15  //  You may elect to redistribute this code under either of these licenses.
16  //  ========================================================================
17  //
18  
19  package org.eclipse.jetty.servlets;
20  
21  import java.io.IOException;
22  import java.io.Serializable;
23  import java.util.ArrayList;
24  import java.util.Iterator;
25  import java.util.List;
26  import java.util.Queue;
27  import java.util.concurrent.ConcurrentHashMap;
28  import java.util.concurrent.ConcurrentLinkedQueue;
29  import java.util.concurrent.CopyOnWriteArrayList;
30  import java.util.concurrent.Semaphore;
31  import java.util.concurrent.TimeUnit;
32  import java.util.regex.Matcher;
33  import java.util.regex.Pattern;
34  import javax.servlet.Filter;
35  import javax.servlet.FilterChain;
36  import javax.servlet.FilterConfig;
37  import javax.servlet.ServletContext;
38  import javax.servlet.ServletException;
39  import javax.servlet.ServletRequest;
40  import javax.servlet.ServletResponse;
41  import javax.servlet.http.HttpServletRequest;
42  import javax.servlet.http.HttpServletResponse;
43  import javax.servlet.http.HttpSession;
44  import javax.servlet.http.HttpSessionActivationListener;
45  import javax.servlet.http.HttpSessionBindingEvent;
46  import javax.servlet.http.HttpSessionBindingListener;
47  import javax.servlet.http.HttpSessionEvent;
48  
49  import org.eclipse.jetty.continuation.Continuation;
50  import org.eclipse.jetty.continuation.ContinuationListener;
51  import org.eclipse.jetty.continuation.ContinuationSupport;
52  import org.eclipse.jetty.server.handler.ContextHandler;
53  import org.eclipse.jetty.util.log.Log;
54  import org.eclipse.jetty.util.log.Logger;
55  import org.eclipse.jetty.util.thread.Timeout;
56  
57  /**
58   * Denial of Service filter
59   * <p/>
60   * <p>
61   * This filter is useful for limiting
62   * exposure to abuse from request flooding, whether malicious, or as a result of
63   * a misconfigured client.
64   * <p>
65   * The filter keeps track of the number of requests from a connection per
66   * second. If a limit is exceeded, the request is either rejected, delayed, or
67   * throttled.
68   * <p>
69   * When a request is throttled, it is placed in a priority queue. Priority is
70   * given first to authenticated users and users with an HttpSession, then
71   * connections which can be identified by their IP addresses. Connections with
72   * no way to identify them are given lowest priority.
73   * <p>
74   * The {@link #extractUserId(ServletRequest request)} function should be
75   * implemented, in order to uniquely identify authenticated users.
76   * <p>
77   * The following init parameters control the behavior of the filter:<dl>
78   * <p/>
79   * <dt>maxRequestsPerSec</dt>
80   * <dd>the maximum number of requests from a connection per
81   * second. Requests in excess of this are first delayed,
82   * then throttled.</dd>
83   * <p/>
84   * <dt>delayMs</dt>
85   * <dd>is the delay given to all requests over the rate limit,
86   * before they are considered at all. -1 means just reject request,
87   * 0 means no delay, otherwise it is the delay.</dd>
88   * <p/>
89   * <dt>maxWaitMs</dt>
90   * <dd>how long to blocking wait for the throttle semaphore.</dd>
91   * <p/>
92   * <dt>throttledRequests</dt>
93   * <dd>is the number of requests over the rate limit able to be
94   * considered at once.</dd>
95   * <p/>
96   * <dt>throttleMs</dt>
97   * <dd>how long to async wait for semaphore.</dd>
98   * <p/>
99   * <dt>maxRequestMs</dt>
100  * <dd>how long to allow this request to run.</dd>
101  * <p/>
102  * <dt>maxIdleTrackerMs</dt>
103  * <dd>how long to keep track of request rates for a connection,
104  * before deciding that the user has gone away, and discarding it</dd>
105  * <p/>
106  * <dt>insertHeaders</dt>
107  * <dd>if true , insert the DoSFilter headers into the response. Defaults to true.</dd>
108  * <p/>
109  * <dt>trackSessions</dt>
110  * <dd>if true, usage rate is tracked by session if a session exists. Defaults to true.</dd>
111  * <p/>
112  * <dt>remotePort</dt>
113  * <dd>if true and session tracking is not used, then rate is tracked by IP+port (effectively connection). Defaults to false.</dd>
114  * <p/>
115  * <dt>ipWhitelist</dt>
116  * <dd>a comma-separated list of IP addresses that will not be rate limited</dd>
117  * <p/>
118  * <dt>managedAttr</dt>
119  * <dd>if set to true, then this servlet is set as a {@link ServletContext} attribute with the
120  * filter name as the attribute name.  This allows context external mechanism (eg JMX via {@link ContextHandler#MANAGED_ATTRIBUTES}) to
121  * manage the configuration of the filter.</dd>
122  * </dl>
123  * </p>
124  */
125 public class DoSFilter implements Filter
126 {
127     private static final Logger LOG = Log.getLogger(DoSFilter.class);
128 
129     private static final String IPv4_GROUP = "(\\d{1,3})";
130     private static final Pattern IPv4_PATTERN = Pattern.compile(IPv4_GROUP+"\\."+IPv4_GROUP+"\\."+IPv4_GROUP+"\\."+IPv4_GROUP);
131     private static final String IPv6_GROUP = "(\\p{XDigit}{1,4})";
132     private static final Pattern IPv6_PATTERN = Pattern.compile(IPv6_GROUP+":"+IPv6_GROUP+":"+IPv6_GROUP+":"+IPv6_GROUP+":"+IPv6_GROUP+":"+IPv6_GROUP+":"+IPv6_GROUP+":"+IPv6_GROUP);
133     private static final Pattern CIDR_PATTERN = Pattern.compile("([^/]+)/(\\d+)");
134 
135     private static final String __TRACKER = "DoSFilter.Tracker";
136     private static final String __THROTTLED = "DoSFilter.Throttled";
137 
138     private static final int __DEFAULT_MAX_REQUESTS_PER_SEC = 25;
139     private static final int __DEFAULT_DELAY_MS = 100;
140     private static final int __DEFAULT_THROTTLE = 5;
141     private static final int __DEFAULT_MAX_WAIT_MS = 50;
142     private static final long __DEFAULT_THROTTLE_MS = 30000L;
143     private static final long __DEFAULT_MAX_REQUEST_MS_INIT_PARAM = 30000L;
144     private static final long __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM = 30000L;
145 
146     static final String MANAGED_ATTR_INIT_PARAM = "managedAttr";
147     static final String MAX_REQUESTS_PER_S_INIT_PARAM = "maxRequestsPerSec";
148     static final String DELAY_MS_INIT_PARAM = "delayMs";
149     static final String THROTTLED_REQUESTS_INIT_PARAM = "throttledRequests";
150     static final String MAX_WAIT_INIT_PARAM = "maxWaitMs";
151     static final String THROTTLE_MS_INIT_PARAM = "throttleMs";
152     static final String MAX_REQUEST_MS_INIT_PARAM = "maxRequestMs";
153     static final String MAX_IDLE_TRACKER_MS_INIT_PARAM = "maxIdleTrackerMs";
154     static final String INSERT_HEADERS_INIT_PARAM = "insertHeaders";
155     static final String TRACK_SESSIONS_INIT_PARAM = "trackSessions";
156     static final String REMOTE_PORT_INIT_PARAM = "remotePort";
157     static final String IP_WHITELIST_INIT_PARAM = "ipWhitelist";
158     static final String ENABLED_INIT_PARAM = "enabled";
159 
160     private static final int USER_AUTH = 2;
161     private static final int USER_SESSION = 2;
162     private static final int USER_IP = 1;
163     private static final int USER_UNKNOWN = 0;
164 
165     private ServletContext _context;
166     private volatile long _delayMs;
167     private volatile long _throttleMs;
168     private volatile long _maxWaitMs;
169     private volatile long _maxRequestMs;
170     private volatile long _maxIdleTrackerMs;
171     private volatile boolean _insertHeaders;
172     private volatile boolean _trackSessions;
173     private volatile boolean _remotePort;
174     private volatile boolean _enabled;
175     private Semaphore _passes;
176     private volatile int _throttledRequests;
177     private volatile int _maxRequestsPerSec;
178     private Queue<Continuation>[] _queue;
179     private ContinuationListener[] _listeners;
180     private final ConcurrentHashMap<String, RateTracker> _rateTrackers = new ConcurrentHashMap<String, RateTracker>();
181     private final List<String> _whitelist = new CopyOnWriteArrayList<String>();
182     private final Timeout _requestTimeoutQ = new Timeout();
183     private final Timeout _trackerTimeoutQ = new Timeout();
184     private Thread _timerThread;
185     private volatile boolean _running;
186 
187     public void init(FilterConfig filterConfig)
188     {
189         _context = filterConfig.getServletContext();
190 
191         _queue = new Queue[getMaxPriority() + 1];
192         _listeners = new ContinuationListener[getMaxPriority() + 1];
193         for (int p = 0; p < _queue.length; p++)
194         {
195             _queue[p] = new ConcurrentLinkedQueue<Continuation>();
196 
197             final int priority = p;
198             _listeners[p] = new ContinuationListener()
199             {
200                 public void onComplete(Continuation continuation)
201                 {
202                 }
203 
204                 public void onTimeout(Continuation continuation)
205                 {
206                     _queue[priority].remove(continuation);
207                 }
208             };
209         }
210 
211         _rateTrackers.clear();
212 
213         int maxRequests = __DEFAULT_MAX_REQUESTS_PER_SEC;
214         String parameter = filterConfig.getInitParameter(MAX_REQUESTS_PER_S_INIT_PARAM);
215         if (parameter != null)
216             maxRequests = Integer.parseInt(parameter);
217         setMaxRequestsPerSec(maxRequests);
218 
219         long delay = __DEFAULT_DELAY_MS;
220         parameter = filterConfig.getInitParameter(DELAY_MS_INIT_PARAM);
221         if (parameter != null)
222             delay = Long.parseLong(parameter);
223         setDelayMs(delay);
224 
225         int throttledRequests = __DEFAULT_THROTTLE;
226         parameter = filterConfig.getInitParameter(THROTTLED_REQUESTS_INIT_PARAM);
227         if (parameter != null)
228             throttledRequests = Integer.parseInt(parameter);
229         setThrottledRequests(throttledRequests);
230 
231         long maxWait = __DEFAULT_MAX_WAIT_MS;
232         parameter = filterConfig.getInitParameter(MAX_WAIT_INIT_PARAM);
233         if (parameter != null)
234             maxWait = Long.parseLong(parameter);
235         setMaxWaitMs(maxWait);
236 
237         long throttle = __DEFAULT_THROTTLE_MS;
238         parameter = filterConfig.getInitParameter(THROTTLE_MS_INIT_PARAM);
239         if (parameter != null)
240             throttle = Long.parseLong(parameter);
241         setThrottleMs(throttle);
242 
243         long maxRequestMs = __DEFAULT_MAX_REQUEST_MS_INIT_PARAM;
244         parameter = filterConfig.getInitParameter(MAX_REQUEST_MS_INIT_PARAM);
245         if (parameter != null)
246             maxRequestMs = Long.parseLong(parameter);
247         setMaxRequestMs(maxRequestMs);
248 
249         long maxIdleTrackerMs = __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM;
250         parameter = filterConfig.getInitParameter(MAX_IDLE_TRACKER_MS_INIT_PARAM);
251         if (parameter != null)
252             maxIdleTrackerMs = Long.parseLong(parameter);
253         setMaxIdleTrackerMs(maxIdleTrackerMs);
254 
255         String whiteList = "";
256         parameter = filterConfig.getInitParameter(IP_WHITELIST_INIT_PARAM);
257         if (parameter != null)
258             whiteList = parameter;
259         setWhitelist(whiteList);
260 
261         parameter = filterConfig.getInitParameter(INSERT_HEADERS_INIT_PARAM);
262         setInsertHeaders(parameter == null || Boolean.parseBoolean(parameter));
263 
264         parameter = filterConfig.getInitParameter(TRACK_SESSIONS_INIT_PARAM);
265         setTrackSessions(parameter == null || Boolean.parseBoolean(parameter));
266 
267         parameter = filterConfig.getInitParameter(REMOTE_PORT_INIT_PARAM);
268         setRemotePort(parameter != null && Boolean.parseBoolean(parameter));
269 
270         parameter = filterConfig.getInitParameter(ENABLED_INIT_PARAM);
271         setEnabled(parameter == null || Boolean.parseBoolean(parameter));
272 
273         _requestTimeoutQ.setNow();
274         _requestTimeoutQ.setDuration(_maxRequestMs);
275 
276         _trackerTimeoutQ.setNow();
277         _trackerTimeoutQ.setDuration(_maxIdleTrackerMs);
278 
279         _running = true;
280         _timerThread = (new Thread()
281         {
282             public void run()
283             {
284                 try
285                 {
286                     while (_running)
287                     {
288                         long now = _requestTimeoutQ.setNow();
289                         _requestTimeoutQ.tick();
290                         _trackerTimeoutQ.setNow(now);
291                         _trackerTimeoutQ.tick();
292                         try
293                         {
294                             Thread.sleep(100);
295                         }
296                         catch (InterruptedException e)
297                         {
298                             LOG.ignore(e);
299                         }
300                     }
301                 }
302                 finally
303                 {
304                     LOG.debug("DoSFilter timer exited");
305                 }
306             }
307         });
308         _timerThread.start();
309 
310         if (_context != null && Boolean.parseBoolean(filterConfig.getInitParameter(MANAGED_ATTR_INIT_PARAM)))
311             _context.setAttribute(filterConfig.getFilterName(), this);
312     }
313 
314     public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterChain) throws IOException, ServletException
315     {
316         doFilter((HttpServletRequest)request, (HttpServletResponse)response, filterChain);
317     }
318 
319     protected void doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws IOException, ServletException
320     {
321         if (!isEnabled())
322         {
323             filterChain.doFilter(request, response);
324             return;
325         }
326 
327         final long now = _requestTimeoutQ.getNow();
328 
329         // Look for the rate tracker for this request
330         RateTracker tracker = (RateTracker)request.getAttribute(__TRACKER);
331 
332         if (tracker == null)
333         {
334             // This is the first time we have seen this request.
335 
336             // get a rate tracker associated with this request, and record one hit
337             tracker = getRateTracker(request);
338 
339             // Calculate the rate and check it is over the allowed limit
340             final boolean overRateLimit = tracker.isRateExceeded(now);
341 
342             // pass it through if  we are not currently over the rate limit
343             if (!overRateLimit)
344             {
345                 doFilterChain(filterChain, request, response);
346                 return;
347             }
348 
349             // We are over the limit.
350             LOG.warn("DOS ALERT: ip=" + request.getRemoteAddr() + ",session=" + request.getRequestedSessionId() + ",user=" + request.getUserPrincipal());
351 
352             // So either reject it, delay it or throttle it
353             long delayMs = getDelayMs();
354             boolean insertHeaders = isInsertHeaders();
355             switch ((int)delayMs)
356             {
357                 case -1:
358                 {
359                     // Reject this request
360                     if (insertHeaders)
361                         response.addHeader("DoSFilter", "unavailable");
362                     response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
363                     return;
364                 }
365                 case 0:
366                 {
367                     // fall through to throttle code
368                     request.setAttribute(__TRACKER, tracker);
369                     break;
370                 }
371                 default:
372                 {
373                     // insert a delay before throttling the request
374                     if (insertHeaders)
375                         response.addHeader("DoSFilter", "delayed");
376                     Continuation continuation = ContinuationSupport.getContinuation(request);
377                     request.setAttribute(__TRACKER, tracker);
378                     if (delayMs > 0)
379                         continuation.setTimeout(delayMs);
380                     continuation.suspend();
381                     return;
382                 }
383             }
384         }
385 
386         // Throttle the request
387         boolean accepted = false;
388         try
389         {
390             // check if we can afford to accept another request at this time
391             accepted = _passes.tryAcquire(getMaxWaitMs(), TimeUnit.MILLISECONDS);
392 
393             if (!accepted)
394             {
395                 // we were not accepted, so either we suspend to wait,or if we were woken up we insist or we fail
396                 final Continuation continuation = ContinuationSupport.getContinuation(request);
397 
398                 Boolean throttled = (Boolean)request.getAttribute(__THROTTLED);
399                 long throttleMs = getThrottleMs();
400                 if (throttled != Boolean.TRUE && throttleMs > 0)
401                 {
402                     int priority = getPriority(request, tracker);
403                     request.setAttribute(__THROTTLED, Boolean.TRUE);
404                     if (isInsertHeaders())
405                         response.addHeader("DoSFilter", "throttled");
406                     if (throttleMs > 0)
407                         continuation.setTimeout(throttleMs);
408                     continuation.suspend();
409 
410                     continuation.addContinuationListener(_listeners[priority]);
411                     _queue[priority].add(continuation);
412                     return;
413                 }
414                 // else were we resumed?
415                 else if (request.getAttribute("javax.servlet.resumed") == Boolean.TRUE)
416                 {
417                     // we were resumed and somebody stole our pass, so we wait for the next one.
418                     _passes.acquire();
419                     accepted = true;
420                 }
421             }
422 
423             // if we were accepted (either immediately or after throttle)
424             if (accepted)
425                 // call the chain
426                 doFilterChain(filterChain, request, response);
427             else
428             {
429                 // fail the request
430                 if (isInsertHeaders())
431                     response.addHeader("DoSFilter", "unavailable");
432                 response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
433             }
434         }
435         catch (InterruptedException e)
436         {
437             _context.log("DoS", e);
438             response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
439         }
440         finally
441         {
442             if (accepted)
443             {
444                 // wake up the next highest priority request.
445                 for (int p = _queue.length; p-- > 0; )
446                 {
447                     Continuation continuation = _queue[p].poll();
448                     if (continuation != null && continuation.isSuspended())
449                     {
450                         continuation.resume();
451                         break;
452                     }
453                 }
454                 _passes.release();
455             }
456         }
457     }
458 
459     protected void doFilterChain(FilterChain chain, final HttpServletRequest request, final HttpServletResponse response) throws IOException, ServletException
460     {
461         final Thread thread = Thread.currentThread();
462 
463         final Timeout.Task requestTimeout = new Timeout.Task()
464         {
465             public void expired()
466             {
467                 closeConnection(request, response, thread);
468             }
469         };
470 
471         try
472         {
473             _requestTimeoutQ.schedule(requestTimeout);
474             chain.doFilter(request, response);
475         }
476         finally
477         {
478             requestTimeout.cancel();
479         }
480     }
481 
482     /**
483      * Takes drastic measures to return this response and stop this thread.
484      * Due to the way the connection is interrupted, may return mixed up headers.
485      *
486      * @param request  current request
487      * @param response current response, which must be stopped
488      * @param thread   the handling thread
489      */
490     protected void closeConnection(HttpServletRequest request, HttpServletResponse response, Thread thread)
491     {
492         // take drastic measures to return this response and stop this thread.
493         if (!response.isCommitted())
494         {
495             response.setHeader("Connection", "close");
496         }
497         try
498         {
499             try
500             {
501                 response.getWriter().close();
502             }
503             catch (IllegalStateException e)
504             {
505                 response.getOutputStream().close();
506             }
507         }
508         catch (IOException e)
509         {
510             LOG.warn(e);
511         }
512 
513         // interrupt the handling thread
514         thread.interrupt();
515     }
516 
517     /**
518      * Get priority for this request, based on user type
519      *
520      * @param request the current request
521      * @param tracker the rate tracker for this request
522      * @return the priority for this request
523      */
524     protected int getPriority(HttpServletRequest request, RateTracker tracker)
525     {
526         if (extractUserId(request) != null)
527             return USER_AUTH;
528         if (tracker != null)
529             return tracker.getType();
530         return USER_UNKNOWN;
531     }
532 
533     /**
534      * @return the maximum priority that we can assign to a request
535      */
536     protected int getMaxPriority()
537     {
538         return USER_AUTH;
539     }
540 
541     /**
542      * Return a request rate tracker associated with this connection; keeps
543      * track of this connection's request rate. If this is not the first request
544      * from this connection, return the existing object with the stored stats.
545      * If it is the first request, then create a new request tracker.
546      * <p/>
547      * Assumes that each connection has an identifying characteristic, and goes
548      * through them in order, taking the first that matches: user id (logged
549      * in), session id, client IP address. Unidentifiable connections are lumped
550      * into one.
551      * <p/>
552      * When a session expires, its rate tracker is automatically deleted.
553      *
554      * @param request the current request
555      * @return the request rate tracker for the current connection
556      */
557     public RateTracker getRateTracker(ServletRequest request)
558     {
559         HttpSession session = ((HttpServletRequest)request).getSession(false);
560 
561         String loadId = extractUserId(request);
562         final int type;
563         if (loadId != null)
564         {
565             type = USER_AUTH;
566         }
567         else
568         {
569             if (_trackSessions && session != null && !session.isNew())
570             {
571                 loadId = session.getId();
572                 type = USER_SESSION;
573             }
574             else
575             {
576                 loadId = _remotePort ? (request.getRemoteAddr() + request.getRemotePort()) : request.getRemoteAddr();
577                 type = USER_IP;
578             }
579         }
580 
581         RateTracker tracker = _rateTrackers.get(loadId);
582 
583         if (tracker == null)
584         {
585             boolean allowed = checkWhitelist(_whitelist, request.getRemoteAddr());
586             tracker = allowed ? new FixedRateTracker(loadId, type, _maxRequestsPerSec)
587                     : new RateTracker(loadId, type, _maxRequestsPerSec);
588             RateTracker existing = _rateTrackers.putIfAbsent(loadId, tracker);
589             if (existing != null)
590                 tracker = existing;
591 
592             if (type == USER_IP)
593             {
594                 // USER_IP expiration from _rateTrackers is handled by the _trackerTimeoutQ
595                 _trackerTimeoutQ.schedule(tracker);
596             }
597             else if (session != null)
598             {
599                 // USER_SESSION expiration from _rateTrackers are handled by the HttpSessionBindingListener
600                 session.setAttribute(__TRACKER, tracker);
601             }
602         }
603 
604         return tracker;
605     }
606 
607     protected boolean checkWhitelist(List<String> whitelist, String candidate)
608     {
609         for (String address : whitelist)
610         {
611             if (address.contains("/"))
612             {
613                 if (subnetMatch(address, candidate))
614                     return true;
615             }
616             else
617             {
618                 if (address.equals(candidate))
619                     return true;
620             }
621         }
622         return false;
623     }
624 
625     protected boolean subnetMatch(String subnetAddress, String address)
626     {
627         Matcher cidrMatcher = CIDR_PATTERN.matcher(subnetAddress);
628         if (!cidrMatcher.matches())
629             return false;
630 
631         String subnet = cidrMatcher.group(1);
632         int prefix;
633         try
634         {
635             prefix = Integer.parseInt(cidrMatcher.group(2));
636         }
637         catch (NumberFormatException x)
638         {
639             LOG.info("Ignoring malformed CIDR address {}", subnetAddress);
640             return false;
641         }
642 
643         byte[] subnetBytes = addressToBytes(subnet);
644         if (subnetBytes == null)
645         {
646             LOG.info("Ignoring malformed CIDR address {}", subnetAddress);
647             return false;
648         }
649         byte[] addressBytes = addressToBytes(address);
650         if (addressBytes == null)
651         {
652             LOG.info("Ignoring malformed remote address {}", address);
653             return false;
654         }
655 
656         // Comparing IPv4 with IPv6 ?
657         int length = subnetBytes.length;
658         if (length != addressBytes.length)
659             return false;
660 
661         byte[] mask = prefixToBytes(prefix, length);
662 
663         for (int i = 0; i < length; ++i)
664         {
665             if ((subnetBytes[i] & mask[i]) != (addressBytes[i] & mask[i]))
666                 return false;
667         }
668 
669         return true;
670     }
671 
672     private byte[] addressToBytes(String address)
673     {
674         Matcher ipv4Matcher = IPv4_PATTERN.matcher(address);
675         if (ipv4Matcher.matches())
676         {
677             byte[] result = new byte[4];
678             for (int i = 0; i < result.length; ++i)
679                 result[i] = Integer.valueOf(ipv4Matcher.group(i + 1)).byteValue();
680             return result;
681         }
682         else
683         {
684             Matcher ipv6Matcher = IPv6_PATTERN.matcher(address);
685             if (ipv6Matcher.matches())
686             {
687                 byte[] result = new byte[16];
688                 for (int i = 0; i < result.length; i += 2)
689                 {
690                     int word = Integer.valueOf(ipv6Matcher.group(i / 2 + 1), 16);
691                     result[i] = (byte)((word & 0xFF00) >>> 8);
692                     result[i + 1] = (byte)(word & 0xFF);
693                 }
694                 return result;
695             }
696         }
697         return null;
698     }
699 
700     private byte[] prefixToBytes(int prefix, int length)
701     {
702         byte[] result = new byte[length];
703         int index = 0;
704         while (prefix / 8 > 0)
705         {
706             result[index] = -1;
707             prefix -= 8;
708             ++index;
709         }
710         // Sets the _prefix_ most significant bits to 1
711         result[index] = (byte)~((1 << (8 - prefix)) - 1);
712         return result;
713     }
714 
715     public void destroy()
716     {
717         _running = false;
718         _timerThread.interrupt();
719         _requestTimeoutQ.cancelAll();
720         _trackerTimeoutQ.cancelAll();
721         _rateTrackers.clear();
722         _whitelist.clear();
723     }
724 
725     /**
726      * Returns the user id, used to track this connection.
727      * This SHOULD be overridden by subclasses.
728      *
729      * @param request the current request
730      * @return a unique user id, if logged in; otherwise null.
731      */
732     protected String extractUserId(ServletRequest request)
733     {
734         return null;
735     }
736 
737     /**
738      * Get maximum number of requests from a connection per
739      * second. Requests in excess of this are first delayed,
740      * then throttled.
741      *
742      * @return maximum number of requests
743      */
744     public int getMaxRequestsPerSec()
745     {
746         return _maxRequestsPerSec;
747     }
748 
749     /**
750      * Get maximum number of requests from a connection per
751      * second. Requests in excess of this are first delayed,
752      * then throttled.
753      *
754      * @param value maximum number of requests
755      */
756     public void setMaxRequestsPerSec(int value)
757     {
758         _maxRequestsPerSec = value;
759     }
760 
761     /**
762      * Get delay (in milliseconds) that is applied to all requests
763      * over the rate limit, before they are considered at all.
764      */
765     public long getDelayMs()
766     {
767         return _delayMs;
768     }
769 
770     /**
771      * Set delay (in milliseconds) that is applied to all requests
772      * over the rate limit, before they are considered at all.
773      *
774      * @param value delay (in milliseconds), 0 - no delay, -1 - reject request
775      */
776     public void setDelayMs(long value)
777     {
778         _delayMs = value;
779     }
780 
781     /**
782      * Get maximum amount of time (in milliseconds) the filter will
783      * blocking wait for the throttle semaphore.
784      *
785      * @return maximum wait time
786      */
787     public long getMaxWaitMs()
788     {
789         return _maxWaitMs;
790     }
791 
792     /**
793      * Set maximum amount of time (in milliseconds) the filter will
794      * blocking wait for the throttle semaphore.
795      *
796      * @param value maximum wait time
797      */
798     public void setMaxWaitMs(long value)
799     {
800         _maxWaitMs = value;
801     }
802 
803     /**
804      * Get number of requests over the rate limit able to be
805      * considered at once.
806      *
807      * @return number of requests
808      */
809     public int getThrottledRequests()
810     {
811         return _throttledRequests;
812     }
813 
814     /**
815      * Set number of requests over the rate limit able to be
816      * considered at once.
817      *
818      * @param value number of requests
819      */
820     public void setThrottledRequests(int value)
821     {
822         int permits = _passes == null ? 0 : _passes.availablePermits();
823         _passes = new Semaphore((value - _throttledRequests + permits), true);
824         _throttledRequests = value;
825     }
826 
827     /**
828      * Get amount of time (in milliseconds) to async wait for semaphore.
829      *
830      * @return wait time
831      */
832     public long getThrottleMs()
833     {
834         return _throttleMs;
835     }
836 
837     /**
838      * Set amount of time (in milliseconds) to async wait for semaphore.
839      *
840      * @param value wait time
841      */
842     public void setThrottleMs(long value)
843     {
844         _throttleMs = value;
845     }
846 
847     /**
848      * Get maximum amount of time (in milliseconds) to allow
849      * the request to process.
850      *
851      * @return maximum processing time
852      */
853     public long getMaxRequestMs()
854     {
855         return _maxRequestMs;
856     }
857 
858     /**
859      * Set maximum amount of time (in milliseconds) to allow
860      * the request to process.
861      *
862      * @param value maximum processing time
863      */
864     public void setMaxRequestMs(long value)
865     {
866         _maxRequestMs = value;
867     }
868 
869     /**
870      * Get maximum amount of time (in milliseconds) to keep track
871      * of request rates for a connection, before deciding that
872      * the user has gone away, and discarding it.
873      *
874      * @return maximum tracking time
875      */
876     public long getMaxIdleTrackerMs()
877     {
878         return _maxIdleTrackerMs;
879     }
880 
881     /**
882      * Set maximum amount of time (in milliseconds) to keep track
883      * of request rates for a connection, before deciding that
884      * the user has gone away, and discarding it.
885      *
886      * @param value maximum tracking time
887      */
888     public void setMaxIdleTrackerMs(long value)
889     {
890         _maxIdleTrackerMs = value;
891     }
892 
893     /**
894      * Check flag to insert the DoSFilter headers into the response.
895      *
896      * @return value of the flag
897      */
898     public boolean isInsertHeaders()
899     {
900         return _insertHeaders;
901     }
902 
903     /**
904      * Set flag to insert the DoSFilter headers into the response.
905      *
906      * @param value value of the flag
907      */
908     public void setInsertHeaders(boolean value)
909     {
910         _insertHeaders = value;
911     }
912 
913     /**
914      * Get flag to have usage rate tracked by session if a session exists.
915      *
916      * @return value of the flag
917      */
918     public boolean isTrackSessions()
919     {
920         return _trackSessions;
921     }
922 
923     /**
924      * Set flag to have usage rate tracked by session if a session exists.
925      *
926      * @param value value of the flag
927      */
928     public void setTrackSessions(boolean value)
929     {
930         _trackSessions = value;
931     }
932 
933     /**
934      * Get flag to have usage rate tracked by IP+port (effectively connection)
935      * if session tracking is not used.
936      *
937      * @return value of the flag
938      */
939     public boolean isRemotePort()
940     {
941         return _remotePort;
942     }
943 
944     /**
945      * Set flag to have usage rate tracked by IP+port (effectively connection)
946      * if session tracking is not used.
947      *
948      * @param value value of the flag
949      */
950     public void setRemotePort(boolean value)
951     {
952         _remotePort = value;
953     }
954 
955     /**
956      * @return whether this filter is enabled
957      */
958     public boolean isEnabled()
959     {
960         return _enabled;
961     }
962 
963     /**
964      * @param enabled whether this filter is enabled
965      */
966     public void setEnabled(boolean enabled)
967     {
968         _enabled = enabled;
969     }
970 
971     /**
972      * Get a list of IP addresses that will not be rate limited.
973      *
974      * @return comma-separated whitelist
975      */
976     public String getWhitelist()
977     {
978         StringBuilder result = new StringBuilder();
979         for (Iterator<String> iterator = _whitelist.iterator(); iterator.hasNext();)
980         {
981             String address = iterator.next();
982             result.append(address);
983             if (iterator.hasNext())
984                 result.append(",");
985         }
986         return result.toString();
987     }
988 
989     /**
990      * Set a list of IP addresses that will not be rate limited.
991      *
992      * @param value comma-separated whitelist
993      */
994     public void setWhitelist(String value)
995     {
996         List<String> result = new ArrayList<String>();
997         for (String address : value.split(","))
998             addWhitelistAddress(result, address);
999         _whitelist.clear();
1000         _whitelist.addAll(result);
1001         LOG.debug("Whitelisted IP addresses: {}", result);
1002     }
1003 
1004     public void clearWhitelist()
1005     {
1006         _whitelist.clear();
1007     }
1008 
1009     public boolean addWhitelistAddress(String address)
1010     {
1011         return addWhitelistAddress(_whitelist, address);
1012     }
1013 
1014     private boolean addWhitelistAddress(List<String> list, String address)
1015     {
1016         address = address.trim();
1017         return address.length() > 0 && list.add(address);
1018     }
1019 
1020     public boolean removeWhitelistAddress(String address)
1021     {
1022         return _whitelist.remove(address);
1023     }
1024 
1025     /**
1026      * A RateTracker is associated with a connection, and stores request rate
1027      * data.
1028      */
1029     class RateTracker extends Timeout.Task implements HttpSessionBindingListener, HttpSessionActivationListener, Serializable
1030     {
1031         private static final long serialVersionUID = 3534663738034577872L;
1032 
1033         transient protected final String _id;
1034         transient protected final int _type;
1035         transient protected final long[] _timestamps;
1036         transient protected int _next;
1037 
1038         public RateTracker(String id, int type, int maxRequestsPerSecond)
1039         {
1040             _id = id;
1041             _type = type;
1042             _timestamps = new long[maxRequestsPerSecond];
1043             _next = 0;
1044         }
1045 
1046         /**
1047          * @return the current calculated request rate over the last second
1048          */
1049         public boolean isRateExceeded(long now)
1050         {
1051             final long last;
1052             synchronized (this)
1053             {
1054                 last = _timestamps[_next];
1055                 _timestamps[_next] = now;
1056                 _next = (_next + 1) % _timestamps.length;
1057             }
1058 
1059             return last != 0 && (now - last) < 1000L;
1060         }
1061 
1062         public String getId()
1063         {
1064             return _id;
1065         }
1066 
1067         public int getType()
1068         {
1069             return _type;
1070         }
1071 
1072         public void valueBound(HttpSessionBindingEvent event)
1073         {
1074             if (LOG.isDebugEnabled())
1075                 LOG.debug("Value bound: {}", getId());
1076         }
1077 
1078         public void valueUnbound(HttpSessionBindingEvent event)
1079         {
1080             //take the tracker out of the list of trackers
1081             _rateTrackers.remove(_id);
1082             if (LOG.isDebugEnabled())
1083                 LOG.debug("Tracker removed: {}", getId());
1084         }
1085 
1086         public void sessionWillPassivate(HttpSessionEvent se)
1087         {
1088             //take the tracker of the list of trackers (if its still there)
1089             //and ensure that we take ourselves out of the session so we are not saved
1090             _rateTrackers.remove(_id);
1091             se.getSession().removeAttribute(__TRACKER);
1092             if (LOG.isDebugEnabled()) LOG.debug("Value removed: {}", getId());
1093         }
1094 
1095         public void sessionDidActivate(HttpSessionEvent se)
1096         {
1097             LOG.warn("Unexpected session activation");
1098         }
1099 
1100         public void expired()
1101         {
1102             long now = _trackerTimeoutQ.getNow();
1103             int latestIndex = _next == 0 ? (_timestamps.length - 1) : (_next - 1);
1104             long last = _timestamps[latestIndex];
1105             boolean hasRecentRequest = last != 0 && (now - last) < 1000L;
1106 
1107             if (hasRecentRequest)
1108                 reschedule();
1109             else
1110                 _rateTrackers.remove(_id);
1111         }
1112 
1113         @Override
1114         public String toString()
1115         {
1116             return "RateTracker/" + _id + "/" + _type;
1117         }
1118     }
1119 
1120     class FixedRateTracker extends RateTracker
1121     {
1122         public FixedRateTracker(String id, int type, int numRecentRequestsTracked)
1123         {
1124             super(id, type, numRecentRequestsTracked);
1125         }
1126 
1127         @Override
1128         public boolean isRateExceeded(long now)
1129         {
1130             // rate limit is never exceeded, but we keep track of the request timestamps
1131             // so that we know whether there was recent activity on this tracker
1132             // and whether it should be expired
1133             synchronized (this)
1134             {
1135                 _timestamps[_next] = now;
1136                 _next = (_next + 1) % _timestamps.length;
1137             }
1138 
1139             return false;
1140         }
1141 
1142         @Override
1143         public String toString()
1144         {
1145             return "Fixed" + super.toString();
1146         }
1147     }
1148 }