1
2
3
4
5
6
7
8
9
10
11
12
13
14 package org.eclipse.jetty.servlets;
15
16 import java.io.IOException;
17 import java.util.HashSet;
18 import java.util.Queue;
19 import java.util.StringTokenizer;
20 import java.util.concurrent.ConcurrentHashMap;
21 import java.util.concurrent.ConcurrentLinkedQueue;
22 import java.util.concurrent.Semaphore;
23 import java.util.concurrent.TimeUnit;
24 import javax.servlet.Filter;
25 import javax.servlet.FilterChain;
26 import javax.servlet.FilterConfig;
27 import javax.servlet.ServletContext;
28 import javax.servlet.ServletException;
29 import javax.servlet.ServletRequest;
30 import javax.servlet.ServletResponse;
31 import javax.servlet.http.HttpServletRequest;
32 import javax.servlet.http.HttpServletResponse;
33 import javax.servlet.http.HttpSession;
34 import javax.servlet.http.HttpSessionBindingEvent;
35 import javax.servlet.http.HttpSessionBindingListener;
36
37 import org.eclipse.jetty.continuation.Continuation;
38 import org.eclipse.jetty.continuation.ContinuationListener;
39 import org.eclipse.jetty.continuation.ContinuationSupport;
40 import org.eclipse.jetty.server.handler.ContextHandler;
41 import org.eclipse.jetty.util.log.Log;
42 import org.eclipse.jetty.util.thread.Timeout;
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113 public class DoSFilter implements Filter
114 {
115 final static String __TRACKER = "DoSFilter.Tracker";
116 final static String __THROTTLED = "DoSFilter.Throttled";
117
118 final static int __DEFAULT_MAX_REQUESTS_PER_SEC = 25;
119 final static int __DEFAULT_DELAY_MS = 100;
120 final static int __DEFAULT_THROTTLE = 5;
121 final static int __DEFAULT_WAIT_MS=50;
122 final static long __DEFAULT_THROTTLE_MS = 30000L;
123 final static long __DEFAULT_MAX_REQUEST_MS_INIT_PARAM=30000L;
124 final static long __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM=30000L;
125
126 final static String MANAGED_ATTR_INIT_PARAM="managedAttr";
127 final static String MAX_REQUESTS_PER_S_INIT_PARAM = "maxRequestsPerSec";
128 final static String DELAY_MS_INIT_PARAM = "delayMs";
129 final static String THROTTLED_REQUESTS_INIT_PARAM = "throttledRequests";
130 final static String MAX_WAIT_INIT_PARAM="maxWaitMs";
131 final static String THROTTLE_MS_INIT_PARAM = "throttleMs";
132 final static String MAX_REQUEST_MS_INIT_PARAM="maxRequestMs";
133 final static String MAX_IDLE_TRACKER_MS_INIT_PARAM="maxIdleTrackerMs";
134 final static String INSERT_HEADERS_INIT_PARAM="insertHeaders";
135 final static String TRACK_SESSIONS_INIT_PARAM="trackSessions";
136 final static String REMOTE_PORT_INIT_PARAM="remotePort";
137 final static String IP_WHITELIST_INIT_PARAM="ipWhitelist";
138
139 final static int USER_AUTH = 2;
140 final static int USER_SESSION = 2;
141 final static int USER_IP = 1;
142 final static int USER_UNKNOWN = 0;
143
144 ServletContext _context;
145
146 protected String _name;
147 protected long _delayMs;
148 protected long _throttleMs;
149 protected long _maxWaitMs;
150 protected long _maxRequestMs;
151 protected long _maxIdleTrackerMs;
152 protected boolean _insertHeaders;
153 protected boolean _trackSessions;
154 protected boolean _remotePort;
155 protected int _throttledRequests;
156 protected Semaphore _passes;
157 protected Queue<Continuation>[] _queue;
158 protected ContinuationListener[] _listener;
159
160 protected int _maxRequestsPerSec;
161 protected final ConcurrentHashMap<String, RateTracker> _rateTrackers=new ConcurrentHashMap<String, RateTracker>();
162 protected String _whitelistStr;
163 private final HashSet<String> _whitelist = new HashSet<String>();
164
165 private final Timeout _requestTimeoutQ = new Timeout();
166 private final Timeout _trackerTimeoutQ = new Timeout();
167
168 private Thread _timerThread;
169 private volatile boolean _running;
170
171 public void init(FilterConfig filterConfig)
172 {
173 _context = filterConfig.getServletContext();
174
175 _queue = new Queue[getMaxPriority() + 1];
176 _listener = new ContinuationListener[getMaxPriority() + 1];
177 for (int p = 0; p < _queue.length; p++)
178 {
179 _queue[p] = new ConcurrentLinkedQueue<Continuation>();
180
181 final int priority=p;
182 _listener[p] = new ContinuationListener()
183 {
184 public void onComplete(Continuation continuation)
185 {
186 }
187
188 public void onTimeout(Continuation continuation)
189 {
190 _queue[priority].remove(continuation);
191 }
192 };
193 }
194
195 _rateTrackers.clear();
196
197 int baseRateLimit = __DEFAULT_MAX_REQUESTS_PER_SEC;
198 if (filterConfig.getInitParameter(MAX_REQUESTS_PER_S_INIT_PARAM) != null)
199 baseRateLimit = Integer.parseInt(filterConfig.getInitParameter(MAX_REQUESTS_PER_S_INIT_PARAM));
200 _maxRequestsPerSec = baseRateLimit;
201
202 long delay = __DEFAULT_DELAY_MS;
203 if (filterConfig.getInitParameter(DELAY_MS_INIT_PARAM) != null)
204 delay = Integer.parseInt(filterConfig.getInitParameter(DELAY_MS_INIT_PARAM));
205 _delayMs = delay;
206
207 int throttledRequests = __DEFAULT_THROTTLE;
208 if (filterConfig.getInitParameter(THROTTLED_REQUESTS_INIT_PARAM) != null)
209 throttledRequests = Integer.parseInt(filterConfig.getInitParameter(THROTTLED_REQUESTS_INIT_PARAM));
210 _passes = new Semaphore(throttledRequests,true);
211 _throttledRequests = throttledRequests;
212
213 long wait = __DEFAULT_WAIT_MS;
214 if (filterConfig.getInitParameter(MAX_WAIT_INIT_PARAM) != null)
215 wait = Integer.parseInt(filterConfig.getInitParameter(MAX_WAIT_INIT_PARAM));
216 _maxWaitMs = wait;
217
218 long suspend = __DEFAULT_THROTTLE_MS;
219 if (filterConfig.getInitParameter(THROTTLE_MS_INIT_PARAM) != null)
220 suspend = Integer.parseInt(filterConfig.getInitParameter(THROTTLE_MS_INIT_PARAM));
221 _throttleMs = suspend;
222
223 long maxRequestMs = __DEFAULT_MAX_REQUEST_MS_INIT_PARAM;
224 if (filterConfig.getInitParameter(MAX_REQUEST_MS_INIT_PARAM) != null )
225 maxRequestMs = Long.parseLong(filterConfig.getInitParameter(MAX_REQUEST_MS_INIT_PARAM));
226 _maxRequestMs = maxRequestMs;
227
228 long maxIdleTrackerMs = __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM;
229 if (filterConfig.getInitParameter(MAX_IDLE_TRACKER_MS_INIT_PARAM) != null )
230 maxIdleTrackerMs = Long.parseLong(filterConfig.getInitParameter(MAX_IDLE_TRACKER_MS_INIT_PARAM));
231 _maxIdleTrackerMs = maxIdleTrackerMs;
232
233 _whitelistStr = "";
234 if (filterConfig.getInitParameter(IP_WHITELIST_INIT_PARAM) !=null )
235 _whitelistStr = filterConfig.getInitParameter(IP_WHITELIST_INIT_PARAM);
236 initWhitelist();
237
238 String tmp = filterConfig.getInitParameter(INSERT_HEADERS_INIT_PARAM);
239 _insertHeaders = tmp==null || Boolean.parseBoolean(tmp);
240
241 tmp = filterConfig.getInitParameter(TRACK_SESSIONS_INIT_PARAM);
242 _trackSessions = tmp==null || Boolean.parseBoolean(tmp);
243
244 tmp = filterConfig.getInitParameter(REMOTE_PORT_INIT_PARAM);
245 _remotePort = tmp!=null&& Boolean.parseBoolean(tmp);
246
247 _requestTimeoutQ.setNow();
248 _requestTimeoutQ.setDuration(_maxRequestMs);
249
250 _trackerTimeoutQ.setNow();
251 _trackerTimeoutQ.setDuration(_maxIdleTrackerMs);
252
253 _running=true;
254 _timerThread = (new Thread()
255 {
256 public void run()
257 {
258 try
259 {
260 while (_running)
261 {
262 long now;
263 synchronized (_requestTimeoutQ)
264 {
265 now = _requestTimeoutQ.setNow();
266 _requestTimeoutQ.tick();
267 }
268 synchronized (_trackerTimeoutQ)
269 {
270 _trackerTimeoutQ.setNow(now);
271 _trackerTimeoutQ.tick();
272 }
273 try
274 {
275 Thread.sleep(100);
276 }
277 catch (InterruptedException e)
278 {
279 Log.ignore(e);
280 }
281 }
282 }
283 finally
284 {
285 Log.info("DoSFilter timer exited");
286 }
287 }
288 });
289 _timerThread.start();
290
291 if (_context!=null && Boolean.parseBoolean(filterConfig.getInitParameter(MANAGED_ATTR_INIT_PARAM)))
292 _context.setAttribute(filterConfig.getFilterName(),this);
293 }
294
295
296 public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterchain) throws IOException, ServletException
297 {
298 final HttpServletRequest srequest = (HttpServletRequest)request;
299 final HttpServletResponse sresponse = (HttpServletResponse)response;
300
301 final long now=_requestTimeoutQ.getNow();
302
303
304 RateTracker tracker = (RateTracker)request.getAttribute(__TRACKER);
305
306 if (tracker==null)
307 {
308
309
310
311 tracker = getRateTracker(request);
312
313
314 final boolean overRateLimit = tracker.isRateExceeded(now);
315
316
317 if (!overRateLimit)
318 {
319 doFilterChain(filterchain,srequest,sresponse);
320 return;
321 }
322
323
324 Log.warn("DOS ALERT: ip="+srequest.getRemoteAddr()+",session="+srequest.getRequestedSessionId()+",user="+srequest.getUserPrincipal());
325
326
327 switch((int)_delayMs)
328 {
329 case -1:
330 {
331
332 if (_insertHeaders)
333 ((HttpServletResponse)response).addHeader("DoSFilter","unavailable");
334
335 ((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
336 return;
337 }
338 case 0:
339 {
340
341 request.setAttribute(__TRACKER,tracker);
342 break;
343 }
344 default:
345 {
346
347 if (_insertHeaders)
348 ((HttpServletResponse)response).addHeader("DoSFilter","delayed");
349 Continuation continuation = ContinuationSupport.getContinuation(request);
350 request.setAttribute(__TRACKER,tracker);
351 if (_delayMs > 0)
352 continuation.setTimeout(_delayMs);
353 continuation.addContinuationListener(new ContinuationListener()
354 {
355
356 public void onComplete(Continuation continuation)
357 {
358 }
359
360 public void onTimeout(Continuation continuation)
361 {
362 }
363 });
364 continuation.suspend();
365 return;
366 }
367 }
368 }
369
370 boolean accepted = false;
371 try
372 {
373
374 accepted = _passes.tryAcquire(_maxWaitMs,TimeUnit.MILLISECONDS);
375
376 if (!accepted)
377 {
378
379 final Continuation continuation = ContinuationSupport.getContinuation(request);
380
381 Boolean throttled = (Boolean)request.getAttribute(__THROTTLED);
382 if (throttled!=Boolean.TRUE && _throttleMs>0)
383 {
384 int priority = getPriority(request,tracker);
385 request.setAttribute(__THROTTLED,Boolean.TRUE);
386 if (_insertHeaders)
387 ((HttpServletResponse)response).addHeader("DoSFilter","throttled");
388 if (_throttleMs > 0)
389 continuation.setTimeout(_throttleMs);
390 continuation.suspend();
391
392 continuation.addContinuationListener(_listener[priority]);
393 _queue[priority].add(continuation);
394 return;
395 }
396
397 else if (request.getAttribute("javax.servlet.resumed")==Boolean.TRUE)
398 {
399
400 _passes.acquire();
401 accepted = true;
402 }
403 }
404
405
406 if (accepted)
407
408 doFilterChain(filterchain,srequest,sresponse);
409 else
410 {
411
412 if (_insertHeaders)
413 ((HttpServletResponse)response).addHeader("DoSFilter","unavailable");
414 ((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
415 }
416 }
417 catch (InterruptedException e)
418 {
419 _context.log("DoS",e);
420 ((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
421 }
422 catch (Exception e)
423 {
424 e.printStackTrace();
425 }
426 finally
427 {
428 if (accepted)
429 {
430
431 for (int p = _queue.length; p-- > 0;)
432 {
433 Continuation continuation = _queue[p].poll();
434 if (continuation != null && continuation.isSuspended())
435 {
436 continuation.resume();
437 break;
438 }
439 }
440 _passes.release();
441 }
442 }
443 }
444
445
446
447
448
449
450
451
452 protected void doFilterChain(FilterChain chain, final HttpServletRequest request, final HttpServletResponse response)
453 throws IOException, ServletException
454 {
455 final Thread thread=Thread.currentThread();
456
457 final Timeout.Task requestTimeout = new Timeout.Task()
458 {
459 public void expired()
460 {
461 closeConnection(request, response, thread);
462 }
463 };
464
465 try
466 {
467 synchronized (_requestTimeoutQ)
468 {
469 _requestTimeoutQ.schedule(requestTimeout);
470 }
471 chain.doFilter(request,response);
472 }
473 finally
474 {
475 synchronized (_requestTimeoutQ)
476 {
477 requestTimeout.cancel();
478 }
479 }
480 }
481
482
483
484
485
486
487
488
489 protected void closeConnection(HttpServletRequest request, HttpServletResponse response, Thread thread)
490 {
491
492 if( !response.isCommitted() )
493 {
494 response.setHeader("Connection", "close");
495 }
496 try
497 {
498 try
499 {
500 response.getWriter().close();
501 }
502 catch (IllegalStateException e)
503 {
504 response.getOutputStream().close();
505 }
506 }
507 catch (IOException e)
508 {
509 Log.warn(e);
510 }
511
512
513 thread.interrupt();
514 }
515
516
517
518
519
520
521
522
523 protected int getPriority(ServletRequest request, RateTracker tracker)
524 {
525 if (extractUserId(request)!=null)
526 return USER_AUTH;
527 if (tracker!=null)
528 return tracker.getType();
529 return USER_UNKNOWN;
530 }
531
532
533
534
535 protected int getMaxPriority()
536 {
537 return USER_AUTH;
538 }
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556 public RateTracker getRateTracker(ServletRequest request)
557 {
558 HttpServletRequest srequest = (HttpServletRequest)request;
559 HttpSession session=srequest.getSession(false);
560
561 String loadId = extractUserId(request);
562 final int type;
563 if (loadId != null)
564 {
565 type = USER_AUTH;
566 }
567 else
568 {
569 if (_trackSessions && session!=null && !session.isNew())
570 {
571 loadId=session.getId();
572 type = USER_SESSION;
573 }
574 else
575 {
576 loadId = _remotePort?(request.getRemoteAddr()+request.getRemotePort()):request.getRemoteAddr();
577 type = USER_IP;
578 }
579 }
580
581 RateTracker tracker=_rateTrackers.get(loadId);
582
583 if (tracker==null)
584 {
585 RateTracker t;
586 if (_whitelist.contains(request.getRemoteAddr()))
587 {
588 t = new FixedRateTracker(loadId,type,_maxRequestsPerSec);
589 }
590 else
591 {
592 t = new RateTracker(loadId,type,_maxRequestsPerSec);
593 }
594
595 tracker=_rateTrackers.putIfAbsent(loadId,t);
596 if (tracker==null)
597 tracker=t;
598
599 if (type == USER_IP)
600 {
601
602 synchronized (_trackerTimeoutQ)
603 {
604 _trackerTimeoutQ.schedule(tracker);
605 }
606 }
607 else if (session!=null)
608
609 session.setAttribute(__TRACKER,tracker);
610 }
611
612 return tracker;
613 }
614
615 public void destroy()
616 {
617 _running=false;
618 _timerThread.interrupt();
619 synchronized (_requestTimeoutQ)
620 {
621 _requestTimeoutQ.cancelAll();
622 }
623 synchronized (_trackerTimeoutQ)
624 {
625 _trackerTimeoutQ.cancelAll();
626 }
627 _rateTrackers.clear();
628 _whitelist.clear();
629 }
630
631
632
633
634
635
636
637
638 protected String extractUserId(ServletRequest request)
639 {
640 return null;
641 }
642
643
644
645
646
647 protected void initWhitelist()
648 {
649 _whitelist.clear();
650 StringTokenizer tokenizer = new StringTokenizer(_whitelistStr, ",");
651 while (tokenizer.hasMoreTokens())
652 _whitelist.add(tokenizer.nextToken().trim());
653
654 Log.info("Whitelisted IP addresses: {}", _whitelist.toString());
655 }
656
657
658
659
660
661
662
663
664
665 public int getMaxRequestsPerSec()
666 {
667 return _maxRequestsPerSec;
668 }
669
670
671
672
673
674
675
676
677
678 public void setMaxRequestsPerSec(int value)
679 {
680 _maxRequestsPerSec = value;
681 }
682
683
684
685
686
687
688 public long getDelayMs()
689 {
690 return _delayMs;
691 }
692
693
694
695
696
697
698
699
700 public void setDelayMs(long value)
701 {
702 _delayMs = value;
703 }
704
705
706
707
708
709
710
711
712 public long getMaxWaitMs()
713 {
714 return _maxWaitMs;
715 }
716
717
718
719
720
721
722
723
724 public void setMaxWaitMs(long value)
725 {
726 _maxWaitMs = value;
727 }
728
729
730
731
732
733
734
735
736 public int getThrottledRequests()
737 {
738 return _throttledRequests;
739 }
740
741
742
743
744
745
746
747
748 public void setThrottledRequests(int value)
749 {
750 _passes = new Semaphore((value-_throttledRequests+_passes.availablePermits()), true);
751 _throttledRequests = value;
752 }
753
754
755
756
757
758
759
760 public long getThrottleMs()
761 {
762 return _throttleMs;
763 }
764
765
766
767
768
769
770
771 public void setThrottleMs(long value)
772 {
773 _throttleMs = value;
774 }
775
776
777
778
779
780
781
782
783 public long getMaxRequestMs()
784 {
785 return _maxRequestMs;
786 }
787
788
789
790
791
792
793
794
795 public void setMaxRequestMs(long value)
796 {
797 _maxRequestMs = value;
798 }
799
800
801
802
803
804
805
806
807
808 public long getMaxIdleTrackerMs()
809 {
810 return _maxIdleTrackerMs;
811 }
812
813
814
815
816
817
818
819
820
821 public void setMaxIdleTrackerMs(long value)
822 {
823 _maxIdleTrackerMs = value;
824 }
825
826
827
828
829
830
831
832 public boolean isInsertHeaders()
833 {
834 return _insertHeaders;
835 }
836
837
838
839
840
841
842
843 public void setInsertHeaders(boolean value)
844 {
845 _insertHeaders = value;
846 }
847
848
849
850
851
852
853
854 public boolean isTrackSessions()
855 {
856 return _trackSessions;
857 }
858
859
860
861
862
863
864 public void setTrackSessions(boolean value)
865 {
866 _trackSessions = value;
867 }
868
869
870
871
872
873
874
875
876 public boolean isRemotePort()
877 {
878 return _remotePort;
879 }
880
881
882
883
884
885
886
887
888
889 public void setRemotePort(boolean value)
890 {
891 _remotePort = value;
892 }
893
894
895
896
897
898
899
900 public String getWhitelist()
901 {
902 return _whitelistStr;
903 }
904
905
906
907
908
909
910
911
912 public void setWhitelist(String value)
913 {
914 _whitelistStr = value;
915 initWhitelist();
916 }
917
918
919
920
921
922 class RateTracker extends Timeout.Task implements HttpSessionBindingListener
923 {
924 protected final String _id;
925 protected final int _type;
926 protected final long[] _timestamps;
927 protected int _next;
928
929 public RateTracker(String id, int type,int maxRequestsPerSecond)
930 {
931 _id = id;
932 _type = type;
933 _timestamps=new long[maxRequestsPerSecond];
934 _next=0;
935 }
936
937
938
939
940 public boolean isRateExceeded(long now)
941 {
942 final long last;
943 synchronized (this)
944 {
945 last=_timestamps[_next];
946 _timestamps[_next]=now;
947 _next= (_next+1)%_timestamps.length;
948 }
949
950 boolean exceeded=last!=0 && (now-last)<1000L;
951 return exceeded;
952 }
953
954
955 public String getId()
956 {
957 return _id;
958 }
959
960 public int getType()
961 {
962 return _type;
963 }
964
965
966 public void valueBound(HttpSessionBindingEvent event)
967 {
968 }
969
970 public void valueUnbound(HttpSessionBindingEvent event)
971 {
972 _rateTrackers.remove(_id);
973 }
974
975 public void expired()
976 {
977 long now = _trackerTimeoutQ.getNow();
978 int latestIndex = _next == 0 ? 3 : (_next - 1 ) % _timestamps.length;
979 long last=_timestamps[latestIndex];
980 boolean hasRecentRequest = last != 0 && (now-last)<1000L;
981
982 if (hasRecentRequest)
983 reschedule();
984 else
985 _rateTrackers.remove(_id);
986 }
987
988 @Override
989 public String toString()
990 {
991 return "RateTracker/"+_id+"/"+_type;
992 }
993 }
994
995 class FixedRateTracker extends RateTracker
996 {
997 public FixedRateTracker(String id, int type, int numRecentRequestsTracked)
998 {
999 super(id,type,numRecentRequestsTracked);
1000 }
1001
1002 @Override
1003 public boolean isRateExceeded(long now)
1004 {
1005
1006
1007
1008 synchronized (this)
1009 {
1010 _timestamps[_next]=now;
1011 _next= (_next+1)%_timestamps.length;
1012 }
1013
1014 return false;
1015 }
1016
1017 @Override
1018 public String toString()
1019 {
1020 return "Fixed"+super.toString();
1021 }
1022 }
1023 }