1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19 package org.eclipse.jetty.servlets;
20
21 import java.io.IOException;
22 import java.io.Serializable;
23 import java.util.HashSet;
24 import java.util.Map;
25 import java.util.Queue;
26 import java.util.StringTokenizer;
27 import java.util.concurrent.ConcurrentHashMap;
28 import java.util.concurrent.ConcurrentLinkedQueue;
29 import java.util.concurrent.Semaphore;
30 import java.util.concurrent.TimeUnit;
31
32 import javax.servlet.Filter;
33 import javax.servlet.FilterChain;
34 import javax.servlet.FilterConfig;
35 import javax.servlet.ServletContext;
36 import javax.servlet.ServletException;
37 import javax.servlet.ServletRequest;
38 import javax.servlet.ServletResponse;
39 import javax.servlet.http.HttpServletRequest;
40 import javax.servlet.http.HttpServletResponse;
41 import javax.servlet.http.HttpSession;
42 import javax.servlet.http.HttpSessionActivationListener;
43 import javax.servlet.http.HttpSessionBindingEvent;
44 import javax.servlet.http.HttpSessionBindingListener;
45 import javax.servlet.http.HttpSessionEvent;
46
47 import org.eclipse.jetty.continuation.Continuation;
48 import org.eclipse.jetty.continuation.ContinuationListener;
49 import org.eclipse.jetty.continuation.ContinuationSupport;
50 import org.eclipse.jetty.server.handler.ContextHandler;
51 import org.eclipse.jetty.util.log.Log;
52 import org.eclipse.jetty.util.log.Logger;
53 import org.eclipse.jetty.util.thread.Timeout;
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124 public class DoSFilter implements Filter
125 {
126 private static final Logger LOG = Log.getLogger(DoSFilter.class);
127
128 final static String __TRACKER = "DoSFilter.Tracker";
129 final static String __THROTTLED = "DoSFilter.Throttled";
130
131 final static int __DEFAULT_MAX_REQUESTS_PER_SEC = 25;
132 final static int __DEFAULT_DELAY_MS = 100;
133 final static int __DEFAULT_THROTTLE = 5;
134 final static int __DEFAULT_WAIT_MS=50;
135 final static long __DEFAULT_THROTTLE_MS = 30000L;
136 final static long __DEFAULT_MAX_REQUEST_MS_INIT_PARAM=30000L;
137 final static long __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM=30000L;
138
139 final static String MANAGED_ATTR_INIT_PARAM="managedAttr";
140 final static String MAX_REQUESTS_PER_S_INIT_PARAM = "maxRequestsPerSec";
141 final static String DELAY_MS_INIT_PARAM = "delayMs";
142 final static String THROTTLED_REQUESTS_INIT_PARAM = "throttledRequests";
143 final static String MAX_WAIT_INIT_PARAM="maxWaitMs";
144 final static String THROTTLE_MS_INIT_PARAM = "throttleMs";
145 final static String MAX_REQUEST_MS_INIT_PARAM="maxRequestMs";
146 final static String MAX_IDLE_TRACKER_MS_INIT_PARAM="maxIdleTrackerMs";
147 final static String INSERT_HEADERS_INIT_PARAM="insertHeaders";
148 final static String TRACK_SESSIONS_INIT_PARAM="trackSessions";
149 final static String REMOTE_PORT_INIT_PARAM="remotePort";
150 final static String IP_WHITELIST_INIT_PARAM="ipWhitelist";
151
152 final static int USER_AUTH = 2;
153 final static int USER_SESSION = 2;
154 final static int USER_IP = 1;
155 final static int USER_UNKNOWN = 0;
156
157 ServletContext _context;
158
159 protected String _name;
160 protected long _delayMs;
161 protected long _throttleMs;
162 protected long _maxWaitMs;
163 protected long _maxRequestMs;
164 protected long _maxIdleTrackerMs;
165 protected boolean _insertHeaders;
166 protected boolean _trackSessions;
167 protected boolean _remotePort;
168 protected int _throttledRequests;
169 protected Semaphore _passes;
170 protected Queue<Continuation>[] _queue;
171 protected ContinuationListener[] _listener;
172
173 protected int _maxRequestsPerSec;
174 protected final ConcurrentHashMap<String, RateTracker> _rateTrackers=new ConcurrentHashMap<String, RateTracker>();
175 protected String _whitelistStr;
176 private final HashSet<String> _whitelist = new HashSet<String>();
177
178 private final Timeout _requestTimeoutQ = new Timeout();
179 private final Timeout _trackerTimeoutQ = new Timeout();
180
181 private Thread _timerThread;
182 private volatile boolean _running;
183
184 public void init(FilterConfig filterConfig)
185 {
186 _context = filterConfig.getServletContext();
187
188 _queue = new Queue[getMaxPriority() + 1];
189 _listener = new ContinuationListener[getMaxPriority() + 1];
190 for (int p = 0; p < _queue.length; p++)
191 {
192 _queue[p] = new ConcurrentLinkedQueue<Continuation>();
193
194 final int priority=p;
195 _listener[p] = new ContinuationListener()
196 {
197 public void onComplete(Continuation continuation)
198 {
199 }
200
201 public void onTimeout(Continuation continuation)
202 {
203 _queue[priority].remove(continuation);
204 }
205 };
206 }
207
208 _rateTrackers.clear();
209
210 int baseRateLimit = __DEFAULT_MAX_REQUESTS_PER_SEC;
211 if (filterConfig.getInitParameter(MAX_REQUESTS_PER_S_INIT_PARAM) != null)
212 baseRateLimit = Integer.parseInt(filterConfig.getInitParameter(MAX_REQUESTS_PER_S_INIT_PARAM));
213 _maxRequestsPerSec = baseRateLimit;
214
215 long delay = __DEFAULT_DELAY_MS;
216 if (filterConfig.getInitParameter(DELAY_MS_INIT_PARAM) != null)
217 delay = Integer.parseInt(filterConfig.getInitParameter(DELAY_MS_INIT_PARAM));
218 _delayMs = delay;
219
220 int throttledRequests = __DEFAULT_THROTTLE;
221 if (filterConfig.getInitParameter(THROTTLED_REQUESTS_INIT_PARAM) != null)
222 throttledRequests = Integer.parseInt(filterConfig.getInitParameter(THROTTLED_REQUESTS_INIT_PARAM));
223 _passes = new Semaphore(throttledRequests,true);
224 _throttledRequests = throttledRequests;
225
226 long wait = __DEFAULT_WAIT_MS;
227 if (filterConfig.getInitParameter(MAX_WAIT_INIT_PARAM) != null)
228 wait = Integer.parseInt(filterConfig.getInitParameter(MAX_WAIT_INIT_PARAM));
229 _maxWaitMs = wait;
230
231 long suspend = __DEFAULT_THROTTLE_MS;
232 if (filterConfig.getInitParameter(THROTTLE_MS_INIT_PARAM) != null)
233 suspend = Integer.parseInt(filterConfig.getInitParameter(THROTTLE_MS_INIT_PARAM));
234 _throttleMs = suspend;
235
236 long maxRequestMs = __DEFAULT_MAX_REQUEST_MS_INIT_PARAM;
237 if (filterConfig.getInitParameter(MAX_REQUEST_MS_INIT_PARAM) != null )
238 maxRequestMs = Long.parseLong(filterConfig.getInitParameter(MAX_REQUEST_MS_INIT_PARAM));
239 _maxRequestMs = maxRequestMs;
240
241 long maxIdleTrackerMs = __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM;
242 if (filterConfig.getInitParameter(MAX_IDLE_TRACKER_MS_INIT_PARAM) != null )
243 maxIdleTrackerMs = Long.parseLong(filterConfig.getInitParameter(MAX_IDLE_TRACKER_MS_INIT_PARAM));
244 _maxIdleTrackerMs = maxIdleTrackerMs;
245
246 _whitelistStr = "";
247 if (filterConfig.getInitParameter(IP_WHITELIST_INIT_PARAM) !=null )
248 _whitelistStr = filterConfig.getInitParameter(IP_WHITELIST_INIT_PARAM);
249 initWhitelist();
250
251 String tmp = filterConfig.getInitParameter(INSERT_HEADERS_INIT_PARAM);
252 _insertHeaders = tmp==null || Boolean.parseBoolean(tmp);
253
254 tmp = filterConfig.getInitParameter(TRACK_SESSIONS_INIT_PARAM);
255 _trackSessions = tmp==null || Boolean.parseBoolean(tmp);
256
257 tmp = filterConfig.getInitParameter(REMOTE_PORT_INIT_PARAM);
258 _remotePort = tmp!=null&& Boolean.parseBoolean(tmp);
259
260 _requestTimeoutQ.setNow();
261 _requestTimeoutQ.setDuration(_maxRequestMs);
262
263 _trackerTimeoutQ.setNow();
264 _trackerTimeoutQ.setDuration(_maxIdleTrackerMs);
265
266 _running=true;
267 _timerThread = (new Thread()
268 {
269 public void run()
270 {
271 try
272 {
273 while (_running)
274 {
275 long now;
276 synchronized (_requestTimeoutQ)
277 {
278 now = _requestTimeoutQ.setNow();
279 _requestTimeoutQ.tick();
280 }
281 synchronized (_trackerTimeoutQ)
282 {
283 _trackerTimeoutQ.setNow(now);
284 _trackerTimeoutQ.tick();
285 }
286 try
287 {
288 Thread.sleep(100);
289 }
290 catch (InterruptedException e)
291 {
292 LOG.ignore(e);
293 }
294 }
295 }
296 finally
297 {
298 LOG.info("DoSFilter timer exited");
299 }
300 }
301 });
302 _timerThread.start();
303
304 if (_context!=null && Boolean.parseBoolean(filterConfig.getInitParameter(MANAGED_ATTR_INIT_PARAM)))
305 _context.setAttribute(filterConfig.getFilterName(),this);
306 }
307
308
309 public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterchain) throws IOException, ServletException
310 {
311 final HttpServletRequest srequest = (HttpServletRequest)request;
312 final HttpServletResponse sresponse = (HttpServletResponse)response;
313
314 final long now=_requestTimeoutQ.getNow();
315
316
317 RateTracker tracker = (RateTracker)request.getAttribute(__TRACKER);
318
319 if (tracker==null)
320 {
321
322
323
324 tracker = getRateTracker(request);
325
326
327 final boolean overRateLimit = tracker.isRateExceeded(now);
328
329
330 if (!overRateLimit)
331 {
332 doFilterChain(filterchain,srequest,sresponse);
333 return;
334 }
335
336
337 LOG.warn("DOS ALERT: ip="+srequest.getRemoteAddr()+",session="+srequest.getRequestedSessionId()+",user="+srequest.getUserPrincipal());
338
339
340 switch((int)_delayMs)
341 {
342 case -1:
343 {
344
345 if (_insertHeaders)
346 ((HttpServletResponse)response).addHeader("DoSFilter","unavailable");
347
348 ((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
349 return;
350 }
351 case 0:
352 {
353
354 request.setAttribute(__TRACKER,tracker);
355 break;
356 }
357 default:
358 {
359
360 if (_insertHeaders)
361 ((HttpServletResponse)response).addHeader("DoSFilter","delayed");
362 Continuation continuation = ContinuationSupport.getContinuation(request);
363 request.setAttribute(__TRACKER,tracker);
364 if (_delayMs > 0)
365 continuation.setTimeout(_delayMs);
366 continuation.addContinuationListener(new ContinuationListener()
367 {
368
369 public void onComplete(Continuation continuation)
370 {
371 }
372
373 public void onTimeout(Continuation continuation)
374 {
375 }
376 });
377 continuation.suspend();
378 return;
379 }
380 }
381 }
382
383 boolean accepted = false;
384 try
385 {
386
387 accepted = _passes.tryAcquire(_maxWaitMs,TimeUnit.MILLISECONDS);
388
389 if (!accepted)
390 {
391
392 final Continuation continuation = ContinuationSupport.getContinuation(request);
393
394 Boolean throttled = (Boolean)request.getAttribute(__THROTTLED);
395 if (throttled!=Boolean.TRUE && _throttleMs>0)
396 {
397 int priority = getPriority(request,tracker);
398 request.setAttribute(__THROTTLED,Boolean.TRUE);
399 if (_insertHeaders)
400 ((HttpServletResponse)response).addHeader("DoSFilter","throttled");
401 if (_throttleMs > 0)
402 continuation.setTimeout(_throttleMs);
403 continuation.suspend();
404
405 continuation.addContinuationListener(_listener[priority]);
406 _queue[priority].add(continuation);
407 return;
408 }
409
410 else if (request.getAttribute("javax.servlet.resumed")==Boolean.TRUE)
411 {
412
413 _passes.acquire();
414 accepted = true;
415 }
416 }
417
418
419 if (accepted)
420
421 doFilterChain(filterchain,srequest,sresponse);
422 else
423 {
424
425 if (_insertHeaders)
426 ((HttpServletResponse)response).addHeader("DoSFilter","unavailable");
427 ((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
428 }
429 }
430 catch (InterruptedException e)
431 {
432 _context.log("DoS",e);
433 ((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
434 }
435 catch (Exception e)
436 {
437 e.printStackTrace();
438 }
439 finally
440 {
441 if (accepted)
442 {
443
444 for (int p = _queue.length; p-- > 0;)
445 {
446 Continuation continuation = _queue[p].poll();
447 if (continuation != null && continuation.isSuspended())
448 {
449 continuation.resume();
450 break;
451 }
452 }
453 _passes.release();
454 }
455 }
456 }
457
458
459
460
461
462
463
464
465 protected void doFilterChain(FilterChain chain, final HttpServletRequest request, final HttpServletResponse response)
466 throws IOException, ServletException
467 {
468 final Thread thread=Thread.currentThread();
469
470 final Timeout.Task requestTimeout = new Timeout.Task()
471 {
472 public void expired()
473 {
474 closeConnection(request, response, thread);
475 }
476 };
477
478 try
479 {
480 synchronized (_requestTimeoutQ)
481 {
482 _requestTimeoutQ.schedule(requestTimeout);
483 }
484 chain.doFilter(request,response);
485 }
486 finally
487 {
488 synchronized (_requestTimeoutQ)
489 {
490 requestTimeout.cancel();
491 }
492 }
493 }
494
495
496
497
498
499
500
501
502 protected void closeConnection(HttpServletRequest request, HttpServletResponse response, Thread thread)
503 {
504
505 if( !response.isCommitted() )
506 {
507 response.setHeader("Connection", "close");
508 }
509 try
510 {
511 try
512 {
513 response.getWriter().close();
514 }
515 catch (IllegalStateException e)
516 {
517 response.getOutputStream().close();
518 }
519 }
520 catch (IOException e)
521 {
522 LOG.warn(e);
523 }
524
525
526 thread.interrupt();
527 }
528
529
530
531
532
533
534
535
536 protected int getPriority(ServletRequest request, RateTracker tracker)
537 {
538 if (extractUserId(request)!=null)
539 return USER_AUTH;
540 if (tracker!=null)
541 return tracker.getType();
542 return USER_UNKNOWN;
543 }
544
545
546
547
548 protected int getMaxPriority()
549 {
550 return USER_AUTH;
551 }
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569 public RateTracker getRateTracker(ServletRequest request)
570 {
571 HttpServletRequest srequest = (HttpServletRequest)request;
572 HttpSession session=srequest.getSession(false);
573
574 String loadId = extractUserId(request);
575 final int type;
576 if (loadId != null)
577 {
578 type = USER_AUTH;
579 }
580 else
581 {
582 if (_trackSessions && session!=null && !session.isNew())
583 {
584 loadId=session.getId();
585 type = USER_SESSION;
586 }
587 else
588 {
589 loadId = _remotePort?(request.getRemoteAddr()+request.getRemotePort()):request.getRemoteAddr();
590 type = USER_IP;
591 }
592 }
593
594 RateTracker tracker=_rateTrackers.get(loadId);
595
596 if (tracker==null)
597 {
598 RateTracker t;
599 if (_whitelist.contains(request.getRemoteAddr()))
600 {
601 t = new FixedRateTracker(loadId,type,_maxRequestsPerSec);
602 }
603 else
604 {
605 t = new RateTracker(loadId,type,_maxRequestsPerSec);
606 }
607
608 tracker=_rateTrackers.putIfAbsent(loadId,t);
609 if (tracker==null)
610 tracker=t;
611
612 if (type == USER_IP)
613 {
614
615 synchronized (_trackerTimeoutQ)
616 {
617 _trackerTimeoutQ.schedule(tracker);
618 }
619 }
620 else if (session!=null)
621
622 session.setAttribute(__TRACKER,tracker);
623 }
624
625 return tracker;
626 }
627
628 public void destroy()
629 {
630 _running=false;
631 _timerThread.interrupt();
632 synchronized (_requestTimeoutQ)
633 {
634 _requestTimeoutQ.cancelAll();
635 }
636 synchronized (_trackerTimeoutQ)
637 {
638 _trackerTimeoutQ.cancelAll();
639 }
640 _rateTrackers.clear();
641 _whitelist.clear();
642 }
643
644
645
646
647
648
649
650
651 protected String extractUserId(ServletRequest request)
652 {
653 return null;
654 }
655
656
657
658
659
660 protected void initWhitelist()
661 {
662 _whitelist.clear();
663 StringTokenizer tokenizer = new StringTokenizer(_whitelistStr, ",");
664 while (tokenizer.hasMoreTokens())
665 _whitelist.add(tokenizer.nextToken().trim());
666
667 LOG.info("Whitelisted IP addresses: {}", _whitelist.toString());
668 }
669
670
671
672
673
674
675
676
677
678 public int getMaxRequestsPerSec()
679 {
680 return _maxRequestsPerSec;
681 }
682
683
684
685
686
687
688
689
690
691 public void setMaxRequestsPerSec(int value)
692 {
693 _maxRequestsPerSec = value;
694 }
695
696
697
698
699
700
701 public long getDelayMs()
702 {
703 return _delayMs;
704 }
705
706
707
708
709
710
711
712
713 public void setDelayMs(long value)
714 {
715 _delayMs = value;
716 }
717
718
719
720
721
722
723
724
725 public long getMaxWaitMs()
726 {
727 return _maxWaitMs;
728 }
729
730
731
732
733
734
735
736
737 public void setMaxWaitMs(long value)
738 {
739 _maxWaitMs = value;
740 }
741
742
743
744
745
746
747
748
749 public int getThrottledRequests()
750 {
751 return _throttledRequests;
752 }
753
754
755
756
757
758
759
760
761 public void setThrottledRequests(int value)
762 {
763 _passes = new Semaphore((value-_throttledRequests+_passes.availablePermits()), true);
764 _throttledRequests = value;
765 }
766
767
768
769
770
771
772
773 public long getThrottleMs()
774 {
775 return _throttleMs;
776 }
777
778
779
780
781
782
783
784 public void setThrottleMs(long value)
785 {
786 _throttleMs = value;
787 }
788
789
790
791
792
793
794
795
796 public long getMaxRequestMs()
797 {
798 return _maxRequestMs;
799 }
800
801
802
803
804
805
806
807
808 public void setMaxRequestMs(long value)
809 {
810 _maxRequestMs = value;
811 }
812
813
814
815
816
817
818
819
820
821 public long getMaxIdleTrackerMs()
822 {
823 return _maxIdleTrackerMs;
824 }
825
826
827
828
829
830
831
832
833
834 public void setMaxIdleTrackerMs(long value)
835 {
836 _maxIdleTrackerMs = value;
837 }
838
839
840
841
842
843
844
845 public boolean isInsertHeaders()
846 {
847 return _insertHeaders;
848 }
849
850
851
852
853
854
855
856 public void setInsertHeaders(boolean value)
857 {
858 _insertHeaders = value;
859 }
860
861
862
863
864
865
866
867 public boolean isTrackSessions()
868 {
869 return _trackSessions;
870 }
871
872
873
874
875
876
877 public void setTrackSessions(boolean value)
878 {
879 _trackSessions = value;
880 }
881
882
883
884
885
886
887
888
889 public boolean isRemotePort()
890 {
891 return _remotePort;
892 }
893
894
895
896
897
898
899
900
901
902 public void setRemotePort(boolean value)
903 {
904 _remotePort = value;
905 }
906
907
908
909
910
911
912
913 public String getWhitelist()
914 {
915 return _whitelistStr;
916 }
917
918
919
920
921
922
923
924
925 public void setWhitelist(String value)
926 {
927 _whitelistStr = value;
928 initWhitelist();
929 }
930
931
932
933
934
935 class RateTracker extends Timeout.Task implements HttpSessionBindingListener, HttpSessionActivationListener
936 {
937 transient protected final String _id;
938 transient protected final int _type;
939 transient protected final long[] _timestamps;
940 transient protected int _next;
941
942
943 public RateTracker(String id, int type,int maxRequestsPerSecond)
944 {
945 _id = id;
946 _type = type;
947 _timestamps=new long[maxRequestsPerSecond];
948 _next=0;
949 }
950
951
952
953
954 public boolean isRateExceeded(long now)
955 {
956 final long last;
957 synchronized (this)
958 {
959 last=_timestamps[_next];
960 _timestamps[_next]=now;
961 _next= (_next+1)%_timestamps.length;
962 }
963
964 boolean exceeded=last!=0 && (now-last)<1000L;
965 return exceeded;
966 }
967
968
969 public String getId()
970 {
971 return _id;
972 }
973
974 public int getType()
975 {
976 return _type;
977 }
978
979
980 public void valueBound(HttpSessionBindingEvent event)
981 {
982 if (LOG.isDebugEnabled())
983 LOG.debug("Value bound:"+_id);
984 }
985
986 public void valueUnbound(HttpSessionBindingEvent event)
987 {
988
989 if (_rateTrackers != null)
990 _rateTrackers.remove(_id);
991 if (LOG.isDebugEnabled()) LOG.debug("Tracker removed: "+_id);
992 }
993
994 public void sessionWillPassivate(HttpSessionEvent se)
995 {
996
997
998 if (_rateTrackers != null)
999 _rateTrackers.remove(_id);
1000 se.getSession().removeAttribute(__TRACKER);
1001 if (LOG.isDebugEnabled()) LOG.debug("Value removed: "+_id);
1002 }
1003
1004 public void sessionDidActivate(HttpSessionEvent se)
1005 {
1006 LOG.warn("Unexpected session activation");
1007 }
1008
1009
1010 public void expired()
1011 {
1012 if (_rateTrackers != null && _trackerTimeoutQ != null)
1013 {
1014 long now = _trackerTimeoutQ.getNow();
1015 int latestIndex = _next == 0 ? (_timestamps.length-1) : (_next - 1 );
1016 long last=_timestamps[latestIndex];
1017 boolean hasRecentRequest = last != 0 && (now-last)<1000L;
1018
1019 if (hasRecentRequest)
1020 reschedule();
1021 else
1022 _rateTrackers.remove(_id);
1023 }
1024 }
1025
1026 @Override
1027 public String toString()
1028 {
1029 return "RateTracker/"+_id+"/"+_type;
1030 }
1031
1032
1033 }
1034
1035 class FixedRateTracker extends RateTracker
1036 {
1037 public FixedRateTracker(String id, int type, int numRecentRequestsTracked)
1038 {
1039 super(id,type,numRecentRequestsTracked);
1040 }
1041
1042 @Override
1043 public boolean isRateExceeded(long now)
1044 {
1045
1046
1047
1048 synchronized (this)
1049 {
1050 _timestamps[_next]=now;
1051 _next= (_next+1)%_timestamps.length;
1052 }
1053
1054 return false;
1055 }
1056
1057 @Override
1058 public String toString()
1059 {
1060 return "Fixed"+super.toString();
1061 }
1062 }
1063 }