1
2
3
4
5
6
7
8
9
10
11
12
13
14 package org.eclipse.jetty.servlets;
15
16 import java.io.IOException;
17 import java.util.HashSet;
18 import java.util.Queue;
19 import java.util.StringTokenizer;
20 import java.util.concurrent.ConcurrentHashMap;
21 import java.util.concurrent.ConcurrentLinkedQueue;
22 import java.util.concurrent.Semaphore;
23 import java.util.concurrent.TimeUnit;
24 import javax.servlet.Filter;
25 import javax.servlet.FilterChain;
26 import javax.servlet.FilterConfig;
27 import javax.servlet.ServletContext;
28 import javax.servlet.ServletException;
29 import javax.servlet.ServletRequest;
30 import javax.servlet.ServletResponse;
31 import javax.servlet.http.HttpServletRequest;
32 import javax.servlet.http.HttpServletResponse;
33 import javax.servlet.http.HttpSession;
34 import javax.servlet.http.HttpSessionBindingEvent;
35 import javax.servlet.http.HttpSessionBindingListener;
36
37 import org.eclipse.jetty.continuation.Continuation;
38 import org.eclipse.jetty.continuation.ContinuationListener;
39 import org.eclipse.jetty.continuation.ContinuationSupport;
40 import org.eclipse.jetty.server.handler.ContextHandler;
41 import org.eclipse.jetty.util.log.Log;
42 import org.eclipse.jetty.util.thread.Timeout;
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113 public class DoSFilter implements Filter
114 {
115 final static String __TRACKER = "DoSFilter.Tracker";
116 final static String __THROTTLED = "DoSFilter.Throttled";
117
118 final static int __DEFAULT_MAX_REQUESTS_PER_SEC = 25;
119 final static int __DEFAULT_DELAY_MS = 100;
120 final static int __DEFAULT_THROTTLE = 5;
121 final static int __DEFAULT_WAIT_MS=50;
122 final static long __DEFAULT_THROTTLE_MS = 30000L;
123 final static long __DEFAULT_MAX_REQUEST_MS_INIT_PARAM=30000L;
124 final static long __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM=30000L;
125
126 final static String MANAGED_ATTR_INIT_PARAM="managedAttr";
127 final static String MAX_REQUESTS_PER_S_INIT_PARAM = "maxRequestsPerSec";
128 final static String DELAY_MS_INIT_PARAM = "delayMs";
129 final static String THROTTLED_REQUESTS_INIT_PARAM = "throttledRequests";
130 final static String MAX_WAIT_INIT_PARAM="maxWaitMs";
131 final static String THROTTLE_MS_INIT_PARAM = "throttleMs";
132 final static String MAX_REQUEST_MS_INIT_PARAM="maxRequestMs";
133 final static String MAX_IDLE_TRACKER_MS_INIT_PARAM="maxIdleTrackerMs";
134 final static String INSERT_HEADERS_INIT_PARAM="insertHeaders";
135 final static String TRACK_SESSIONS_INIT_PARAM="trackSessions";
136 final static String REMOTE_PORT_INIT_PARAM="remotePort";
137 final static String IP_WHITELIST_INIT_PARAM="ipWhitelist";
138
139 final static int USER_AUTH = 2;
140 final static int USER_SESSION = 2;
141 final static int USER_IP = 1;
142 final static int USER_UNKNOWN = 0;
143
144 ServletContext _context;
145
146 protected String _name;
147 protected long _delayMs;
148 protected long _throttleMs;
149 protected long _maxWaitMs;
150 protected long _maxRequestMs;
151 protected long _maxIdleTrackerMs;
152 protected boolean _insertHeaders;
153 protected boolean _trackSessions;
154 protected boolean _remotePort;
155 protected int _throttledRequests;
156 protected Semaphore _passes;
157 protected Queue<Continuation>[] _queue;
158 protected ContinuationListener[] _listener;
159
160 protected int _maxRequestsPerSec;
161 protected final ConcurrentHashMap<String, RateTracker> _rateTrackers=new ConcurrentHashMap<String, RateTracker>();
162 protected String _whitelistStr;
163 private final HashSet<String> _whitelist = new HashSet<String>();
164
165 private final Timeout _requestTimeoutQ = new Timeout();
166 private final Timeout _trackerTimeoutQ = new Timeout();
167
168 private Thread _timerThread;
169 private volatile boolean _running;
170
171 public void init(FilterConfig filterConfig)
172 {
173 _context = filterConfig.getServletContext();
174
175 _queue = new Queue[getMaxPriority() + 1];
176 _listener = new ContinuationListener[getMaxPriority() + 1];
177 for (int p = 0; p < _queue.length; p++)
178 {
179 _queue[p] = new ConcurrentLinkedQueue<Continuation>();
180
181 final int priority=p;
182 _listener[p] = new ContinuationListener()
183 {
184 public void onComplete(Continuation continuation)
185 {
186 }
187
188 public void onTimeout(Continuation continuation)
189 {
190 _queue[priority].remove(continuation);
191 }
192 };
193 }
194
195 _rateTrackers.clear();
196
197 int baseRateLimit = __DEFAULT_MAX_REQUESTS_PER_SEC;
198 if (filterConfig.getInitParameter(MAX_REQUESTS_PER_S_INIT_PARAM) != null)
199 baseRateLimit = Integer.parseInt(filterConfig.getInitParameter(MAX_REQUESTS_PER_S_INIT_PARAM));
200 _maxRequestsPerSec = baseRateLimit;
201
202 long delay = __DEFAULT_DELAY_MS;
203 if (filterConfig.getInitParameter(DELAY_MS_INIT_PARAM) != null)
204 delay = Integer.parseInt(filterConfig.getInitParameter(DELAY_MS_INIT_PARAM));
205 _delayMs = delay;
206
207 int throttledRequests = __DEFAULT_THROTTLE;
208 if (filterConfig.getInitParameter(THROTTLED_REQUESTS_INIT_PARAM) != null)
209 throttledRequests = Integer.parseInt(filterConfig.getInitParameter(THROTTLED_REQUESTS_INIT_PARAM));
210 _passes = new Semaphore(throttledRequests,true);
211 _throttledRequests = throttledRequests;
212
213 long wait = __DEFAULT_WAIT_MS;
214 if (filterConfig.getInitParameter(MAX_WAIT_INIT_PARAM) != null)
215 wait = Integer.parseInt(filterConfig.getInitParameter(MAX_WAIT_INIT_PARAM));
216 _maxWaitMs = wait;
217
218 long suspend = __DEFAULT_THROTTLE_MS;
219 if (filterConfig.getInitParameter(THROTTLE_MS_INIT_PARAM) != null)
220 suspend = Integer.parseInt(filterConfig.getInitParameter(THROTTLE_MS_INIT_PARAM));
221 _throttleMs = suspend;
222
223 long maxRequestMs = __DEFAULT_MAX_REQUEST_MS_INIT_PARAM;
224 if (filterConfig.getInitParameter(MAX_REQUEST_MS_INIT_PARAM) != null )
225 maxRequestMs = Long.parseLong(filterConfig.getInitParameter(MAX_REQUEST_MS_INIT_PARAM));
226 _maxRequestMs = maxRequestMs;
227
228 long maxIdleTrackerMs = __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM;
229 if (filterConfig.getInitParameter(MAX_IDLE_TRACKER_MS_INIT_PARAM) != null )
230 maxIdleTrackerMs = Long.parseLong(filterConfig.getInitParameter(MAX_IDLE_TRACKER_MS_INIT_PARAM));
231 _maxIdleTrackerMs = maxIdleTrackerMs;
232
233 _whitelistStr = "";
234 if (filterConfig.getInitParameter(IP_WHITELIST_INIT_PARAM) !=null )
235 _whitelistStr = filterConfig.getInitParameter(IP_WHITELIST_INIT_PARAM);
236 initWhitelist();
237
238 String tmp = filterConfig.getInitParameter(INSERT_HEADERS_INIT_PARAM);
239 _insertHeaders = tmp==null || Boolean.parseBoolean(tmp);
240
241 tmp = filterConfig.getInitParameter(TRACK_SESSIONS_INIT_PARAM);
242 _trackSessions = tmp==null || Boolean.parseBoolean(tmp);
243
244 tmp = filterConfig.getInitParameter(REMOTE_PORT_INIT_PARAM);
245 _remotePort = tmp!=null&& Boolean.parseBoolean(tmp);
246
247 _requestTimeoutQ.setNow();
248 _requestTimeoutQ.setDuration(_maxRequestMs);
249
250 _trackerTimeoutQ.setNow();
251 _trackerTimeoutQ.setDuration(_maxIdleTrackerMs);
252
253 _running=true;
254 _timerThread = (new Thread()
255 {
256 public void run()
257 {
258 try
259 {
260 while (_running)
261 {
262 long now;
263 synchronized (_requestTimeoutQ)
264 {
265 now = _requestTimeoutQ.setNow();
266 _requestTimeoutQ.tick();
267 }
268 synchronized (_trackerTimeoutQ)
269 {
270 _trackerTimeoutQ.setNow(now);
271 _trackerTimeoutQ.tick();
272 }
273 try
274 {
275 Thread.sleep(100);
276 }
277 catch (InterruptedException e)
278 {
279 Log.ignore(e);
280 }
281 }
282 }
283 finally
284 {
285 Log.info("DoSFilter timer exited");
286 }
287 }
288 });
289 _timerThread.start();
290
291 if (_context!=null && Boolean.parseBoolean(filterConfig.getInitParameter(MANAGED_ATTR_INIT_PARAM)))
292 _context.setAttribute(filterConfig.getFilterName(),this);
293 }
294
295
296 public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterchain) throws IOException, ServletException
297 {
298 final HttpServletRequest srequest = (HttpServletRequest)request;
299 final HttpServletResponse sresponse = (HttpServletResponse)response;
300
301 final long now=_requestTimeoutQ.getNow();
302
303
304 RateTracker tracker = (RateTracker)request.getAttribute(__TRACKER);
305
306 if (tracker==null)
307 {
308
309
310
311 tracker = getRateTracker(request);
312
313
314 final boolean overRateLimit = tracker.isRateExceeded(now);
315
316
317 if (!overRateLimit)
318 {
319 doFilterChain(filterchain,srequest,sresponse);
320 return;
321 }
322
323
324 Log.warn("DOS ALERT: ip="+srequest.getRemoteAddr()+",session="+srequest.getRequestedSessionId()+",user="+srequest.getUserPrincipal());
325
326
327 switch((int)_delayMs)
328 {
329 case -1:
330 {
331
332 if (_insertHeaders)
333 ((HttpServletResponse)response).addHeader("DoSFilter","unavailable");
334 ((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
335 return;
336 }
337 case 0:
338 {
339
340 request.setAttribute(__TRACKER,tracker);
341 break;
342 }
343 default:
344 {
345
346 if (_insertHeaders)
347 ((HttpServletResponse)response).addHeader("DoSFilter","delayed");
348 Continuation continuation = ContinuationSupport.getContinuation(request);
349 request.setAttribute(__TRACKER,tracker);
350 if (_delayMs > 0)
351 continuation.setTimeout(_delayMs);
352 continuation.suspend();
353 return;
354 }
355 }
356 }
357
358
359 boolean accepted = false;
360 try
361 {
362
363 accepted = _passes.tryAcquire(_maxWaitMs,TimeUnit.MILLISECONDS);
364
365 if (!accepted)
366 {
367
368 final Continuation continuation = ContinuationSupport.getContinuation(request);
369
370 Boolean throttled = (Boolean)request.getAttribute(__THROTTLED);
371 if (throttled!=Boolean.TRUE && _throttleMs>0)
372 {
373 int priority = getPriority(request,tracker);
374 request.setAttribute(__THROTTLED,Boolean.TRUE);
375 if (_insertHeaders)
376 ((HttpServletResponse)response).addHeader("DoSFilter","throttled");
377 if (_throttleMs > 0)
378 continuation.setTimeout(_throttleMs);
379 continuation.suspend();
380
381 continuation.addContinuationListener(_listener[priority]);
382 _queue[priority].add(continuation);
383 return;
384 }
385
386 else if (request.getAttribute("javax.servlet.resumed")==Boolean.TRUE)
387 {
388
389 _passes.acquire();
390 accepted = true;
391 }
392 }
393
394
395 if (accepted)
396
397 doFilterChain(filterchain,srequest,sresponse);
398 else
399 {
400
401 if (_insertHeaders)
402 ((HttpServletResponse)response).addHeader("DoSFilter","unavailable");
403 ((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
404 }
405 }
406 catch (InterruptedException e)
407 {
408 _context.log("DoS",e);
409 ((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
410 }
411 finally
412 {
413 if (accepted)
414 {
415
416 for (int p = _queue.length; p-- > 0;)
417 {
418 Continuation continuation = _queue[p].poll();
419 if (continuation != null && continuation.isSuspended())
420 {
421 continuation.resume();
422 break;
423 }
424 }
425 _passes.release();
426 }
427 }
428 }
429
430
431
432
433
434
435
436
437 protected void doFilterChain(FilterChain chain, final HttpServletRequest request, final HttpServletResponse response)
438 throws IOException, ServletException
439 {
440 final Thread thread=Thread.currentThread();
441
442 final Timeout.Task requestTimeout = new Timeout.Task()
443 {
444 public void expired()
445 {
446 closeConnection(request, response, thread);
447 }
448 };
449
450 try
451 {
452 synchronized (_requestTimeoutQ)
453 {
454 _requestTimeoutQ.schedule(requestTimeout);
455 }
456 chain.doFilter(request,response);
457 }
458 finally
459 {
460 synchronized (_requestTimeoutQ)
461 {
462 requestTimeout.cancel();
463 }
464 }
465 }
466
467
468
469
470
471
472
473
474 protected void closeConnection(HttpServletRequest request, HttpServletResponse response, Thread thread)
475 {
476
477 if( !response.isCommitted() )
478 {
479 response.setHeader("Connection", "close");
480 }
481 try
482 {
483 try
484 {
485 response.getWriter().close();
486 }
487 catch (IllegalStateException e)
488 {
489 response.getOutputStream().close();
490 }
491 }
492 catch (IOException e)
493 {
494 Log.warn(e);
495 }
496
497
498 thread.interrupt();
499 }
500
501
502
503
504
505
506
507
508 protected int getPriority(ServletRequest request, RateTracker tracker)
509 {
510 if (extractUserId(request)!=null)
511 return USER_AUTH;
512 if (tracker!=null)
513 return tracker.getType();
514 return USER_UNKNOWN;
515 }
516
517
518
519
520 protected int getMaxPriority()
521 {
522 return USER_AUTH;
523 }
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541 public RateTracker getRateTracker(ServletRequest request)
542 {
543 HttpServletRequest srequest = (HttpServletRequest)request;
544 HttpSession session=srequest.getSession(false);
545
546 String loadId = extractUserId(request);
547 final int type;
548 if (loadId != null)
549 {
550 type = USER_AUTH;
551 }
552 else
553 {
554 if (_trackSessions && session!=null && !session.isNew())
555 {
556 loadId=session.getId();
557 type = USER_SESSION;
558 }
559 else
560 {
561 loadId = _remotePort?(request.getRemoteAddr()+request.getRemotePort()):request.getRemoteAddr();
562 type = USER_IP;
563 }
564 }
565
566 RateTracker tracker=_rateTrackers.get(loadId);
567
568 if (tracker==null)
569 {
570 RateTracker t;
571 if (_whitelist.contains(request.getRemoteAddr()))
572 {
573 t = new FixedRateTracker(loadId,type,_maxRequestsPerSec);
574 }
575 else
576 {
577 t = new RateTracker(loadId,type,_maxRequestsPerSec);
578 }
579
580 tracker=_rateTrackers.putIfAbsent(loadId,t);
581 if (tracker==null)
582 tracker=t;
583
584 if (type == USER_IP)
585 {
586
587 synchronized (_trackerTimeoutQ)
588 {
589 _trackerTimeoutQ.schedule(tracker);
590 }
591 }
592 else if (session!=null)
593
594 session.setAttribute(__TRACKER,tracker);
595 }
596
597 return tracker;
598 }
599
600 public void destroy()
601 {
602 _running=false;
603 _timerThread.interrupt();
604 synchronized (_requestTimeoutQ)
605 {
606 _requestTimeoutQ.cancelAll();
607 }
608 synchronized (_trackerTimeoutQ)
609 {
610 _trackerTimeoutQ.cancelAll();
611 }
612 _rateTrackers.clear();
613 _whitelist.clear();
614 }
615
616
617
618
619
620
621
622
623 protected String extractUserId(ServletRequest request)
624 {
625 return null;
626 }
627
628
629
630
631
632 protected void initWhitelist()
633 {
634 _whitelist.clear();
635 StringTokenizer tokenizer = new StringTokenizer(_whitelistStr, ",");
636 while (tokenizer.hasMoreTokens())
637 _whitelist.add(tokenizer.nextToken().trim());
638
639 Log.info("Whitelisted IP addresses: {}", _whitelist.toString());
640 }
641
642
643
644
645
646
647
648
649
650 public int getMaxRequestsPerSec()
651 {
652 return _maxRequestsPerSec;
653 }
654
655
656
657
658
659
660
661
662
663 public void setMaxRequestsPerSec(int value)
664 {
665 _maxRequestsPerSec = value;
666 }
667
668
669
670
671
672
673 public long getDelayMs()
674 {
675 return _delayMs;
676 }
677
678
679
680
681
682
683
684
685 public void setDelayMs(long value)
686 {
687 _delayMs = value;
688 }
689
690
691
692
693
694
695
696
697 public long getMaxWaitMs()
698 {
699 return _maxWaitMs;
700 }
701
702
703
704
705
706
707
708
709 public void setMaxWaitMs(long value)
710 {
711 _maxWaitMs = value;
712 }
713
714
715
716
717
718
719
720
721 public int getThrottledRequests()
722 {
723 return _throttledRequests;
724 }
725
726
727
728
729
730
731
732
733 public void setThrottledRequests(int value)
734 {
735 _passes = new Semaphore((value-_throttledRequests+_passes.availablePermits()), true);
736 _throttledRequests = value;
737 }
738
739
740
741
742
743
744
745 public long getThrottleMs()
746 {
747 return _throttleMs;
748 }
749
750
751
752
753
754
755
756 public void setThrottleMs(long value)
757 {
758 _throttleMs = value;
759 }
760
761
762
763
764
765
766
767
768 public long getMaxRequestMs()
769 {
770 return _maxRequestMs;
771 }
772
773
774
775
776
777
778
779
780 public void setMaxRequestMs(long value)
781 {
782 _maxRequestMs = value;
783 }
784
785
786
787
788
789
790
791
792
793 public long getMaxIdleTrackerMs()
794 {
795 return _maxIdleTrackerMs;
796 }
797
798
799
800
801
802
803
804
805
806 public void setMaxIdleTrackerMs(long value)
807 {
808 _maxIdleTrackerMs = value;
809 }
810
811
812
813
814
815
816
817 public boolean isInsertHeaders()
818 {
819 return _insertHeaders;
820 }
821
822
823
824
825
826
827
828 public void setInsertHeaders(boolean value)
829 {
830 _insertHeaders = value;
831 }
832
833
834
835
836
837
838
839 public boolean isTrackSessions()
840 {
841 return _trackSessions;
842 }
843
844
845
846
847
848
849 public void setTrackSessions(boolean value)
850 {
851 _trackSessions = value;
852 }
853
854
855
856
857
858
859
860
861 public boolean isRemotePort()
862 {
863 return _remotePort;
864 }
865
866
867
868
869
870
871
872
873
874 public void setRemotePort(boolean value)
875 {
876 _remotePort = value;
877 }
878
879
880
881
882
883
884
885 public String getWhitelist()
886 {
887 return _whitelistStr;
888 }
889
890
891
892
893
894
895
896
897 public void setWhitelist(String value)
898 {
899 _whitelistStr = value;
900 initWhitelist();
901 }
902
903
904
905
906
907 class RateTracker extends Timeout.Task implements HttpSessionBindingListener
908 {
909 protected final String _id;
910 protected final int _type;
911 protected final long[] _timestamps;
912 protected int _next;
913
914 public RateTracker(String id, int type,int maxRequestsPerSecond)
915 {
916 _id = id;
917 _type = type;
918 _timestamps=new long[maxRequestsPerSecond];
919 _next=0;
920 }
921
922
923
924
925 public boolean isRateExceeded(long now)
926 {
927 final long last;
928 synchronized (this)
929 {
930 last=_timestamps[_next];
931 _timestamps[_next]=now;
932 _next= (_next+1)%_timestamps.length;
933 }
934
935 boolean exceeded=last!=0 && (now-last)<1000L;
936 return exceeded;
937 }
938
939
940 public String getId()
941 {
942 return _id;
943 }
944
945 public int getType()
946 {
947 return _type;
948 }
949
950
951 public void valueBound(HttpSessionBindingEvent event)
952 {
953 }
954
955 public void valueUnbound(HttpSessionBindingEvent event)
956 {
957 _rateTrackers.remove(_id);
958 }
959
960 public void expired()
961 {
962 long now = _trackerTimeoutQ.getNow();
963 int latestIndex = _next == 0 ? 3 : (_next - 1 ) % _timestamps.length;
964 long last=_timestamps[latestIndex];
965 boolean hasRecentRequest = last != 0 && (now-last)<1000L;
966
967 if (hasRecentRequest)
968 reschedule();
969 else
970 _rateTrackers.remove(_id);
971 }
972
973 @Override
974 public String toString()
975 {
976 return "RateTracker/"+_id+"/"+_type;
977 }
978 }
979
980 class FixedRateTracker extends RateTracker
981 {
982 public FixedRateTracker(String id, int type, int numRecentRequestsTracked)
983 {
984 super(id,type,numRecentRequestsTracked);
985 }
986
987 @Override
988 public boolean isRateExceeded(long now)
989 {
990
991
992
993 synchronized (this)
994 {
995 _timestamps[_next]=now;
996 _next= (_next+1)%_timestamps.length;
997 }
998
999 return false;
1000 }
1001
1002 @Override
1003 public String toString()
1004 {
1005 return "Fixed"+super.toString();
1006 }
1007 }
1008 }