1
2
3
4
5
6
7
8
9
10
11
12
13
14 package org.eclipse.jetty.servlets;
15
16 import java.io.IOException;
17 import java.util.HashSet;
18 import java.util.Queue;
19 import java.util.StringTokenizer;
20 import java.util.concurrent.ConcurrentHashMap;
21 import java.util.concurrent.ConcurrentLinkedQueue;
22 import java.util.concurrent.Semaphore;
23 import java.util.concurrent.TimeUnit;
24 import javax.servlet.Filter;
25 import javax.servlet.FilterChain;
26 import javax.servlet.FilterConfig;
27 import javax.servlet.ServletContext;
28 import javax.servlet.ServletException;
29 import javax.servlet.ServletRequest;
30 import javax.servlet.ServletResponse;
31 import javax.servlet.http.HttpServletRequest;
32 import javax.servlet.http.HttpServletResponse;
33 import javax.servlet.http.HttpSession;
34 import javax.servlet.http.HttpSessionBindingEvent;
35 import javax.servlet.http.HttpSessionBindingListener;
36
37 import org.eclipse.jetty.continuation.Continuation;
38 import org.eclipse.jetty.continuation.ContinuationListener;
39 import org.eclipse.jetty.continuation.ContinuationSupport;
40 import org.eclipse.jetty.server.handler.ContextHandler;
41 import org.eclipse.jetty.util.log.Log;
42 import org.eclipse.jetty.util.log.Logger;
43 import org.eclipse.jetty.util.thread.Timeout;
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114 public class DoSFilter implements Filter
115 {
116 private static final Logger LOG = Log.getLogger(DoSFilter.class);
117
118 final static String __TRACKER = "DoSFilter.Tracker";
119 final static String __THROTTLED = "DoSFilter.Throttled";
120
121 final static int __DEFAULT_MAX_REQUESTS_PER_SEC = 25;
122 final static int __DEFAULT_DELAY_MS = 100;
123 final static int __DEFAULT_THROTTLE = 5;
124 final static int __DEFAULT_WAIT_MS=50;
125 final static long __DEFAULT_THROTTLE_MS = 30000L;
126 final static long __DEFAULT_MAX_REQUEST_MS_INIT_PARAM=30000L;
127 final static long __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM=30000L;
128
129 final static String MANAGED_ATTR_INIT_PARAM="managedAttr";
130 final static String MAX_REQUESTS_PER_S_INIT_PARAM = "maxRequestsPerSec";
131 final static String DELAY_MS_INIT_PARAM = "delayMs";
132 final static String THROTTLED_REQUESTS_INIT_PARAM = "throttledRequests";
133 final static String MAX_WAIT_INIT_PARAM="maxWaitMs";
134 final static String THROTTLE_MS_INIT_PARAM = "throttleMs";
135 final static String MAX_REQUEST_MS_INIT_PARAM="maxRequestMs";
136 final static String MAX_IDLE_TRACKER_MS_INIT_PARAM="maxIdleTrackerMs";
137 final static String INSERT_HEADERS_INIT_PARAM="insertHeaders";
138 final static String TRACK_SESSIONS_INIT_PARAM="trackSessions";
139 final static String REMOTE_PORT_INIT_PARAM="remotePort";
140 final static String IP_WHITELIST_INIT_PARAM="ipWhitelist";
141
142 final static int USER_AUTH = 2;
143 final static int USER_SESSION = 2;
144 final static int USER_IP = 1;
145 final static int USER_UNKNOWN = 0;
146
147 ServletContext _context;
148
149 protected String _name;
150 protected long _delayMs;
151 protected long _throttleMs;
152 protected long _maxWaitMs;
153 protected long _maxRequestMs;
154 protected long _maxIdleTrackerMs;
155 protected boolean _insertHeaders;
156 protected boolean _trackSessions;
157 protected boolean _remotePort;
158 protected int _throttledRequests;
159 protected Semaphore _passes;
160 protected Queue<Continuation>[] _queue;
161 protected ContinuationListener[] _listener;
162
163 protected int _maxRequestsPerSec;
164 protected final ConcurrentHashMap<String, RateTracker> _rateTrackers=new ConcurrentHashMap<String, RateTracker>();
165 protected String _whitelistStr;
166 private final HashSet<String> _whitelist = new HashSet<String>();
167
168 private final Timeout _requestTimeoutQ = new Timeout();
169 private final Timeout _trackerTimeoutQ = new Timeout();
170
171 private Thread _timerThread;
172 private volatile boolean _running;
173
174 public void init(FilterConfig filterConfig)
175 {
176 _context = filterConfig.getServletContext();
177
178 _queue = new Queue[getMaxPriority() + 1];
179 _listener = new ContinuationListener[getMaxPriority() + 1];
180 for (int p = 0; p < _queue.length; p++)
181 {
182 _queue[p] = new ConcurrentLinkedQueue<Continuation>();
183
184 final int priority=p;
185 _listener[p] = new ContinuationListener()
186 {
187 public void onComplete(Continuation continuation)
188 {
189 }
190
191 public void onTimeout(Continuation continuation)
192 {
193 _queue[priority].remove(continuation);
194 }
195 };
196 }
197
198 _rateTrackers.clear();
199
200 int baseRateLimit = __DEFAULT_MAX_REQUESTS_PER_SEC;
201 if (filterConfig.getInitParameter(MAX_REQUESTS_PER_S_INIT_PARAM) != null)
202 baseRateLimit = Integer.parseInt(filterConfig.getInitParameter(MAX_REQUESTS_PER_S_INIT_PARAM));
203 _maxRequestsPerSec = baseRateLimit;
204
205 long delay = __DEFAULT_DELAY_MS;
206 if (filterConfig.getInitParameter(DELAY_MS_INIT_PARAM) != null)
207 delay = Integer.parseInt(filterConfig.getInitParameter(DELAY_MS_INIT_PARAM));
208 _delayMs = delay;
209
210 int throttledRequests = __DEFAULT_THROTTLE;
211 if (filterConfig.getInitParameter(THROTTLED_REQUESTS_INIT_PARAM) != null)
212 throttledRequests = Integer.parseInt(filterConfig.getInitParameter(THROTTLED_REQUESTS_INIT_PARAM));
213 _passes = new Semaphore(throttledRequests,true);
214 _throttledRequests = throttledRequests;
215
216 long wait = __DEFAULT_WAIT_MS;
217 if (filterConfig.getInitParameter(MAX_WAIT_INIT_PARAM) != null)
218 wait = Integer.parseInt(filterConfig.getInitParameter(MAX_WAIT_INIT_PARAM));
219 _maxWaitMs = wait;
220
221 long suspend = __DEFAULT_THROTTLE_MS;
222 if (filterConfig.getInitParameter(THROTTLE_MS_INIT_PARAM) != null)
223 suspend = Integer.parseInt(filterConfig.getInitParameter(THROTTLE_MS_INIT_PARAM));
224 _throttleMs = suspend;
225
226 long maxRequestMs = __DEFAULT_MAX_REQUEST_MS_INIT_PARAM;
227 if (filterConfig.getInitParameter(MAX_REQUEST_MS_INIT_PARAM) != null )
228 maxRequestMs = Long.parseLong(filterConfig.getInitParameter(MAX_REQUEST_MS_INIT_PARAM));
229 _maxRequestMs = maxRequestMs;
230
231 long maxIdleTrackerMs = __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM;
232 if (filterConfig.getInitParameter(MAX_IDLE_TRACKER_MS_INIT_PARAM) != null )
233 maxIdleTrackerMs = Long.parseLong(filterConfig.getInitParameter(MAX_IDLE_TRACKER_MS_INIT_PARAM));
234 _maxIdleTrackerMs = maxIdleTrackerMs;
235
236 _whitelistStr = "";
237 if (filterConfig.getInitParameter(IP_WHITELIST_INIT_PARAM) !=null )
238 _whitelistStr = filterConfig.getInitParameter(IP_WHITELIST_INIT_PARAM);
239 initWhitelist();
240
241 String tmp = filterConfig.getInitParameter(INSERT_HEADERS_INIT_PARAM);
242 _insertHeaders = tmp==null || Boolean.parseBoolean(tmp);
243
244 tmp = filterConfig.getInitParameter(TRACK_SESSIONS_INIT_PARAM);
245 _trackSessions = tmp==null || Boolean.parseBoolean(tmp);
246
247 tmp = filterConfig.getInitParameter(REMOTE_PORT_INIT_PARAM);
248 _remotePort = tmp!=null&& Boolean.parseBoolean(tmp);
249
250 _requestTimeoutQ.setNow();
251 _requestTimeoutQ.setDuration(_maxRequestMs);
252
253 _trackerTimeoutQ.setNow();
254 _trackerTimeoutQ.setDuration(_maxIdleTrackerMs);
255
256 _running=true;
257 _timerThread = (new Thread()
258 {
259 public void run()
260 {
261 try
262 {
263 while (_running)
264 {
265 long now;
266 synchronized (_requestTimeoutQ)
267 {
268 now = _requestTimeoutQ.setNow();
269 _requestTimeoutQ.tick();
270 }
271 synchronized (_trackerTimeoutQ)
272 {
273 _trackerTimeoutQ.setNow(now);
274 _trackerTimeoutQ.tick();
275 }
276 try
277 {
278 Thread.sleep(100);
279 }
280 catch (InterruptedException e)
281 {
282 LOG.ignore(e);
283 }
284 }
285 }
286 finally
287 {
288 LOG.info("DoSFilter timer exited");
289 }
290 }
291 });
292 _timerThread.start();
293
294 if (_context!=null && Boolean.parseBoolean(filterConfig.getInitParameter(MANAGED_ATTR_INIT_PARAM)))
295 _context.setAttribute(filterConfig.getFilterName(),this);
296 }
297
298
299 public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterchain) throws IOException, ServletException
300 {
301 final HttpServletRequest srequest = (HttpServletRequest)request;
302 final HttpServletResponse sresponse = (HttpServletResponse)response;
303
304 final long now=_requestTimeoutQ.getNow();
305
306
307 RateTracker tracker = (RateTracker)request.getAttribute(__TRACKER);
308
309 if (tracker==null)
310 {
311
312
313
314 tracker = getRateTracker(request);
315
316
317 final boolean overRateLimit = tracker.isRateExceeded(now);
318
319
320 if (!overRateLimit)
321 {
322 doFilterChain(filterchain,srequest,sresponse);
323 return;
324 }
325
326
327 LOG.warn("DOS ALERT: ip="+srequest.getRemoteAddr()+",session="+srequest.getRequestedSessionId()+",user="+srequest.getUserPrincipal());
328
329
330 switch((int)_delayMs)
331 {
332 case -1:
333 {
334
335 if (_insertHeaders)
336 ((HttpServletResponse)response).addHeader("DoSFilter","unavailable");
337 ((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
338 return;
339 }
340 case 0:
341 {
342
343 request.setAttribute(__TRACKER,tracker);
344 break;
345 }
346 default:
347 {
348
349 if (_insertHeaders)
350 ((HttpServletResponse)response).addHeader("DoSFilter","delayed");
351 Continuation continuation = ContinuationSupport.getContinuation(request);
352 request.setAttribute(__TRACKER,tracker);
353 if (_delayMs > 0)
354 continuation.setTimeout(_delayMs);
355 continuation.suspend();
356 return;
357 }
358 }
359 }
360
361
362 boolean accepted = false;
363 try
364 {
365
366 accepted = _passes.tryAcquire(_maxWaitMs,TimeUnit.MILLISECONDS);
367
368 if (!accepted)
369 {
370
371 final Continuation continuation = ContinuationSupport.getContinuation(request);
372
373 Boolean throttled = (Boolean)request.getAttribute(__THROTTLED);
374 if (throttled!=Boolean.TRUE && _throttleMs>0)
375 {
376 int priority = getPriority(request,tracker);
377 request.setAttribute(__THROTTLED,Boolean.TRUE);
378 if (_insertHeaders)
379 ((HttpServletResponse)response).addHeader("DoSFilter","throttled");
380 if (_throttleMs > 0)
381 continuation.setTimeout(_throttleMs);
382 continuation.suspend();
383
384 continuation.addContinuationListener(_listener[priority]);
385 _queue[priority].add(continuation);
386 return;
387 }
388
389 else if (request.getAttribute("javax.servlet.resumed")==Boolean.TRUE)
390 {
391
392 _passes.acquire();
393 accepted = true;
394 }
395 }
396
397
398 if (accepted)
399
400 doFilterChain(filterchain,srequest,sresponse);
401 else
402 {
403
404 if (_insertHeaders)
405 ((HttpServletResponse)response).addHeader("DoSFilter","unavailable");
406 ((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
407 }
408 }
409 catch (InterruptedException e)
410 {
411 _context.log("DoS",e);
412 ((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
413 }
414 finally
415 {
416 if (accepted)
417 {
418
419 for (int p = _queue.length; p-- > 0;)
420 {
421 Continuation continuation = _queue[p].poll();
422 if (continuation != null && continuation.isSuspended())
423 {
424 continuation.resume();
425 break;
426 }
427 }
428 _passes.release();
429 }
430 }
431 }
432
433
434
435
436
437
438
439
440 protected void doFilterChain(FilterChain chain, final HttpServletRequest request, final HttpServletResponse response)
441 throws IOException, ServletException
442 {
443 final Thread thread=Thread.currentThread();
444
445 final Timeout.Task requestTimeout = new Timeout.Task()
446 {
447 public void expired()
448 {
449 closeConnection(request, response, thread);
450 }
451 };
452
453 try
454 {
455 synchronized (_requestTimeoutQ)
456 {
457 _requestTimeoutQ.schedule(requestTimeout);
458 }
459 chain.doFilter(request,response);
460 }
461 finally
462 {
463 synchronized (_requestTimeoutQ)
464 {
465 requestTimeout.cancel();
466 }
467 }
468 }
469
470
471
472
473
474
475
476
477 protected void closeConnection(HttpServletRequest request, HttpServletResponse response, Thread thread)
478 {
479
480 if( !response.isCommitted() )
481 {
482 response.setHeader("Connection", "close");
483 }
484 try
485 {
486 try
487 {
488 response.getWriter().close();
489 }
490 catch (IllegalStateException e)
491 {
492 response.getOutputStream().close();
493 }
494 }
495 catch (IOException e)
496 {
497 LOG.warn(e);
498 }
499
500
501 thread.interrupt();
502 }
503
504
505
506
507
508
509
510
511 protected int getPriority(ServletRequest request, RateTracker tracker)
512 {
513 if (extractUserId(request)!=null)
514 return USER_AUTH;
515 if (tracker!=null)
516 return tracker.getType();
517 return USER_UNKNOWN;
518 }
519
520
521
522
523 protected int getMaxPriority()
524 {
525 return USER_AUTH;
526 }
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544 public RateTracker getRateTracker(ServletRequest request)
545 {
546 HttpServletRequest srequest = (HttpServletRequest)request;
547 HttpSession session=srequest.getSession(false);
548
549 String loadId = extractUserId(request);
550 final int type;
551 if (loadId != null)
552 {
553 type = USER_AUTH;
554 }
555 else
556 {
557 if (_trackSessions && session!=null && !session.isNew())
558 {
559 loadId=session.getId();
560 type = USER_SESSION;
561 }
562 else
563 {
564 loadId = _remotePort?(request.getRemoteAddr()+request.getRemotePort()):request.getRemoteAddr();
565 type = USER_IP;
566 }
567 }
568
569 RateTracker tracker=_rateTrackers.get(loadId);
570
571 if (tracker==null)
572 {
573 RateTracker t;
574 if (_whitelist.contains(request.getRemoteAddr()))
575 {
576 t = new FixedRateTracker(loadId,type,_maxRequestsPerSec);
577 }
578 else
579 {
580 t = new RateTracker(loadId,type,_maxRequestsPerSec);
581 }
582
583 tracker=_rateTrackers.putIfAbsent(loadId,t);
584 if (tracker==null)
585 tracker=t;
586
587 if (type == USER_IP)
588 {
589
590 synchronized (_trackerTimeoutQ)
591 {
592 _trackerTimeoutQ.schedule(tracker);
593 }
594 }
595 else if (session!=null)
596
597 session.setAttribute(__TRACKER,tracker);
598 }
599
600 return tracker;
601 }
602
603 public void destroy()
604 {
605 _running=false;
606 _timerThread.interrupt();
607 synchronized (_requestTimeoutQ)
608 {
609 _requestTimeoutQ.cancelAll();
610 }
611 synchronized (_trackerTimeoutQ)
612 {
613 _trackerTimeoutQ.cancelAll();
614 }
615 _rateTrackers.clear();
616 _whitelist.clear();
617 }
618
619
620
621
622
623
624
625
626 protected String extractUserId(ServletRequest request)
627 {
628 return null;
629 }
630
631
632
633
634
635 protected void initWhitelist()
636 {
637 _whitelist.clear();
638 StringTokenizer tokenizer = new StringTokenizer(_whitelistStr, ",");
639 while (tokenizer.hasMoreTokens())
640 _whitelist.add(tokenizer.nextToken().trim());
641
642 LOG.info("Whitelisted IP addresses: {}", _whitelist.toString());
643 }
644
645
646
647
648
649
650
651
652
653 public int getMaxRequestsPerSec()
654 {
655 return _maxRequestsPerSec;
656 }
657
658
659
660
661
662
663
664
665
666 public void setMaxRequestsPerSec(int value)
667 {
668 _maxRequestsPerSec = value;
669 }
670
671
672
673
674
675
676 public long getDelayMs()
677 {
678 return _delayMs;
679 }
680
681
682
683
684
685
686
687
688 public void setDelayMs(long value)
689 {
690 _delayMs = value;
691 }
692
693
694
695
696
697
698
699
700 public long getMaxWaitMs()
701 {
702 return _maxWaitMs;
703 }
704
705
706
707
708
709
710
711
712 public void setMaxWaitMs(long value)
713 {
714 _maxWaitMs = value;
715 }
716
717
718
719
720
721
722
723
724 public int getThrottledRequests()
725 {
726 return _throttledRequests;
727 }
728
729
730
731
732
733
734
735
736 public void setThrottledRequests(int value)
737 {
738 _passes = new Semaphore((value-_throttledRequests+_passes.availablePermits()), true);
739 _throttledRequests = value;
740 }
741
742
743
744
745
746
747
748 public long getThrottleMs()
749 {
750 return _throttleMs;
751 }
752
753
754
755
756
757
758
759 public void setThrottleMs(long value)
760 {
761 _throttleMs = value;
762 }
763
764
765
766
767
768
769
770
771 public long getMaxRequestMs()
772 {
773 return _maxRequestMs;
774 }
775
776
777
778
779
780
781
782
783 public void setMaxRequestMs(long value)
784 {
785 _maxRequestMs = value;
786 }
787
788
789
790
791
792
793
794
795
796 public long getMaxIdleTrackerMs()
797 {
798 return _maxIdleTrackerMs;
799 }
800
801
802
803
804
805
806
807
808
809 public void setMaxIdleTrackerMs(long value)
810 {
811 _maxIdleTrackerMs = value;
812 }
813
814
815
816
817
818
819
820 public boolean isInsertHeaders()
821 {
822 return _insertHeaders;
823 }
824
825
826
827
828
829
830
831 public void setInsertHeaders(boolean value)
832 {
833 _insertHeaders = value;
834 }
835
836
837
838
839
840
841
842 public boolean isTrackSessions()
843 {
844 return _trackSessions;
845 }
846
847
848
849
850
851
852 public void setTrackSessions(boolean value)
853 {
854 _trackSessions = value;
855 }
856
857
858
859
860
861
862
863
864 public boolean isRemotePort()
865 {
866 return _remotePort;
867 }
868
869
870
871
872
873
874
875
876
877 public void setRemotePort(boolean value)
878 {
879 _remotePort = value;
880 }
881
882
883
884
885
886
887
888 public String getWhitelist()
889 {
890 return _whitelistStr;
891 }
892
893
894
895
896
897
898
899
900 public void setWhitelist(String value)
901 {
902 _whitelistStr = value;
903 initWhitelist();
904 }
905
906
907
908
909
910 class RateTracker extends Timeout.Task implements HttpSessionBindingListener
911 {
912 protected final String _id;
913 protected final int _type;
914 protected final long[] _timestamps;
915 protected int _next;
916
917 public RateTracker(String id, int type,int maxRequestsPerSecond)
918 {
919 _id = id;
920 _type = type;
921 _timestamps=new long[maxRequestsPerSecond];
922 _next=0;
923 }
924
925
926
927
928 public boolean isRateExceeded(long now)
929 {
930 final long last;
931 synchronized (this)
932 {
933 last=_timestamps[_next];
934 _timestamps[_next]=now;
935 _next= (_next+1)%_timestamps.length;
936 }
937
938 boolean exceeded=last!=0 && (now-last)<1000L;
939 return exceeded;
940 }
941
942
943 public String getId()
944 {
945 return _id;
946 }
947
948 public int getType()
949 {
950 return _type;
951 }
952
953
954 public void valueBound(HttpSessionBindingEvent event)
955 {
956 }
957
958 public void valueUnbound(HttpSessionBindingEvent event)
959 {
960 _rateTrackers.remove(_id);
961 }
962
963 public void expired()
964 {
965 long now = _trackerTimeoutQ.getNow();
966 int latestIndex = _next == 0 ? 3 : (_next - 1 ) % _timestamps.length;
967 long last=_timestamps[latestIndex];
968 boolean hasRecentRequest = last != 0 && (now-last)<1000L;
969
970 if (hasRecentRequest)
971 reschedule();
972 else
973 _rateTrackers.remove(_id);
974 }
975
976 @Override
977 public String toString()
978 {
979 return "RateTracker/"+_id+"/"+_type;
980 }
981 }
982
983 class FixedRateTracker extends RateTracker
984 {
985 public FixedRateTracker(String id, int type, int numRecentRequestsTracked)
986 {
987 super(id,type,numRecentRequestsTracked);
988 }
989
990 @Override
991 public boolean isRateExceeded(long now)
992 {
993
994
995
996 synchronized (this)
997 {
998 _timestamps[_next]=now;
999 _next= (_next+1)%_timestamps.length;
1000 }
1001
1002 return false;
1003 }
1004
1005 @Override
1006 public String toString()
1007 {
1008 return "Fixed"+super.toString();
1009 }
1010 }
1011 }