1
2
3
4
5
6
7
8
9
10
11
12
13
14 package org.eclipse.jetty.servlets;
15
16 import java.io.IOException;
17 import java.util.HashSet;
18 import java.util.Queue;
19 import java.util.StringTokenizer;
20 import java.util.concurrent.ConcurrentHashMap;
21 import java.util.concurrent.ConcurrentLinkedQueue;
22 import java.util.concurrent.Semaphore;
23 import java.util.concurrent.TimeUnit;
24 import javax.servlet.Filter;
25 import javax.servlet.FilterChain;
26 import javax.servlet.FilterConfig;
27 import javax.servlet.ServletContext;
28 import javax.servlet.ServletException;
29 import javax.servlet.ServletRequest;
30 import javax.servlet.ServletResponse;
31 import javax.servlet.http.HttpServletRequest;
32 import javax.servlet.http.HttpServletResponse;
33 import javax.servlet.http.HttpSession;
34 import javax.servlet.http.HttpSessionBindingEvent;
35 import javax.servlet.http.HttpSessionBindingListener;
36
37 import org.eclipse.jetty.continuation.Continuation;
38 import org.eclipse.jetty.continuation.ContinuationListener;
39 import org.eclipse.jetty.continuation.ContinuationSupport;
40 import org.eclipse.jetty.server.handler.ContextHandler;
41 import org.eclipse.jetty.util.log.Log;
42 import org.eclipse.jetty.util.log.Logger;
43 import org.eclipse.jetty.util.thread.Timeout;
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114 public class DoSFilter implements Filter
115 {
116 private static final Logger LOG = Log.getLogger(DoSFilter.class);
117
118 final static String __TRACKER = "DoSFilter.Tracker";
119 final static String __THROTTLED = "DoSFilter.Throttled";
120
121 final static int __DEFAULT_MAX_REQUESTS_PER_SEC = 25;
122 final static int __DEFAULT_DELAY_MS = 100;
123 final static int __DEFAULT_THROTTLE = 5;
124 final static int __DEFAULT_WAIT_MS=50;
125 final static long __DEFAULT_THROTTLE_MS = 30000L;
126 final static long __DEFAULT_MAX_REQUEST_MS_INIT_PARAM=30000L;
127 final static long __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM=30000L;
128
129 final static String MANAGED_ATTR_INIT_PARAM="managedAttr";
130 final static String MAX_REQUESTS_PER_S_INIT_PARAM = "maxRequestsPerSec";
131 final static String DELAY_MS_INIT_PARAM = "delayMs";
132 final static String THROTTLED_REQUESTS_INIT_PARAM = "throttledRequests";
133 final static String MAX_WAIT_INIT_PARAM="maxWaitMs";
134 final static String THROTTLE_MS_INIT_PARAM = "throttleMs";
135 final static String MAX_REQUEST_MS_INIT_PARAM="maxRequestMs";
136 final static String MAX_IDLE_TRACKER_MS_INIT_PARAM="maxIdleTrackerMs";
137 final static String INSERT_HEADERS_INIT_PARAM="insertHeaders";
138 final static String TRACK_SESSIONS_INIT_PARAM="trackSessions";
139 final static String REMOTE_PORT_INIT_PARAM="remotePort";
140 final static String IP_WHITELIST_INIT_PARAM="ipWhitelist";
141
142 final static int USER_AUTH = 2;
143 final static int USER_SESSION = 2;
144 final static int USER_IP = 1;
145 final static int USER_UNKNOWN = 0;
146
147 ServletContext _context;
148
149 protected String _name;
150 protected long _delayMs;
151 protected long _throttleMs;
152 protected long _maxWaitMs;
153 protected long _maxRequestMs;
154 protected long _maxIdleTrackerMs;
155 protected boolean _insertHeaders;
156 protected boolean _trackSessions;
157 protected boolean _remotePort;
158 protected int _throttledRequests;
159 protected Semaphore _passes;
160 protected Queue<Continuation>[] _queue;
161 protected ContinuationListener[] _listener;
162
163 protected int _maxRequestsPerSec;
164 protected final ConcurrentHashMap<String, RateTracker> _rateTrackers=new ConcurrentHashMap<String, RateTracker>();
165 protected String _whitelistStr;
166 private final HashSet<String> _whitelist = new HashSet<String>();
167
168 private final Timeout _requestTimeoutQ = new Timeout();
169 private final Timeout _trackerTimeoutQ = new Timeout();
170
171 private Thread _timerThread;
172 private volatile boolean _running;
173
174 public void init(FilterConfig filterConfig)
175 {
176 _context = filterConfig.getServletContext();
177
178 _queue = new Queue[getMaxPriority() + 1];
179 _listener = new ContinuationListener[getMaxPriority() + 1];
180 for (int p = 0; p < _queue.length; p++)
181 {
182 _queue[p] = new ConcurrentLinkedQueue<Continuation>();
183
184 final int priority=p;
185 _listener[p] = new ContinuationListener()
186 {
187 public void onComplete(Continuation continuation)
188 {
189 }
190
191 public void onTimeout(Continuation continuation)
192 {
193 _queue[priority].remove(continuation);
194 }
195 };
196 }
197
198 _rateTrackers.clear();
199
200 int baseRateLimit = __DEFAULT_MAX_REQUESTS_PER_SEC;
201 if (filterConfig.getInitParameter(MAX_REQUESTS_PER_S_INIT_PARAM) != null)
202 baseRateLimit = Integer.parseInt(filterConfig.getInitParameter(MAX_REQUESTS_PER_S_INIT_PARAM));
203 _maxRequestsPerSec = baseRateLimit;
204
205 long delay = __DEFAULT_DELAY_MS;
206 if (filterConfig.getInitParameter(DELAY_MS_INIT_PARAM) != null)
207 delay = Integer.parseInt(filterConfig.getInitParameter(DELAY_MS_INIT_PARAM));
208 _delayMs = delay;
209
210 int throttledRequests = __DEFAULT_THROTTLE;
211 if (filterConfig.getInitParameter(THROTTLED_REQUESTS_INIT_PARAM) != null)
212 throttledRequests = Integer.parseInt(filterConfig.getInitParameter(THROTTLED_REQUESTS_INIT_PARAM));
213 _passes = new Semaphore(throttledRequests,true);
214 _throttledRequests = throttledRequests;
215
216 long wait = __DEFAULT_WAIT_MS;
217 if (filterConfig.getInitParameter(MAX_WAIT_INIT_PARAM) != null)
218 wait = Integer.parseInt(filterConfig.getInitParameter(MAX_WAIT_INIT_PARAM));
219 _maxWaitMs = wait;
220
221 long suspend = __DEFAULT_THROTTLE_MS;
222 if (filterConfig.getInitParameter(THROTTLE_MS_INIT_PARAM) != null)
223 suspend = Integer.parseInt(filterConfig.getInitParameter(THROTTLE_MS_INIT_PARAM));
224 _throttleMs = suspend;
225
226 long maxRequestMs = __DEFAULT_MAX_REQUEST_MS_INIT_PARAM;
227 if (filterConfig.getInitParameter(MAX_REQUEST_MS_INIT_PARAM) != null )
228 maxRequestMs = Long.parseLong(filterConfig.getInitParameter(MAX_REQUEST_MS_INIT_PARAM));
229 _maxRequestMs = maxRequestMs;
230
231 long maxIdleTrackerMs = __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM;
232 if (filterConfig.getInitParameter(MAX_IDLE_TRACKER_MS_INIT_PARAM) != null )
233 maxIdleTrackerMs = Long.parseLong(filterConfig.getInitParameter(MAX_IDLE_TRACKER_MS_INIT_PARAM));
234 _maxIdleTrackerMs = maxIdleTrackerMs;
235
236 _whitelistStr = "";
237 if (filterConfig.getInitParameter(IP_WHITELIST_INIT_PARAM) !=null )
238 _whitelistStr = filterConfig.getInitParameter(IP_WHITELIST_INIT_PARAM);
239 initWhitelist();
240
241 String tmp = filterConfig.getInitParameter(INSERT_HEADERS_INIT_PARAM);
242 _insertHeaders = tmp==null || Boolean.parseBoolean(tmp);
243
244 tmp = filterConfig.getInitParameter(TRACK_SESSIONS_INIT_PARAM);
245 _trackSessions = tmp==null || Boolean.parseBoolean(tmp);
246
247 tmp = filterConfig.getInitParameter(REMOTE_PORT_INIT_PARAM);
248 _remotePort = tmp!=null&& Boolean.parseBoolean(tmp);
249
250 _requestTimeoutQ.setNow();
251 _requestTimeoutQ.setDuration(_maxRequestMs);
252
253 _trackerTimeoutQ.setNow();
254 _trackerTimeoutQ.setDuration(_maxIdleTrackerMs);
255
256 _running=true;
257 _timerThread = (new Thread()
258 {
259 public void run()
260 {
261 try
262 {
263 while (_running)
264 {
265 long now;
266 synchronized (_requestTimeoutQ)
267 {
268 now = _requestTimeoutQ.setNow();
269 _requestTimeoutQ.tick();
270 }
271 synchronized (_trackerTimeoutQ)
272 {
273 _trackerTimeoutQ.setNow(now);
274 _trackerTimeoutQ.tick();
275 }
276 try
277 {
278 Thread.sleep(100);
279 }
280 catch (InterruptedException e)
281 {
282 LOG.ignore(e);
283 }
284 }
285 }
286 finally
287 {
288 LOG.info("DoSFilter timer exited");
289 }
290 }
291 });
292 _timerThread.start();
293
294 if (_context!=null && Boolean.parseBoolean(filterConfig.getInitParameter(MANAGED_ATTR_INIT_PARAM)))
295 _context.setAttribute(filterConfig.getFilterName(),this);
296 }
297
298
299 public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterchain) throws IOException, ServletException
300 {
301 final HttpServletRequest srequest = (HttpServletRequest)request;
302 final HttpServletResponse sresponse = (HttpServletResponse)response;
303
304 final long now=_requestTimeoutQ.getNow();
305
306
307 RateTracker tracker = (RateTracker)request.getAttribute(__TRACKER);
308
309 if (tracker==null)
310 {
311
312
313
314 tracker = getRateTracker(request);
315
316
317 final boolean overRateLimit = tracker.isRateExceeded(now);
318
319
320 if (!overRateLimit)
321 {
322 doFilterChain(filterchain,srequest,sresponse);
323 return;
324 }
325
326
327 LOG.warn("DOS ALERT: ip="+srequest.getRemoteAddr()+",session="+srequest.getRequestedSessionId()+",user="+srequest.getUserPrincipal());
328
329
330 switch((int)_delayMs)
331 {
332 case -1:
333 {
334
335 if (_insertHeaders)
336 ((HttpServletResponse)response).addHeader("DoSFilter","unavailable");
337
338 ((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
339 return;
340 }
341 case 0:
342 {
343
344 request.setAttribute(__TRACKER,tracker);
345 break;
346 }
347 default:
348 {
349
350 if (_insertHeaders)
351 ((HttpServletResponse)response).addHeader("DoSFilter","delayed");
352 Continuation continuation = ContinuationSupport.getContinuation(request);
353 request.setAttribute(__TRACKER,tracker);
354 if (_delayMs > 0)
355 continuation.setTimeout(_delayMs);
356 continuation.addContinuationListener(new ContinuationListener()
357 {
358
359 public void onComplete(Continuation continuation)
360 {
361 }
362
363 public void onTimeout(Continuation continuation)
364 {
365 }
366 });
367 continuation.suspend();
368 return;
369 }
370 }
371 }
372
373 boolean accepted = false;
374 try
375 {
376
377 accepted = _passes.tryAcquire(_maxWaitMs,TimeUnit.MILLISECONDS);
378
379 if (!accepted)
380 {
381
382 final Continuation continuation = ContinuationSupport.getContinuation(request);
383
384 Boolean throttled = (Boolean)request.getAttribute(__THROTTLED);
385 if (throttled!=Boolean.TRUE && _throttleMs>0)
386 {
387 int priority = getPriority(request,tracker);
388 request.setAttribute(__THROTTLED,Boolean.TRUE);
389 if (_insertHeaders)
390 ((HttpServletResponse)response).addHeader("DoSFilter","throttled");
391 if (_throttleMs > 0)
392 continuation.setTimeout(_throttleMs);
393 continuation.suspend();
394
395 continuation.addContinuationListener(_listener[priority]);
396 _queue[priority].add(continuation);
397 return;
398 }
399
400 else if (request.getAttribute("javax.servlet.resumed")==Boolean.TRUE)
401 {
402
403 _passes.acquire();
404 accepted = true;
405 }
406 }
407
408
409 if (accepted)
410
411 doFilterChain(filterchain,srequest,sresponse);
412 else
413 {
414
415 if (_insertHeaders)
416 ((HttpServletResponse)response).addHeader("DoSFilter","unavailable");
417 ((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
418 }
419 }
420 catch (InterruptedException e)
421 {
422 _context.log("DoS",e);
423 ((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
424 }
425 catch (Exception e)
426 {
427 e.printStackTrace();
428 }
429 finally
430 {
431 if (accepted)
432 {
433
434 for (int p = _queue.length; p-- > 0;)
435 {
436 Continuation continuation = _queue[p].poll();
437 if (continuation != null && continuation.isSuspended())
438 {
439 continuation.resume();
440 break;
441 }
442 }
443 _passes.release();
444 }
445 }
446 }
447
448
449
450
451
452
453
454
455 protected void doFilterChain(FilterChain chain, final HttpServletRequest request, final HttpServletResponse response)
456 throws IOException, ServletException
457 {
458 final Thread thread=Thread.currentThread();
459
460 final Timeout.Task requestTimeout = new Timeout.Task()
461 {
462 public void expired()
463 {
464 closeConnection(request, response, thread);
465 }
466 };
467
468 try
469 {
470 synchronized (_requestTimeoutQ)
471 {
472 _requestTimeoutQ.schedule(requestTimeout);
473 }
474 chain.doFilter(request,response);
475 }
476 finally
477 {
478 synchronized (_requestTimeoutQ)
479 {
480 requestTimeout.cancel();
481 }
482 }
483 }
484
485
486
487
488
489
490
491
492 protected void closeConnection(HttpServletRequest request, HttpServletResponse response, Thread thread)
493 {
494
495 if( !response.isCommitted() )
496 {
497 response.setHeader("Connection", "close");
498 }
499 try
500 {
501 try
502 {
503 response.getWriter().close();
504 }
505 catch (IllegalStateException e)
506 {
507 response.getOutputStream().close();
508 }
509 }
510 catch (IOException e)
511 {
512 LOG.warn(e);
513 }
514
515
516 thread.interrupt();
517 }
518
519
520
521
522
523
524
525
526 protected int getPriority(ServletRequest request, RateTracker tracker)
527 {
528 if (extractUserId(request)!=null)
529 return USER_AUTH;
530 if (tracker!=null)
531 return tracker.getType();
532 return USER_UNKNOWN;
533 }
534
535
536
537
538 protected int getMaxPriority()
539 {
540 return USER_AUTH;
541 }
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559 public RateTracker getRateTracker(ServletRequest request)
560 {
561 HttpServletRequest srequest = (HttpServletRequest)request;
562 HttpSession session=srequest.getSession(false);
563
564 String loadId = extractUserId(request);
565 final int type;
566 if (loadId != null)
567 {
568 type = USER_AUTH;
569 }
570 else
571 {
572 if (_trackSessions && session!=null && !session.isNew())
573 {
574 loadId=session.getId();
575 type = USER_SESSION;
576 }
577 else
578 {
579 loadId = _remotePort?(request.getRemoteAddr()+request.getRemotePort()):request.getRemoteAddr();
580 type = USER_IP;
581 }
582 }
583
584 RateTracker tracker=_rateTrackers.get(loadId);
585
586 if (tracker==null)
587 {
588 RateTracker t;
589 if (_whitelist.contains(request.getRemoteAddr()))
590 {
591 t = new FixedRateTracker(loadId,type,_maxRequestsPerSec);
592 }
593 else
594 {
595 t = new RateTracker(loadId,type,_maxRequestsPerSec);
596 }
597
598 tracker=_rateTrackers.putIfAbsent(loadId,t);
599 if (tracker==null)
600 tracker=t;
601
602 if (type == USER_IP)
603 {
604
605 synchronized (_trackerTimeoutQ)
606 {
607 _trackerTimeoutQ.schedule(tracker);
608 }
609 }
610 else if (session!=null)
611
612 session.setAttribute(__TRACKER,tracker);
613 }
614
615 return tracker;
616 }
617
618 public void destroy()
619 {
620 _running=false;
621 _timerThread.interrupt();
622 synchronized (_requestTimeoutQ)
623 {
624 _requestTimeoutQ.cancelAll();
625 }
626 synchronized (_trackerTimeoutQ)
627 {
628 _trackerTimeoutQ.cancelAll();
629 }
630 _rateTrackers.clear();
631 _whitelist.clear();
632 }
633
634
635
636
637
638
639
640
641 protected String extractUserId(ServletRequest request)
642 {
643 return null;
644 }
645
646
647
648
649
650 protected void initWhitelist()
651 {
652 _whitelist.clear();
653 StringTokenizer tokenizer = new StringTokenizer(_whitelistStr, ",");
654 while (tokenizer.hasMoreTokens())
655 _whitelist.add(tokenizer.nextToken().trim());
656
657 LOG.info("Whitelisted IP addresses: {}", _whitelist.toString());
658 }
659
660
661
662
663
664
665
666
667
668 public int getMaxRequestsPerSec()
669 {
670 return _maxRequestsPerSec;
671 }
672
673
674
675
676
677
678
679
680
681 public void setMaxRequestsPerSec(int value)
682 {
683 _maxRequestsPerSec = value;
684 }
685
686
687
688
689
690
691 public long getDelayMs()
692 {
693 return _delayMs;
694 }
695
696
697
698
699
700
701
702
703 public void setDelayMs(long value)
704 {
705 _delayMs = value;
706 }
707
708
709
710
711
712
713
714
715 public long getMaxWaitMs()
716 {
717 return _maxWaitMs;
718 }
719
720
721
722
723
724
725
726
727 public void setMaxWaitMs(long value)
728 {
729 _maxWaitMs = value;
730 }
731
732
733
734
735
736
737
738
739 public int getThrottledRequests()
740 {
741 return _throttledRequests;
742 }
743
744
745
746
747
748
749
750
751 public void setThrottledRequests(int value)
752 {
753 _passes = new Semaphore((value-_throttledRequests+_passes.availablePermits()), true);
754 _throttledRequests = value;
755 }
756
757
758
759
760
761
762
763 public long getThrottleMs()
764 {
765 return _throttleMs;
766 }
767
768
769
770
771
772
773
774 public void setThrottleMs(long value)
775 {
776 _throttleMs = value;
777 }
778
779
780
781
782
783
784
785
786 public long getMaxRequestMs()
787 {
788 return _maxRequestMs;
789 }
790
791
792
793
794
795
796
797
798 public void setMaxRequestMs(long value)
799 {
800 _maxRequestMs = value;
801 }
802
803
804
805
806
807
808
809
810
811 public long getMaxIdleTrackerMs()
812 {
813 return _maxIdleTrackerMs;
814 }
815
816
817
818
819
820
821
822
823
824 public void setMaxIdleTrackerMs(long value)
825 {
826 _maxIdleTrackerMs = value;
827 }
828
829
830
831
832
833
834
835 public boolean isInsertHeaders()
836 {
837 return _insertHeaders;
838 }
839
840
841
842
843
844
845
846 public void setInsertHeaders(boolean value)
847 {
848 _insertHeaders = value;
849 }
850
851
852
853
854
855
856
857 public boolean isTrackSessions()
858 {
859 return _trackSessions;
860 }
861
862
863
864
865
866
867 public void setTrackSessions(boolean value)
868 {
869 _trackSessions = value;
870 }
871
872
873
874
875
876
877
878
879 public boolean isRemotePort()
880 {
881 return _remotePort;
882 }
883
884
885
886
887
888
889
890
891
892 public void setRemotePort(boolean value)
893 {
894 _remotePort = value;
895 }
896
897
898
899
900
901
902
903 public String getWhitelist()
904 {
905 return _whitelistStr;
906 }
907
908
909
910
911
912
913
914
915 public void setWhitelist(String value)
916 {
917 _whitelistStr = value;
918 initWhitelist();
919 }
920
921
922
923
924
925 class RateTracker extends Timeout.Task implements HttpSessionBindingListener
926 {
927 protected final String _id;
928 protected final int _type;
929 protected final long[] _timestamps;
930 protected int _next;
931
932 public RateTracker(String id, int type,int maxRequestsPerSecond)
933 {
934 _id = id;
935 _type = type;
936 _timestamps=new long[maxRequestsPerSecond];
937 _next=0;
938 }
939
940
941
942
943 public boolean isRateExceeded(long now)
944 {
945 final long last;
946 synchronized (this)
947 {
948 last=_timestamps[_next];
949 _timestamps[_next]=now;
950 _next= (_next+1)%_timestamps.length;
951 }
952
953 boolean exceeded=last!=0 && (now-last)<1000L;
954 return exceeded;
955 }
956
957
958 public String getId()
959 {
960 return _id;
961 }
962
963 public int getType()
964 {
965 return _type;
966 }
967
968
969 public void valueBound(HttpSessionBindingEvent event)
970 {
971 }
972
973 public void valueUnbound(HttpSessionBindingEvent event)
974 {
975 _rateTrackers.remove(_id);
976 }
977
978 public void expired()
979 {
980 long now = _trackerTimeoutQ.getNow();
981 int latestIndex = _next == 0 ? 3 : (_next - 1 ) % _timestamps.length;
982 long last=_timestamps[latestIndex];
983 boolean hasRecentRequest = last != 0 && (now-last)<1000L;
984
985 if (hasRecentRequest)
986 reschedule();
987 else
988 _rateTrackers.remove(_id);
989 }
990
991 @Override
992 public String toString()
993 {
994 return "RateTracker/"+_id+"/"+_type;
995 }
996 }
997
998 class FixedRateTracker extends RateTracker
999 {
1000 public FixedRateTracker(String id, int type, int numRecentRequestsTracked)
1001 {
1002 super(id,type,numRecentRequestsTracked);
1003 }
1004
1005 @Override
1006 public boolean isRateExceeded(long now)
1007 {
1008
1009
1010
1011 synchronized (this)
1012 {
1013 _timestamps[_next]=now;
1014 _next= (_next+1)%_timestamps.length;
1015 }
1016
1017 return false;
1018 }
1019
1020 @Override
1021 public String toString()
1022 {
1023 return "Fixed"+super.toString();
1024 }
1025 }
1026 }