1
2
3
4
5
6
7
8
9
10
11
12
13
14 package org.eclipse.jetty.servlets;
15
16 import java.io.IOException;
17 import java.io.Serializable;
18 import java.util.HashSet;
19 import java.util.Map;
20 import java.util.Queue;
21 import java.util.StringTokenizer;
22 import java.util.concurrent.ConcurrentHashMap;
23 import java.util.concurrent.ConcurrentLinkedQueue;
24 import java.util.concurrent.Semaphore;
25 import java.util.concurrent.TimeUnit;
26
27 import javax.servlet.Filter;
28 import javax.servlet.FilterChain;
29 import javax.servlet.FilterConfig;
30 import javax.servlet.ServletContext;
31 import javax.servlet.ServletException;
32 import javax.servlet.ServletRequest;
33 import javax.servlet.ServletResponse;
34 import javax.servlet.http.HttpServletRequest;
35 import javax.servlet.http.HttpServletResponse;
36 import javax.servlet.http.HttpSession;
37 import javax.servlet.http.HttpSessionActivationListener;
38 import javax.servlet.http.HttpSessionBindingEvent;
39 import javax.servlet.http.HttpSessionBindingListener;
40 import javax.servlet.http.HttpSessionEvent;
41
42 import org.eclipse.jetty.continuation.Continuation;
43 import org.eclipse.jetty.continuation.ContinuationListener;
44 import org.eclipse.jetty.continuation.ContinuationSupport;
45 import org.eclipse.jetty.server.handler.ContextHandler;
46 import org.eclipse.jetty.util.log.Log;
47 import org.eclipse.jetty.util.log.Logger;
48 import org.eclipse.jetty.util.thread.Timeout;
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119 public class DoSFilter implements Filter
120 {
121 private static final Logger LOG = Log.getLogger(DoSFilter.class);
122
123 final static String __TRACKER = "DoSFilter.Tracker";
124 final static String __THROTTLED = "DoSFilter.Throttled";
125
126 final static int __DEFAULT_MAX_REQUESTS_PER_SEC = 25;
127 final static int __DEFAULT_DELAY_MS = 100;
128 final static int __DEFAULT_THROTTLE = 5;
129 final static int __DEFAULT_WAIT_MS=50;
130 final static long __DEFAULT_THROTTLE_MS = 30000L;
131 final static long __DEFAULT_MAX_REQUEST_MS_INIT_PARAM=30000L;
132 final static long __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM=30000L;
133
134 final static String MANAGED_ATTR_INIT_PARAM="managedAttr";
135 final static String MAX_REQUESTS_PER_S_INIT_PARAM = "maxRequestsPerSec";
136 final static String DELAY_MS_INIT_PARAM = "delayMs";
137 final static String THROTTLED_REQUESTS_INIT_PARAM = "throttledRequests";
138 final static String MAX_WAIT_INIT_PARAM="maxWaitMs";
139 final static String THROTTLE_MS_INIT_PARAM = "throttleMs";
140 final static String MAX_REQUEST_MS_INIT_PARAM="maxRequestMs";
141 final static String MAX_IDLE_TRACKER_MS_INIT_PARAM="maxIdleTrackerMs";
142 final static String INSERT_HEADERS_INIT_PARAM="insertHeaders";
143 final static String TRACK_SESSIONS_INIT_PARAM="trackSessions";
144 final static String REMOTE_PORT_INIT_PARAM="remotePort";
145 final static String IP_WHITELIST_INIT_PARAM="ipWhitelist";
146
147 final static int USER_AUTH = 2;
148 final static int USER_SESSION = 2;
149 final static int USER_IP = 1;
150 final static int USER_UNKNOWN = 0;
151
152 ServletContext _context;
153
154 protected String _name;
155 protected long _delayMs;
156 protected long _throttleMs;
157 protected long _maxWaitMs;
158 protected long _maxRequestMs;
159 protected long _maxIdleTrackerMs;
160 protected boolean _insertHeaders;
161 protected boolean _trackSessions;
162 protected boolean _remotePort;
163 protected int _throttledRequests;
164 protected Semaphore _passes;
165 protected Queue<Continuation>[] _queue;
166 protected ContinuationListener[] _listener;
167
168 protected int _maxRequestsPerSec;
169 protected final ConcurrentHashMap<String, RateTracker> _rateTrackers=new ConcurrentHashMap<String, RateTracker>();
170 protected String _whitelistStr;
171 private final HashSet<String> _whitelist = new HashSet<String>();
172
173 private final Timeout _requestTimeoutQ = new Timeout();
174 private final Timeout _trackerTimeoutQ = new Timeout();
175
176 private Thread _timerThread;
177 private volatile boolean _running;
178
179 public void init(FilterConfig filterConfig)
180 {
181 _context = filterConfig.getServletContext();
182
183 _queue = new Queue[getMaxPriority() + 1];
184 _listener = new ContinuationListener[getMaxPriority() + 1];
185 for (int p = 0; p < _queue.length; p++)
186 {
187 _queue[p] = new ConcurrentLinkedQueue<Continuation>();
188
189 final int priority=p;
190 _listener[p] = new ContinuationListener()
191 {
192 public void onComplete(Continuation continuation)
193 {
194 }
195
196 public void onTimeout(Continuation continuation)
197 {
198 _queue[priority].remove(continuation);
199 }
200 };
201 }
202
203 _rateTrackers.clear();
204
205 int baseRateLimit = __DEFAULT_MAX_REQUESTS_PER_SEC;
206 if (filterConfig.getInitParameter(MAX_REQUESTS_PER_S_INIT_PARAM) != null)
207 baseRateLimit = Integer.parseInt(filterConfig.getInitParameter(MAX_REQUESTS_PER_S_INIT_PARAM));
208 _maxRequestsPerSec = baseRateLimit;
209
210 long delay = __DEFAULT_DELAY_MS;
211 if (filterConfig.getInitParameter(DELAY_MS_INIT_PARAM) != null)
212 delay = Integer.parseInt(filterConfig.getInitParameter(DELAY_MS_INIT_PARAM));
213 _delayMs = delay;
214
215 int throttledRequests = __DEFAULT_THROTTLE;
216 if (filterConfig.getInitParameter(THROTTLED_REQUESTS_INIT_PARAM) != null)
217 throttledRequests = Integer.parseInt(filterConfig.getInitParameter(THROTTLED_REQUESTS_INIT_PARAM));
218 _passes = new Semaphore(throttledRequests,true);
219 _throttledRequests = throttledRequests;
220
221 long wait = __DEFAULT_WAIT_MS;
222 if (filterConfig.getInitParameter(MAX_WAIT_INIT_PARAM) != null)
223 wait = Integer.parseInt(filterConfig.getInitParameter(MAX_WAIT_INIT_PARAM));
224 _maxWaitMs = wait;
225
226 long suspend = __DEFAULT_THROTTLE_MS;
227 if (filterConfig.getInitParameter(THROTTLE_MS_INIT_PARAM) != null)
228 suspend = Integer.parseInt(filterConfig.getInitParameter(THROTTLE_MS_INIT_PARAM));
229 _throttleMs = suspend;
230
231 long maxRequestMs = __DEFAULT_MAX_REQUEST_MS_INIT_PARAM;
232 if (filterConfig.getInitParameter(MAX_REQUEST_MS_INIT_PARAM) != null )
233 maxRequestMs = Long.parseLong(filterConfig.getInitParameter(MAX_REQUEST_MS_INIT_PARAM));
234 _maxRequestMs = maxRequestMs;
235
236 long maxIdleTrackerMs = __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM;
237 if (filterConfig.getInitParameter(MAX_IDLE_TRACKER_MS_INIT_PARAM) != null )
238 maxIdleTrackerMs = Long.parseLong(filterConfig.getInitParameter(MAX_IDLE_TRACKER_MS_INIT_PARAM));
239 _maxIdleTrackerMs = maxIdleTrackerMs;
240
241 _whitelistStr = "";
242 if (filterConfig.getInitParameter(IP_WHITELIST_INIT_PARAM) !=null )
243 _whitelistStr = filterConfig.getInitParameter(IP_WHITELIST_INIT_PARAM);
244 initWhitelist();
245
246 String tmp = filterConfig.getInitParameter(INSERT_HEADERS_INIT_PARAM);
247 _insertHeaders = tmp==null || Boolean.parseBoolean(tmp);
248
249 tmp = filterConfig.getInitParameter(TRACK_SESSIONS_INIT_PARAM);
250 _trackSessions = tmp==null || Boolean.parseBoolean(tmp);
251
252 tmp = filterConfig.getInitParameter(REMOTE_PORT_INIT_PARAM);
253 _remotePort = tmp!=null&& Boolean.parseBoolean(tmp);
254
255 _requestTimeoutQ.setNow();
256 _requestTimeoutQ.setDuration(_maxRequestMs);
257
258 _trackerTimeoutQ.setNow();
259 _trackerTimeoutQ.setDuration(_maxIdleTrackerMs);
260
261 _running=true;
262 _timerThread = (new Thread()
263 {
264 public void run()
265 {
266 try
267 {
268 while (_running)
269 {
270 long now;
271 synchronized (_requestTimeoutQ)
272 {
273 now = _requestTimeoutQ.setNow();
274 _requestTimeoutQ.tick();
275 }
276 synchronized (_trackerTimeoutQ)
277 {
278 _trackerTimeoutQ.setNow(now);
279 _trackerTimeoutQ.tick();
280 }
281 try
282 {
283 Thread.sleep(100);
284 }
285 catch (InterruptedException e)
286 {
287 LOG.ignore(e);
288 }
289 }
290 }
291 finally
292 {
293 LOG.info("DoSFilter timer exited");
294 }
295 }
296 });
297 _timerThread.start();
298
299 if (_context!=null && Boolean.parseBoolean(filterConfig.getInitParameter(MANAGED_ATTR_INIT_PARAM)))
300 _context.setAttribute(filterConfig.getFilterName(),this);
301 }
302
303
304 public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterchain) throws IOException, ServletException
305 {
306 final HttpServletRequest srequest = (HttpServletRequest)request;
307 final HttpServletResponse sresponse = (HttpServletResponse)response;
308
309 final long now=_requestTimeoutQ.getNow();
310
311
312 RateTracker tracker = (RateTracker)request.getAttribute(__TRACKER);
313
314 if (tracker==null)
315 {
316
317
318
319 tracker = getRateTracker(request);
320
321
322 final boolean overRateLimit = tracker.isRateExceeded(now);
323
324
325 if (!overRateLimit)
326 {
327 doFilterChain(filterchain,srequest,sresponse);
328 return;
329 }
330
331
332 LOG.warn("DOS ALERT: ip="+srequest.getRemoteAddr()+",session="+srequest.getRequestedSessionId()+",user="+srequest.getUserPrincipal());
333
334
335 switch((int)_delayMs)
336 {
337 case -1:
338 {
339
340 if (_insertHeaders)
341 ((HttpServletResponse)response).addHeader("DoSFilter","unavailable");
342
343 ((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
344 return;
345 }
346 case 0:
347 {
348
349 request.setAttribute(__TRACKER,tracker);
350 break;
351 }
352 default:
353 {
354
355 if (_insertHeaders)
356 ((HttpServletResponse)response).addHeader("DoSFilter","delayed");
357 Continuation continuation = ContinuationSupport.getContinuation(request);
358 request.setAttribute(__TRACKER,tracker);
359 if (_delayMs > 0)
360 continuation.setTimeout(_delayMs);
361 continuation.addContinuationListener(new ContinuationListener()
362 {
363
364 public void onComplete(Continuation continuation)
365 {
366 }
367
368 public void onTimeout(Continuation continuation)
369 {
370 }
371 });
372 continuation.suspend();
373 return;
374 }
375 }
376 }
377
378 boolean accepted = false;
379 try
380 {
381
382 accepted = _passes.tryAcquire(_maxWaitMs,TimeUnit.MILLISECONDS);
383
384 if (!accepted)
385 {
386
387 final Continuation continuation = ContinuationSupport.getContinuation(request);
388
389 Boolean throttled = (Boolean)request.getAttribute(__THROTTLED);
390 if (throttled!=Boolean.TRUE && _throttleMs>0)
391 {
392 int priority = getPriority(request,tracker);
393 request.setAttribute(__THROTTLED,Boolean.TRUE);
394 if (_insertHeaders)
395 ((HttpServletResponse)response).addHeader("DoSFilter","throttled");
396 if (_throttleMs > 0)
397 continuation.setTimeout(_throttleMs);
398 continuation.suspend();
399
400 continuation.addContinuationListener(_listener[priority]);
401 _queue[priority].add(continuation);
402 return;
403 }
404
405 else if (request.getAttribute("javax.servlet.resumed")==Boolean.TRUE)
406 {
407
408 _passes.acquire();
409 accepted = true;
410 }
411 }
412
413
414 if (accepted)
415
416 doFilterChain(filterchain,srequest,sresponse);
417 else
418 {
419
420 if (_insertHeaders)
421 ((HttpServletResponse)response).addHeader("DoSFilter","unavailable");
422 ((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
423 }
424 }
425 catch (InterruptedException e)
426 {
427 _context.log("DoS",e);
428 ((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
429 }
430 catch (Exception e)
431 {
432 e.printStackTrace();
433 }
434 finally
435 {
436 if (accepted)
437 {
438
439 for (int p = _queue.length; p-- > 0;)
440 {
441 Continuation continuation = _queue[p].poll();
442 if (continuation != null && continuation.isSuspended())
443 {
444 continuation.resume();
445 break;
446 }
447 }
448 _passes.release();
449 }
450 }
451 }
452
453
454
455
456
457
458
459
460 protected void doFilterChain(FilterChain chain, final HttpServletRequest request, final HttpServletResponse response)
461 throws IOException, ServletException
462 {
463 final Thread thread=Thread.currentThread();
464
465 final Timeout.Task requestTimeout = new Timeout.Task()
466 {
467 public void expired()
468 {
469 closeConnection(request, response, thread);
470 }
471 };
472
473 try
474 {
475 synchronized (_requestTimeoutQ)
476 {
477 _requestTimeoutQ.schedule(requestTimeout);
478 }
479 chain.doFilter(request,response);
480 }
481 finally
482 {
483 synchronized (_requestTimeoutQ)
484 {
485 requestTimeout.cancel();
486 }
487 }
488 }
489
490
491
492
493
494
495
496
497 protected void closeConnection(HttpServletRequest request, HttpServletResponse response, Thread thread)
498 {
499
500 if( !response.isCommitted() )
501 {
502 response.setHeader("Connection", "close");
503 }
504 try
505 {
506 try
507 {
508 response.getWriter().close();
509 }
510 catch (IllegalStateException e)
511 {
512 response.getOutputStream().close();
513 }
514 }
515 catch (IOException e)
516 {
517 LOG.warn(e);
518 }
519
520
521 thread.interrupt();
522 }
523
524
525
526
527
528
529
530
531 protected int getPriority(ServletRequest request, RateTracker tracker)
532 {
533 if (extractUserId(request)!=null)
534 return USER_AUTH;
535 if (tracker!=null)
536 return tracker.getType();
537 return USER_UNKNOWN;
538 }
539
540
541
542
543 protected int getMaxPriority()
544 {
545 return USER_AUTH;
546 }
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564 public RateTracker getRateTracker(ServletRequest request)
565 {
566 HttpServletRequest srequest = (HttpServletRequest)request;
567 HttpSession session=srequest.getSession(false);
568
569 String loadId = extractUserId(request);
570 final int type;
571 if (loadId != null)
572 {
573 type = USER_AUTH;
574 }
575 else
576 {
577 if (_trackSessions && session!=null && !session.isNew())
578 {
579 loadId=session.getId();
580 type = USER_SESSION;
581 }
582 else
583 {
584 loadId = _remotePort?(request.getRemoteAddr()+request.getRemotePort()):request.getRemoteAddr();
585 type = USER_IP;
586 }
587 }
588
589 RateTracker tracker=_rateTrackers.get(loadId);
590
591 if (tracker==null)
592 {
593 RateTracker t;
594 if (_whitelist.contains(request.getRemoteAddr()))
595 {
596 t = new FixedRateTracker(loadId,type,_maxRequestsPerSec);
597 }
598 else
599 {
600 t = new RateTracker(loadId,type,_maxRequestsPerSec);
601 }
602
603 tracker=_rateTrackers.putIfAbsent(loadId,t);
604 if (tracker==null)
605 tracker=t;
606
607 if (type == USER_IP)
608 {
609
610 synchronized (_trackerTimeoutQ)
611 {
612 _trackerTimeoutQ.schedule(tracker);
613 }
614 }
615 else if (session!=null)
616
617 session.setAttribute(__TRACKER,tracker);
618 }
619
620 return tracker;
621 }
622
623 public void destroy()
624 {
625 _running=false;
626 _timerThread.interrupt();
627 synchronized (_requestTimeoutQ)
628 {
629 _requestTimeoutQ.cancelAll();
630 }
631 synchronized (_trackerTimeoutQ)
632 {
633 _trackerTimeoutQ.cancelAll();
634 }
635 _rateTrackers.clear();
636 _whitelist.clear();
637 }
638
639
640
641
642
643
644
645
646 protected String extractUserId(ServletRequest request)
647 {
648 return null;
649 }
650
651
652
653
654
655 protected void initWhitelist()
656 {
657 _whitelist.clear();
658 StringTokenizer tokenizer = new StringTokenizer(_whitelistStr, ",");
659 while (tokenizer.hasMoreTokens())
660 _whitelist.add(tokenizer.nextToken().trim());
661
662 LOG.info("Whitelisted IP addresses: {}", _whitelist.toString());
663 }
664
665
666
667
668
669
670
671
672
673 public int getMaxRequestsPerSec()
674 {
675 return _maxRequestsPerSec;
676 }
677
678
679
680
681
682
683
684
685
686 public void setMaxRequestsPerSec(int value)
687 {
688 _maxRequestsPerSec = value;
689 }
690
691
692
693
694
695
696 public long getDelayMs()
697 {
698 return _delayMs;
699 }
700
701
702
703
704
705
706
707
708 public void setDelayMs(long value)
709 {
710 _delayMs = value;
711 }
712
713
714
715
716
717
718
719
720 public long getMaxWaitMs()
721 {
722 return _maxWaitMs;
723 }
724
725
726
727
728
729
730
731
732 public void setMaxWaitMs(long value)
733 {
734 _maxWaitMs = value;
735 }
736
737
738
739
740
741
742
743
744 public int getThrottledRequests()
745 {
746 return _throttledRequests;
747 }
748
749
750
751
752
753
754
755
756 public void setThrottledRequests(int value)
757 {
758 _passes = new Semaphore((value-_throttledRequests+_passes.availablePermits()), true);
759 _throttledRequests = value;
760 }
761
762
763
764
765
766
767
768 public long getThrottleMs()
769 {
770 return _throttleMs;
771 }
772
773
774
775
776
777
778
779 public void setThrottleMs(long value)
780 {
781 _throttleMs = value;
782 }
783
784
785
786
787
788
789
790
791 public long getMaxRequestMs()
792 {
793 return _maxRequestMs;
794 }
795
796
797
798
799
800
801
802
803 public void setMaxRequestMs(long value)
804 {
805 _maxRequestMs = value;
806 }
807
808
809
810
811
812
813
814
815
816 public long getMaxIdleTrackerMs()
817 {
818 return _maxIdleTrackerMs;
819 }
820
821
822
823
824
825
826
827
828
829 public void setMaxIdleTrackerMs(long value)
830 {
831 _maxIdleTrackerMs = value;
832 }
833
834
835
836
837
838
839
840 public boolean isInsertHeaders()
841 {
842 return _insertHeaders;
843 }
844
845
846
847
848
849
850
851 public void setInsertHeaders(boolean value)
852 {
853 _insertHeaders = value;
854 }
855
856
857
858
859
860
861
862 public boolean isTrackSessions()
863 {
864 return _trackSessions;
865 }
866
867
868
869
870
871
872 public void setTrackSessions(boolean value)
873 {
874 _trackSessions = value;
875 }
876
877
878
879
880
881
882
883
884 public boolean isRemotePort()
885 {
886 return _remotePort;
887 }
888
889
890
891
892
893
894
895
896
897 public void setRemotePort(boolean value)
898 {
899 _remotePort = value;
900 }
901
902
903
904
905
906
907
908 public String getWhitelist()
909 {
910 return _whitelistStr;
911 }
912
913
914
915
916
917
918
919
920 public void setWhitelist(String value)
921 {
922 _whitelistStr = value;
923 initWhitelist();
924 }
925
926
927
928
929
930 class RateTracker extends Timeout.Task implements HttpSessionBindingListener, HttpSessionActivationListener
931 {
932 transient protected final String _id;
933 transient protected final int _type;
934 transient protected final long[] _timestamps;
935 transient protected int _next;
936
937
938 public RateTracker(String id, int type,int maxRequestsPerSecond)
939 {
940 _id = id;
941 _type = type;
942 _timestamps=new long[maxRequestsPerSecond];
943 _next=0;
944 }
945
946
947
948
949 public boolean isRateExceeded(long now)
950 {
951 final long last;
952 synchronized (this)
953 {
954 last=_timestamps[_next];
955 _timestamps[_next]=now;
956 _next= (_next+1)%_timestamps.length;
957 }
958
959 boolean exceeded=last!=0 && (now-last)<1000L;
960 return exceeded;
961 }
962
963
964 public String getId()
965 {
966 return _id;
967 }
968
969 public int getType()
970 {
971 return _type;
972 }
973
974
975 public void valueBound(HttpSessionBindingEvent event)
976 {
977 if (LOG.isDebugEnabled())
978 LOG.debug("Value bound:"+_id);
979 }
980
981 public void valueUnbound(HttpSessionBindingEvent event)
982 {
983
984 if (_rateTrackers != null)
985 _rateTrackers.remove(_id);
986 if (LOG.isDebugEnabled()) LOG.debug("Tracker removed: "+_id);
987 }
988
989 public void sessionWillPassivate(HttpSessionEvent se)
990 {
991
992
993 if (_rateTrackers != null)
994 _rateTrackers.remove(_id);
995 se.getSession().removeAttribute(__TRACKER);
996 if (LOG.isDebugEnabled()) LOG.debug("Value removed: "+_id);
997 }
998
999 public void sessionDidActivate(HttpSessionEvent se)
1000 {
1001 LOG.warn("Unexpected session activation");
1002 }
1003
1004
1005 public void expired()
1006 {
1007 if (_rateTrackers != null && _trackerTimeoutQ != null)
1008 {
1009 long now = _trackerTimeoutQ.getNow();
1010 int latestIndex = _next == 0 ? (_timestamps.length-1) : (_next - 1 );
1011 long last=_timestamps[latestIndex];
1012 boolean hasRecentRequest = last != 0 && (now-last)<1000L;
1013
1014 if (hasRecentRequest)
1015 reschedule();
1016 else
1017 _rateTrackers.remove(_id);
1018 }
1019 }
1020
1021 @Override
1022 public String toString()
1023 {
1024 return "RateTracker/"+_id+"/"+_type;
1025 }
1026
1027
1028 }
1029
1030 class FixedRateTracker extends RateTracker
1031 {
1032 public FixedRateTracker(String id, int type, int numRecentRequestsTracked)
1033 {
1034 super(id,type,numRecentRequestsTracked);
1035 }
1036
1037 @Override
1038 public boolean isRateExceeded(long now)
1039 {
1040
1041
1042
1043 synchronized (this)
1044 {
1045 _timestamps[_next]=now;
1046 _next= (_next+1)%_timestamps.length;
1047 }
1048
1049 return false;
1050 }
1051
1052 @Override
1053 public String toString()
1054 {
1055 return "Fixed"+super.toString();
1056 }
1057 }
1058 }