1
2
3
4
5
6
7
8
9
10
11
12
13
14 package org.eclipse.jetty.servlets;
15
16 import java.io.IOException;
17 import java.io.Serializable;
18 import java.util.HashSet;
19 import java.util.Map;
20 import java.util.Queue;
21 import java.util.StringTokenizer;
22 import java.util.concurrent.ConcurrentHashMap;
23 import java.util.concurrent.ConcurrentLinkedQueue;
24 import java.util.concurrent.Semaphore;
25 import java.util.concurrent.TimeUnit;
26
27 import javax.servlet.Filter;
28 import javax.servlet.FilterChain;
29 import javax.servlet.FilterConfig;
30 import javax.servlet.ServletContext;
31 import javax.servlet.ServletException;
32 import javax.servlet.ServletRequest;
33 import javax.servlet.ServletResponse;
34 import javax.servlet.http.HttpServletRequest;
35 import javax.servlet.http.HttpServletResponse;
36 import javax.servlet.http.HttpSession;
37 import javax.servlet.http.HttpSessionActivationListener;
38 import javax.servlet.http.HttpSessionBindingEvent;
39 import javax.servlet.http.HttpSessionBindingListener;
40 import javax.servlet.http.HttpSessionEvent;
41
42 import org.eclipse.jetty.continuation.Continuation;
43 import org.eclipse.jetty.continuation.ContinuationListener;
44 import org.eclipse.jetty.continuation.ContinuationSupport;
45 import org.eclipse.jetty.server.handler.ContextHandler;
46 import org.eclipse.jetty.util.log.Log;
47 import org.eclipse.jetty.util.log.Logger;
48 import org.eclipse.jetty.util.thread.Timeout;
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119 public class DoSFilter implements Filter
120 {
121 private static final Logger LOG = Log.getLogger(DoSFilter.class);
122
123 final static String __TRACKER = "DoSFilter.Tracker";
124 final static String __THROTTLED = "DoSFilter.Throttled";
125
126 final static int __DEFAULT_MAX_REQUESTS_PER_SEC = 25;
127 final static int __DEFAULT_DELAY_MS = 100;
128 final static int __DEFAULT_THROTTLE = 5;
129 final static int __DEFAULT_WAIT_MS=50;
130 final static long __DEFAULT_THROTTLE_MS = 30000L;
131 final static long __DEFAULT_MAX_REQUEST_MS_INIT_PARAM=30000L;
132 final static long __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM=30000L;
133
134 final static String MANAGED_ATTR_INIT_PARAM="managedAttr";
135 final static String MAX_REQUESTS_PER_S_INIT_PARAM = "maxRequestsPerSec";
136 final static String DELAY_MS_INIT_PARAM = "delayMs";
137 final static String THROTTLED_REQUESTS_INIT_PARAM = "throttledRequests";
138 final static String MAX_WAIT_INIT_PARAM="maxWaitMs";
139 final static String THROTTLE_MS_INIT_PARAM = "throttleMs";
140 final static String MAX_REQUEST_MS_INIT_PARAM="maxRequestMs";
141 final static String MAX_IDLE_TRACKER_MS_INIT_PARAM="maxIdleTrackerMs";
142 final static String INSERT_HEADERS_INIT_PARAM="insertHeaders";
143 final static String TRACK_SESSIONS_INIT_PARAM="trackSessions";
144 final static String REMOTE_PORT_INIT_PARAM="remotePort";
145 final static String IP_WHITELIST_INIT_PARAM="ipWhitelist";
146
147 final static int USER_AUTH = 2;
148 final static int USER_SESSION = 2;
149 final static int USER_IP = 1;
150 final static int USER_UNKNOWN = 0;
151
152 ServletContext _context;
153
154 protected String _name;
155 protected long _delayMs;
156 protected long _throttleMs;
157 protected long _maxWaitMs;
158 protected long _maxRequestMs;
159 protected long _maxIdleTrackerMs;
160 protected boolean _insertHeaders;
161 protected boolean _trackSessions;
162 protected boolean _remotePort;
163 protected int _throttledRequests;
164 protected Semaphore _passes;
165 protected Queue<Continuation>[] _queue;
166 protected ContinuationListener[] _listener;
167
168 protected int _maxRequestsPerSec;
169 protected final ConcurrentHashMap<String, RateTracker> _rateTrackers=new ConcurrentHashMap<String, RateTracker>();
170 protected String _whitelistStr;
171 private final HashSet<String> _whitelist = new HashSet<String>();
172
173 private final Timeout _requestTimeoutQ = new Timeout();
174 private final Timeout _trackerTimeoutQ = new Timeout();
175
176 private Thread _timerThread;
177 private volatile boolean _running;
178
179 public void init(FilterConfig filterConfig)
180 {
181 _context = filterConfig.getServletContext();
182
183 _queue = new Queue[getMaxPriority() + 1];
184 _listener = new ContinuationListener[getMaxPriority() + 1];
185 for (int p = 0; p < _queue.length; p++)
186 {
187 _queue[p] = new ConcurrentLinkedQueue<Continuation>();
188
189 final int priority=p;
190 _listener[p] = new ContinuationListener()
191 {
192 public void onComplete(Continuation continuation)
193 {
194 }
195
196 public void onTimeout(Continuation continuation)
197 {
198 _queue[priority].remove(continuation);
199 }
200 };
201 }
202
203 _rateTrackers.clear();
204
205 int baseRateLimit = __DEFAULT_MAX_REQUESTS_PER_SEC;
206 if (filterConfig.getInitParameter(MAX_REQUESTS_PER_S_INIT_PARAM) != null)
207 baseRateLimit = Integer.parseInt(filterConfig.getInitParameter(MAX_REQUESTS_PER_S_INIT_PARAM));
208 _maxRequestsPerSec = baseRateLimit;
209
210 long delay = __DEFAULT_DELAY_MS;
211 if (filterConfig.getInitParameter(DELAY_MS_INIT_PARAM) != null)
212 delay = Integer.parseInt(filterConfig.getInitParameter(DELAY_MS_INIT_PARAM));
213 _delayMs = delay;
214
215 int throttledRequests = __DEFAULT_THROTTLE;
216 if (filterConfig.getInitParameter(THROTTLED_REQUESTS_INIT_PARAM) != null)
217 throttledRequests = Integer.parseInt(filterConfig.getInitParameter(THROTTLED_REQUESTS_INIT_PARAM));
218 _passes = new Semaphore(throttledRequests,true);
219 _throttledRequests = throttledRequests;
220
221 long wait = __DEFAULT_WAIT_MS;
222 if (filterConfig.getInitParameter(MAX_WAIT_INIT_PARAM) != null)
223 wait = Integer.parseInt(filterConfig.getInitParameter(MAX_WAIT_INIT_PARAM));
224 _maxWaitMs = wait;
225
226 long suspend = __DEFAULT_THROTTLE_MS;
227 if (filterConfig.getInitParameter(THROTTLE_MS_INIT_PARAM) != null)
228 suspend = Integer.parseInt(filterConfig.getInitParameter(THROTTLE_MS_INIT_PARAM));
229 _throttleMs = suspend;
230
231 long maxRequestMs = __DEFAULT_MAX_REQUEST_MS_INIT_PARAM;
232 if (filterConfig.getInitParameter(MAX_REQUEST_MS_INIT_PARAM) != null )
233 maxRequestMs = Long.parseLong(filterConfig.getInitParameter(MAX_REQUEST_MS_INIT_PARAM));
234 _maxRequestMs = maxRequestMs;
235
236 long maxIdleTrackerMs = __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM;
237 if (filterConfig.getInitParameter(MAX_IDLE_TRACKER_MS_INIT_PARAM) != null )
238 maxIdleTrackerMs = Long.parseLong(filterConfig.getInitParameter(MAX_IDLE_TRACKER_MS_INIT_PARAM));
239 _maxIdleTrackerMs = maxIdleTrackerMs;
240
241 _whitelistStr = "";
242 if (filterConfig.getInitParameter(IP_WHITELIST_INIT_PARAM) !=null )
243 _whitelistStr = filterConfig.getInitParameter(IP_WHITELIST_INIT_PARAM);
244 initWhitelist();
245
246 String tmp = filterConfig.getInitParameter(INSERT_HEADERS_INIT_PARAM);
247 _insertHeaders = tmp==null || Boolean.parseBoolean(tmp);
248
249 tmp = filterConfig.getInitParameter(TRACK_SESSIONS_INIT_PARAM);
250 _trackSessions = tmp==null || Boolean.parseBoolean(tmp);
251
252 tmp = filterConfig.getInitParameter(REMOTE_PORT_INIT_PARAM);
253 _remotePort = tmp!=null&& Boolean.parseBoolean(tmp);
254
255 _requestTimeoutQ.setNow();
256 _requestTimeoutQ.setDuration(_maxRequestMs);
257
258 _trackerTimeoutQ.setNow();
259 _trackerTimeoutQ.setDuration(_maxIdleTrackerMs);
260
261 _running=true;
262 _timerThread = (new Thread()
263 {
264 public void run()
265 {
266 try
267 {
268 while (_running)
269 {
270 long now;
271 synchronized (_requestTimeoutQ)
272 {
273 now = _requestTimeoutQ.setNow();
274 _requestTimeoutQ.tick();
275 }
276 synchronized (_trackerTimeoutQ)
277 {
278 _trackerTimeoutQ.setNow(now);
279 _trackerTimeoutQ.tick();
280 }
281 try
282 {
283 Thread.sleep(100);
284 }
285 catch (InterruptedException e)
286 {
287 LOG.ignore(e);
288 }
289 }
290 }
291 finally
292 {
293 LOG.info("DoSFilter timer exited");
294 }
295 }
296 });
297 _timerThread.start();
298
299 if (_context!=null && Boolean.parseBoolean(filterConfig.getInitParameter(MANAGED_ATTR_INIT_PARAM)))
300 _context.setAttribute(filterConfig.getFilterName(),this);
301 }
302
303
304 public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterchain) throws IOException, ServletException
305 {
306 final HttpServletRequest srequest = (HttpServletRequest)request;
307 final HttpServletResponse sresponse = (HttpServletResponse)response;
308
309 final long now=_requestTimeoutQ.getNow();
310
311
312 RateTracker tracker = (RateTracker)request.getAttribute(__TRACKER);
313
314 if (tracker==null)
315 {
316
317
318
319 tracker = getRateTracker(request);
320
321
322 final boolean overRateLimit = tracker.isRateExceeded(now);
323
324
325 if (!overRateLimit)
326 {
327 doFilterChain(filterchain,srequest,sresponse);
328 return;
329 }
330
331
332 LOG.warn("DOS ALERT: ip="+srequest.getRemoteAddr()+",session="+srequest.getRequestedSessionId()+",user="+srequest.getUserPrincipal());
333
334
335 switch((int)_delayMs)
336 {
337 case -1:
338 {
339
340 if (_insertHeaders)
341 ((HttpServletResponse)response).addHeader("DoSFilter","unavailable");
342 ((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
343 return;
344 }
345 case 0:
346 {
347
348 request.setAttribute(__TRACKER,tracker);
349 break;
350 }
351 default:
352 {
353
354 if (_insertHeaders)
355 ((HttpServletResponse)response).addHeader("DoSFilter","delayed");
356 Continuation continuation = ContinuationSupport.getContinuation(request);
357 request.setAttribute(__TRACKER,tracker);
358 if (_delayMs > 0)
359 continuation.setTimeout(_delayMs);
360 continuation.suspend();
361 return;
362 }
363 }
364 }
365
366
367 boolean accepted = false;
368 try
369 {
370
371 accepted = _passes.tryAcquire(_maxWaitMs,TimeUnit.MILLISECONDS);
372
373 if (!accepted)
374 {
375
376 final Continuation continuation = ContinuationSupport.getContinuation(request);
377
378 Boolean throttled = (Boolean)request.getAttribute(__THROTTLED);
379 if (throttled!=Boolean.TRUE && _throttleMs>0)
380 {
381 int priority = getPriority(request,tracker);
382 request.setAttribute(__THROTTLED,Boolean.TRUE);
383 if (_insertHeaders)
384 ((HttpServletResponse)response).addHeader("DoSFilter","throttled");
385 if (_throttleMs > 0)
386 continuation.setTimeout(_throttleMs);
387 continuation.suspend();
388
389 continuation.addContinuationListener(_listener[priority]);
390 _queue[priority].add(continuation);
391 return;
392 }
393
394 else if (request.getAttribute("javax.servlet.resumed")==Boolean.TRUE)
395 {
396
397 _passes.acquire();
398 accepted = true;
399 }
400 }
401
402
403 if (accepted)
404
405 doFilterChain(filterchain,srequest,sresponse);
406 else
407 {
408
409 if (_insertHeaders)
410 ((HttpServletResponse)response).addHeader("DoSFilter","unavailable");
411 ((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
412 }
413 }
414 catch (InterruptedException e)
415 {
416 _context.log("DoS",e);
417 ((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
418 }
419 finally
420 {
421 if (accepted)
422 {
423
424 for (int p = _queue.length; p-- > 0;)
425 {
426 Continuation continuation = _queue[p].poll();
427 if (continuation != null && continuation.isSuspended())
428 {
429 continuation.resume();
430 break;
431 }
432 }
433 _passes.release();
434 }
435 }
436 }
437
438
439
440
441
442
443
444
445 protected void doFilterChain(FilterChain chain, final HttpServletRequest request, final HttpServletResponse response)
446 throws IOException, ServletException
447 {
448 final Thread thread=Thread.currentThread();
449
450 final Timeout.Task requestTimeout = new Timeout.Task()
451 {
452 public void expired()
453 {
454 closeConnection(request, response, thread);
455 }
456 };
457
458 try
459 {
460 synchronized (_requestTimeoutQ)
461 {
462 _requestTimeoutQ.schedule(requestTimeout);
463 }
464 chain.doFilter(request,response);
465 }
466 finally
467 {
468 synchronized (_requestTimeoutQ)
469 {
470 requestTimeout.cancel();
471 }
472 }
473 }
474
475
476
477
478
479
480
481
482 protected void closeConnection(HttpServletRequest request, HttpServletResponse response, Thread thread)
483 {
484
485 if( !response.isCommitted() )
486 {
487 response.setHeader("Connection", "close");
488 }
489 try
490 {
491 try
492 {
493 response.getWriter().close();
494 }
495 catch (IllegalStateException e)
496 {
497 response.getOutputStream().close();
498 }
499 }
500 catch (IOException e)
501 {
502 LOG.warn(e);
503 }
504
505
506 thread.interrupt();
507 }
508
509
510
511
512
513
514
515
516 protected int getPriority(ServletRequest request, RateTracker tracker)
517 {
518 if (extractUserId(request)!=null)
519 return USER_AUTH;
520 if (tracker!=null)
521 return tracker.getType();
522 return USER_UNKNOWN;
523 }
524
525
526
527
528 protected int getMaxPriority()
529 {
530 return USER_AUTH;
531 }
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549 public RateTracker getRateTracker(ServletRequest request)
550 {
551 HttpServletRequest srequest = (HttpServletRequest)request;
552 HttpSession session=srequest.getSession(false);
553
554 String loadId = extractUserId(request);
555 final int type;
556 if (loadId != null)
557 {
558 type = USER_AUTH;
559 }
560 else
561 {
562 if (_trackSessions && session!=null && !session.isNew())
563 {
564 loadId=session.getId();
565 type = USER_SESSION;
566 }
567 else
568 {
569 loadId = _remotePort?(request.getRemoteAddr()+request.getRemotePort()):request.getRemoteAddr();
570 type = USER_IP;
571 }
572 }
573
574 RateTracker tracker=_rateTrackers.get(loadId);
575
576 if (tracker==null)
577 {
578 RateTracker t;
579 if (_whitelist.contains(request.getRemoteAddr()))
580 {
581 t = new FixedRateTracker(loadId,type,_maxRequestsPerSec);
582 }
583 else
584 {
585 t = new RateTracker(loadId,type,_maxRequestsPerSec);
586 }
587
588 tracker=_rateTrackers.putIfAbsent(loadId,t);
589 if (tracker==null)
590 tracker=t;
591
592 if (type == USER_IP)
593 {
594
595 synchronized (_trackerTimeoutQ)
596 {
597 _trackerTimeoutQ.schedule(tracker);
598 }
599 }
600 else if (session!=null)
601
602 session.setAttribute(__TRACKER,tracker);
603 }
604
605 return tracker;
606 }
607
608 public void destroy()
609 {
610 _running=false;
611 _timerThread.interrupt();
612 synchronized (_requestTimeoutQ)
613 {
614 _requestTimeoutQ.cancelAll();
615 }
616 synchronized (_trackerTimeoutQ)
617 {
618 _trackerTimeoutQ.cancelAll();
619 }
620 _rateTrackers.clear();
621 _whitelist.clear();
622 }
623
624
625
626
627
628
629
630
631 protected String extractUserId(ServletRequest request)
632 {
633 return null;
634 }
635
636
637
638
639
640 protected void initWhitelist()
641 {
642 _whitelist.clear();
643 StringTokenizer tokenizer = new StringTokenizer(_whitelistStr, ",");
644 while (tokenizer.hasMoreTokens())
645 _whitelist.add(tokenizer.nextToken().trim());
646
647 LOG.info("Whitelisted IP addresses: {}", _whitelist.toString());
648 }
649
650
651
652
653
654
655
656
657
658 public int getMaxRequestsPerSec()
659 {
660 return _maxRequestsPerSec;
661 }
662
663
664
665
666
667
668
669
670
671 public void setMaxRequestsPerSec(int value)
672 {
673 _maxRequestsPerSec = value;
674 }
675
676
677
678
679
680
681 public long getDelayMs()
682 {
683 return _delayMs;
684 }
685
686
687
688
689
690
691
692
693 public void setDelayMs(long value)
694 {
695 _delayMs = value;
696 }
697
698
699
700
701
702
703
704
705 public long getMaxWaitMs()
706 {
707 return _maxWaitMs;
708 }
709
710
711
712
713
714
715
716
717 public void setMaxWaitMs(long value)
718 {
719 _maxWaitMs = value;
720 }
721
722
723
724
725
726
727
728
729 public int getThrottledRequests()
730 {
731 return _throttledRequests;
732 }
733
734
735
736
737
738
739
740
741 public void setThrottledRequests(int value)
742 {
743 _passes = new Semaphore((value-_throttledRequests+_passes.availablePermits()), true);
744 _throttledRequests = value;
745 }
746
747
748
749
750
751
752
753 public long getThrottleMs()
754 {
755 return _throttleMs;
756 }
757
758
759
760
761
762
763
764 public void setThrottleMs(long value)
765 {
766 _throttleMs = value;
767 }
768
769
770
771
772
773
774
775
776 public long getMaxRequestMs()
777 {
778 return _maxRequestMs;
779 }
780
781
782
783
784
785
786
787
788 public void setMaxRequestMs(long value)
789 {
790 _maxRequestMs = value;
791 }
792
793
794
795
796
797
798
799
800
801 public long getMaxIdleTrackerMs()
802 {
803 return _maxIdleTrackerMs;
804 }
805
806
807
808
809
810
811
812
813
814 public void setMaxIdleTrackerMs(long value)
815 {
816 _maxIdleTrackerMs = value;
817 }
818
819
820
821
822
823
824
825 public boolean isInsertHeaders()
826 {
827 return _insertHeaders;
828 }
829
830
831
832
833
834
835
836 public void setInsertHeaders(boolean value)
837 {
838 _insertHeaders = value;
839 }
840
841
842
843
844
845
846
847 public boolean isTrackSessions()
848 {
849 return _trackSessions;
850 }
851
852
853
854
855
856
857 public void setTrackSessions(boolean value)
858 {
859 _trackSessions = value;
860 }
861
862
863
864
865
866
867
868
869 public boolean isRemotePort()
870 {
871 return _remotePort;
872 }
873
874
875
876
877
878
879
880
881
882 public void setRemotePort(boolean value)
883 {
884 _remotePort = value;
885 }
886
887
888
889
890
891
892
893 public String getWhitelist()
894 {
895 return _whitelistStr;
896 }
897
898
899
900
901
902
903
904
905 public void setWhitelist(String value)
906 {
907 _whitelistStr = value;
908 initWhitelist();
909 }
910
911
912
913
914
915 class RateTracker extends Timeout.Task implements HttpSessionBindingListener, HttpSessionActivationListener
916 {
917 transient protected final String _id;
918 transient protected final int _type;
919 transient protected final long[] _timestamps;
920 transient protected int _next;
921
922
923 public RateTracker(String id, int type,int maxRequestsPerSecond)
924 {
925 _id = id;
926 _type = type;
927 _timestamps=new long[maxRequestsPerSecond];
928 _next=0;
929 }
930
931
932
933
934 public boolean isRateExceeded(long now)
935 {
936 final long last;
937 synchronized (this)
938 {
939 last=_timestamps[_next];
940 _timestamps[_next]=now;
941 _next= (_next+1)%_timestamps.length;
942 }
943
944 boolean exceeded=last!=0 && (now-last)<1000L;
945 return exceeded;
946 }
947
948
949 public String getId()
950 {
951 return _id;
952 }
953
954 public int getType()
955 {
956 return _type;
957 }
958
959
960 public void valueBound(HttpSessionBindingEvent event)
961 {
962 if (LOG.isDebugEnabled())
963 LOG.debug("Value bound:"+_id);
964 }
965
966 public void valueUnbound(HttpSessionBindingEvent event)
967 {
968
969 if (_rateTrackers != null)
970 _rateTrackers.remove(_id);
971 if (LOG.isDebugEnabled()) LOG.debug("Tracker removed: "+_id);
972 }
973
974 public void sessionWillPassivate(HttpSessionEvent se)
975 {
976
977
978 if (_rateTrackers != null)
979 _rateTrackers.remove(_id);
980 se.getSession().removeAttribute(__TRACKER);
981 if (LOG.isDebugEnabled()) LOG.debug("Value removed: "+_id);
982 }
983
984 public void sessionDidActivate(HttpSessionEvent se)
985 {
986 LOG.warn("Unexpected session activation");
987 }
988
989
990 public void expired()
991 {
992 if (_rateTrackers != null && _trackerTimeoutQ != null)
993 {
994 long now = _trackerTimeoutQ.getNow();
995 int latestIndex = _next == 0 ? 3 : (_next - 1 ) % _timestamps.length;
996 long last=_timestamps[latestIndex];
997 boolean hasRecentRequest = last != 0 && (now-last)<1000L;
998
999 if (hasRecentRequest)
1000 reschedule();
1001 else
1002 _rateTrackers.remove(_id);
1003 }
1004 }
1005
1006 @Override
1007 public String toString()
1008 {
1009 return "RateTracker/"+_id+"/"+_type;
1010 }
1011
1012
1013 }
1014
1015 class FixedRateTracker extends RateTracker
1016 {
1017 public FixedRateTracker(String id, int type, int numRecentRequestsTracked)
1018 {
1019 super(id,type,numRecentRequestsTracked);
1020 }
1021
1022 @Override
1023 public boolean isRateExceeded(long now)
1024 {
1025
1026
1027
1028 synchronized (this)
1029 {
1030 _timestamps[_next]=now;
1031 _next= (_next+1)%_timestamps.length;
1032 }
1033
1034 return false;
1035 }
1036
1037 @Override
1038 public String toString()
1039 {
1040 return "Fixed"+super.toString();
1041 }
1042 }
1043 }