1
2
3
4
5
6
7
8
9
10
11
12
13
14 package org.eclipse.jetty.servlets;
15
16 import java.io.IOException;
17 import java.util.HashSet;
18 import java.util.Queue;
19 import java.util.StringTokenizer;
20 import java.util.concurrent.ConcurrentHashMap;
21 import java.util.concurrent.ConcurrentLinkedQueue;
22 import java.util.concurrent.Semaphore;
23 import java.util.concurrent.TimeUnit;
24
25 import javax.servlet.Filter;
26 import javax.servlet.FilterChain;
27 import javax.servlet.FilterConfig;
28 import javax.servlet.ServletContext;
29 import javax.servlet.ServletException;
30 import javax.servlet.ServletRequest;
31 import javax.servlet.ServletResponse;
32 import javax.servlet.http.HttpServletRequest;
33 import javax.servlet.http.HttpServletResponse;
34 import javax.servlet.http.HttpSession;
35 import javax.servlet.http.HttpSessionBindingEvent;
36 import javax.servlet.http.HttpSessionBindingListener;
37
38 import org.eclipse.jetty.continuation.Continuation;
39 import org.eclipse.jetty.continuation.ContinuationListener;
40 import org.eclipse.jetty.continuation.ContinuationSupport;
41 import org.eclipse.jetty.server.handler.ContextHandler;
42 import org.eclipse.jetty.util.log.Log;
43 import org.eclipse.jetty.util.log.Logger;
44 import org.eclipse.jetty.util.thread.Timeout;
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115 public class DoSFilter implements Filter
116 {
117 private static final Logger LOG = Log.getLogger(DoSFilter.class);
118
119 final static String __TRACKER = "DoSFilter.Tracker";
120 final static String __THROTTLED = "DoSFilter.Throttled";
121
122 final static int __DEFAULT_MAX_REQUESTS_PER_SEC = 25;
123 final static int __DEFAULT_DELAY_MS = 100;
124 final static int __DEFAULT_THROTTLE = 5;
125 final static int __DEFAULT_WAIT_MS=50;
126 final static long __DEFAULT_THROTTLE_MS = 30000L;
127 final static long __DEFAULT_MAX_REQUEST_MS_INIT_PARAM=30000L;
128 final static long __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM=30000L;
129
130 final static String MANAGED_ATTR_INIT_PARAM="managedAttr";
131 final static String MAX_REQUESTS_PER_S_INIT_PARAM = "maxRequestsPerSec";
132 final static String DELAY_MS_INIT_PARAM = "delayMs";
133 final static String THROTTLED_REQUESTS_INIT_PARAM = "throttledRequests";
134 final static String MAX_WAIT_INIT_PARAM="maxWaitMs";
135 final static String THROTTLE_MS_INIT_PARAM = "throttleMs";
136 final static String MAX_REQUEST_MS_INIT_PARAM="maxRequestMs";
137 final static String MAX_IDLE_TRACKER_MS_INIT_PARAM="maxIdleTrackerMs";
138 final static String INSERT_HEADERS_INIT_PARAM="insertHeaders";
139 final static String TRACK_SESSIONS_INIT_PARAM="trackSessions";
140 final static String REMOTE_PORT_INIT_PARAM="remotePort";
141 final static String IP_WHITELIST_INIT_PARAM="ipWhitelist";
142
143 final static int USER_AUTH = 2;
144 final static int USER_SESSION = 2;
145 final static int USER_IP = 1;
146 final static int USER_UNKNOWN = 0;
147
148 ServletContext _context;
149
150 protected String _name;
151 protected long _delayMs;
152 protected long _throttleMs;
153 protected long _maxWaitMs;
154 protected long _maxRequestMs;
155 protected long _maxIdleTrackerMs;
156 protected boolean _insertHeaders;
157 protected boolean _trackSessions;
158 protected boolean _remotePort;
159 protected int _throttledRequests;
160 protected Semaphore _passes;
161 protected Queue<Continuation>[] _queue;
162 protected ContinuationListener[] _listener;
163
164 protected int _maxRequestsPerSec;
165 protected final ConcurrentHashMap<String, RateTracker> _rateTrackers=new ConcurrentHashMap<String, RateTracker>();
166 protected String _whitelistStr;
167 private final HashSet<String> _whitelist = new HashSet<String>();
168
169 private final Timeout _requestTimeoutQ = new Timeout();
170 private final Timeout _trackerTimeoutQ = new Timeout();
171
172 private Thread _timerThread;
173 private volatile boolean _running;
174
175 public void init(FilterConfig filterConfig)
176 {
177 _context = filterConfig.getServletContext();
178
179 _queue = new Queue[getMaxPriority() + 1];
180 _listener = new ContinuationListener[getMaxPriority() + 1];
181 for (int p = 0; p < _queue.length; p++)
182 {
183 _queue[p] = new ConcurrentLinkedQueue<Continuation>();
184
185 final int priority=p;
186 _listener[p] = new ContinuationListener()
187 {
188 public void onComplete(Continuation continuation)
189 {
190 }
191
192 public void onTimeout(Continuation continuation)
193 {
194 _queue[priority].remove(continuation);
195 }
196 };
197 }
198
199 _rateTrackers.clear();
200
201 int baseRateLimit = __DEFAULT_MAX_REQUESTS_PER_SEC;
202 if (filterConfig.getInitParameter(MAX_REQUESTS_PER_S_INIT_PARAM) != null)
203 baseRateLimit = Integer.parseInt(filterConfig.getInitParameter(MAX_REQUESTS_PER_S_INIT_PARAM));
204 _maxRequestsPerSec = baseRateLimit;
205
206 long delay = __DEFAULT_DELAY_MS;
207 if (filterConfig.getInitParameter(DELAY_MS_INIT_PARAM) != null)
208 delay = Integer.parseInt(filterConfig.getInitParameter(DELAY_MS_INIT_PARAM));
209 _delayMs = delay;
210
211 int throttledRequests = __DEFAULT_THROTTLE;
212 if (filterConfig.getInitParameter(THROTTLED_REQUESTS_INIT_PARAM) != null)
213 throttledRequests = Integer.parseInt(filterConfig.getInitParameter(THROTTLED_REQUESTS_INIT_PARAM));
214 _passes = new Semaphore(throttledRequests,true);
215 _throttledRequests = throttledRequests;
216
217 long wait = __DEFAULT_WAIT_MS;
218 if (filterConfig.getInitParameter(MAX_WAIT_INIT_PARAM) != null)
219 wait = Integer.parseInt(filterConfig.getInitParameter(MAX_WAIT_INIT_PARAM));
220 _maxWaitMs = wait;
221
222 long suspend = __DEFAULT_THROTTLE_MS;
223 if (filterConfig.getInitParameter(THROTTLE_MS_INIT_PARAM) != null)
224 suspend = Integer.parseInt(filterConfig.getInitParameter(THROTTLE_MS_INIT_PARAM));
225 _throttleMs = suspend;
226
227 long maxRequestMs = __DEFAULT_MAX_REQUEST_MS_INIT_PARAM;
228 if (filterConfig.getInitParameter(MAX_REQUEST_MS_INIT_PARAM) != null )
229 maxRequestMs = Long.parseLong(filterConfig.getInitParameter(MAX_REQUEST_MS_INIT_PARAM));
230 _maxRequestMs = maxRequestMs;
231
232 long maxIdleTrackerMs = __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM;
233 if (filterConfig.getInitParameter(MAX_IDLE_TRACKER_MS_INIT_PARAM) != null )
234 maxIdleTrackerMs = Long.parseLong(filterConfig.getInitParameter(MAX_IDLE_TRACKER_MS_INIT_PARAM));
235 _maxIdleTrackerMs = maxIdleTrackerMs;
236
237 _whitelistStr = "";
238 if (filterConfig.getInitParameter(IP_WHITELIST_INIT_PARAM) !=null )
239 _whitelistStr = filterConfig.getInitParameter(IP_WHITELIST_INIT_PARAM);
240 initWhitelist();
241
242 String tmp = filterConfig.getInitParameter(INSERT_HEADERS_INIT_PARAM);
243 _insertHeaders = tmp==null || Boolean.parseBoolean(tmp);
244
245 tmp = filterConfig.getInitParameter(TRACK_SESSIONS_INIT_PARAM);
246 _trackSessions = tmp==null || Boolean.parseBoolean(tmp);
247
248 tmp = filterConfig.getInitParameter(REMOTE_PORT_INIT_PARAM);
249 _remotePort = tmp!=null&& Boolean.parseBoolean(tmp);
250
251 _requestTimeoutQ.setNow();
252 _requestTimeoutQ.setDuration(_maxRequestMs);
253
254 _trackerTimeoutQ.setNow();
255 _trackerTimeoutQ.setDuration(_maxIdleTrackerMs);
256
257 _running=true;
258 _timerThread = (new Thread()
259 {
260 public void run()
261 {
262 try
263 {
264 while (_running)
265 {
266 long now;
267 synchronized (_requestTimeoutQ)
268 {
269 now = _requestTimeoutQ.setNow();
270 _requestTimeoutQ.tick();
271 }
272 synchronized (_trackerTimeoutQ)
273 {
274 _trackerTimeoutQ.setNow(now);
275 _trackerTimeoutQ.tick();
276 }
277 try
278 {
279 Thread.sleep(100);
280 }
281 catch (InterruptedException e)
282 {
283 LOG.ignore(e);
284 }
285 }
286 }
287 finally
288 {
289 LOG.info("DoSFilter timer exited");
290 }
291 }
292 });
293 _timerThread.start();
294
295 if (_context!=null && Boolean.parseBoolean(filterConfig.getInitParameter(MANAGED_ATTR_INIT_PARAM)))
296 _context.setAttribute(filterConfig.getFilterName(),this);
297 }
298
299
300 public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterchain) throws IOException, ServletException
301 {
302 final HttpServletRequest srequest = (HttpServletRequest)request;
303 final HttpServletResponse sresponse = (HttpServletResponse)response;
304
305 final long now=_requestTimeoutQ.getNow();
306
307
308 RateTracker tracker = (RateTracker)request.getAttribute(__TRACKER);
309
310 if (tracker==null)
311 {
312
313
314
315 tracker = getRateTracker(request);
316
317
318 final boolean overRateLimit = tracker.isRateExceeded(now);
319
320
321 if (!overRateLimit)
322 {
323 doFilterChain(filterchain,srequest,sresponse);
324 return;
325 }
326
327
328 LOG.warn("DOS ALERT: ip="+srequest.getRemoteAddr()+",session="+srequest.getRequestedSessionId()+",user="+srequest.getUserPrincipal());
329
330
331 switch((int)_delayMs)
332 {
333 case -1:
334 {
335
336 if (_insertHeaders)
337 ((HttpServletResponse)response).addHeader("DoSFilter","unavailable");
338 ((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
339 return;
340 }
341 case 0:
342 {
343
344 request.setAttribute(__TRACKER,tracker);
345 break;
346 }
347 default:
348 {
349
350 if (_insertHeaders)
351 ((HttpServletResponse)response).addHeader("DoSFilter","delayed");
352 Continuation continuation = ContinuationSupport.getContinuation(request);
353 request.setAttribute(__TRACKER,tracker);
354 if (_delayMs > 0)
355 continuation.setTimeout(_delayMs);
356 continuation.suspend();
357 return;
358 }
359 }
360 }
361
362
363 boolean accepted = false;
364 try
365 {
366
367 accepted = _passes.tryAcquire(_maxWaitMs,TimeUnit.MILLISECONDS);
368
369 if (!accepted)
370 {
371
372 final Continuation continuation = ContinuationSupport.getContinuation(request);
373
374 Boolean throttled = (Boolean)request.getAttribute(__THROTTLED);
375 if (throttled!=Boolean.TRUE && _throttleMs>0)
376 {
377 int priority = getPriority(request,tracker);
378 request.setAttribute(__THROTTLED,Boolean.TRUE);
379 if (_insertHeaders)
380 ((HttpServletResponse)response).addHeader("DoSFilter","throttled");
381 if (_throttleMs > 0)
382 continuation.setTimeout(_throttleMs);
383 continuation.suspend();
384
385 continuation.addContinuationListener(_listener[priority]);
386 _queue[priority].add(continuation);
387 return;
388 }
389
390 else if (request.getAttribute("javax.servlet.resumed")==Boolean.TRUE)
391 {
392
393 _passes.acquire();
394 accepted = true;
395 }
396 }
397
398
399 if (accepted)
400
401 doFilterChain(filterchain,srequest,sresponse);
402 else
403 {
404
405 if (_insertHeaders)
406 ((HttpServletResponse)response).addHeader("DoSFilter","unavailable");
407 ((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
408 }
409 }
410 catch (InterruptedException e)
411 {
412 _context.log("DoS",e);
413 ((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
414 }
415 finally
416 {
417 if (accepted)
418 {
419
420 for (int p = _queue.length; p-- > 0;)
421 {
422 Continuation continuation = _queue[p].poll();
423 if (continuation != null && continuation.isSuspended())
424 {
425 continuation.resume();
426 break;
427 }
428 }
429 _passes.release();
430 }
431 }
432 }
433
434
435
436
437
438
439
440
441 protected void doFilterChain(FilterChain chain, final HttpServletRequest request, final HttpServletResponse response)
442 throws IOException, ServletException
443 {
444 final Thread thread=Thread.currentThread();
445
446 final Timeout.Task requestTimeout = new Timeout.Task()
447 {
448 public void expired()
449 {
450 closeConnection(request, response, thread);
451 }
452 };
453
454 try
455 {
456 synchronized (_requestTimeoutQ)
457 {
458 _requestTimeoutQ.schedule(requestTimeout);
459 }
460 chain.doFilter(request,response);
461 }
462 finally
463 {
464 synchronized (_requestTimeoutQ)
465 {
466 requestTimeout.cancel();
467 }
468 }
469 }
470
471
472
473
474
475
476
477
478 protected void closeConnection(HttpServletRequest request, HttpServletResponse response, Thread thread)
479 {
480
481 if( !response.isCommitted() )
482 {
483 response.setHeader("Connection", "close");
484 }
485 try
486 {
487 try
488 {
489 response.getWriter().close();
490 }
491 catch (IllegalStateException e)
492 {
493 response.getOutputStream().close();
494 }
495 }
496 catch (IOException e)
497 {
498 LOG.warn(e);
499 }
500
501
502 thread.interrupt();
503 }
504
505
506
507
508
509
510
511
512 protected int getPriority(ServletRequest request, RateTracker tracker)
513 {
514 if (extractUserId(request)!=null)
515 return USER_AUTH;
516 if (tracker!=null)
517 return tracker.getType();
518 return USER_UNKNOWN;
519 }
520
521
522
523
524 protected int getMaxPriority()
525 {
526 return USER_AUTH;
527 }
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545 public RateTracker getRateTracker(ServletRequest request)
546 {
547 HttpServletRequest srequest = (HttpServletRequest)request;
548 HttpSession session=srequest.getSession(false);
549
550 String loadId = extractUserId(request);
551 final int type;
552 if (loadId != null)
553 {
554 type = USER_AUTH;
555 }
556 else
557 {
558 if (_trackSessions && session!=null && !session.isNew())
559 {
560 loadId=session.getId();
561 type = USER_SESSION;
562 }
563 else
564 {
565 loadId = _remotePort?(request.getRemoteAddr()+request.getRemotePort()):request.getRemoteAddr();
566 type = USER_IP;
567 }
568 }
569
570 RateTracker tracker=_rateTrackers.get(loadId);
571
572 if (tracker==null)
573 {
574 RateTracker t;
575 if (_whitelist.contains(request.getRemoteAddr()))
576 {
577 t = new FixedRateTracker(loadId,type,_maxRequestsPerSec);
578 }
579 else
580 {
581 t = new RateTracker(loadId,type,_maxRequestsPerSec);
582 }
583
584 tracker=_rateTrackers.putIfAbsent(loadId,t);
585 if (tracker==null)
586 tracker=t;
587
588 if (type == USER_IP)
589 {
590
591 synchronized (_trackerTimeoutQ)
592 {
593 _trackerTimeoutQ.schedule(tracker);
594 }
595 }
596 else if (session!=null)
597
598 session.setAttribute(__TRACKER,tracker);
599 }
600
601 return tracker;
602 }
603
604 public void destroy()
605 {
606 _running=false;
607 _timerThread.interrupt();
608 synchronized (_requestTimeoutQ)
609 {
610 _requestTimeoutQ.cancelAll();
611 }
612 synchronized (_trackerTimeoutQ)
613 {
614 _trackerTimeoutQ.cancelAll();
615 }
616 _rateTrackers.clear();
617 _whitelist.clear();
618 }
619
620
621
622
623
624
625
626
627 protected String extractUserId(ServletRequest request)
628 {
629 return null;
630 }
631
632
633
634
635
636 protected void initWhitelist()
637 {
638 _whitelist.clear();
639 StringTokenizer tokenizer = new StringTokenizer(_whitelistStr, ",");
640 while (tokenizer.hasMoreTokens())
641 _whitelist.add(tokenizer.nextToken().trim());
642
643 LOG.info("Whitelisted IP addresses: {}", _whitelist.toString());
644 }
645
646
647
648
649
650
651
652
653
654 public int getMaxRequestsPerSec()
655 {
656 return _maxRequestsPerSec;
657 }
658
659
660
661
662
663
664
665
666
667 public void setMaxRequestsPerSec(int value)
668 {
669 _maxRequestsPerSec = value;
670 }
671
672
673
674
675
676
677 public long getDelayMs()
678 {
679 return _delayMs;
680 }
681
682
683
684
685
686
687
688
689 public void setDelayMs(long value)
690 {
691 _delayMs = value;
692 }
693
694
695
696
697
698
699
700
701 public long getMaxWaitMs()
702 {
703 return _maxWaitMs;
704 }
705
706
707
708
709
710
711
712
713 public void setMaxWaitMs(long value)
714 {
715 _maxWaitMs = value;
716 }
717
718
719
720
721
722
723
724
725 public int getThrottledRequests()
726 {
727 return _throttledRequests;
728 }
729
730
731
732
733
734
735
736
737 public void setThrottledRequests(int value)
738 {
739 _passes = new Semaphore((value-_throttledRequests+_passes.availablePermits()), true);
740 _throttledRequests = value;
741 }
742
743
744
745
746
747
748
749 public long getThrottleMs()
750 {
751 return _throttleMs;
752 }
753
754
755
756
757
758
759
760 public void setThrottleMs(long value)
761 {
762 _throttleMs = value;
763 }
764
765
766
767
768
769
770
771
772 public long getMaxRequestMs()
773 {
774 return _maxRequestMs;
775 }
776
777
778
779
780
781
782
783
784 public void setMaxRequestMs(long value)
785 {
786 _maxRequestMs = value;
787 }
788
789
790
791
792
793
794
795
796
797 public long getMaxIdleTrackerMs()
798 {
799 return _maxIdleTrackerMs;
800 }
801
802
803
804
805
806
807
808
809
810 public void setMaxIdleTrackerMs(long value)
811 {
812 _maxIdleTrackerMs = value;
813 }
814
815
816
817
818
819
820
821 public boolean isInsertHeaders()
822 {
823 return _insertHeaders;
824 }
825
826
827
828
829
830
831
832 public void setInsertHeaders(boolean value)
833 {
834 _insertHeaders = value;
835 }
836
837
838
839
840
841
842
843 public boolean isTrackSessions()
844 {
845 return _trackSessions;
846 }
847
848
849
850
851
852
853 public void setTrackSessions(boolean value)
854 {
855 _trackSessions = value;
856 }
857
858
859
860
861
862
863
864
865 public boolean isRemotePort()
866 {
867 return _remotePort;
868 }
869
870
871
872
873
874
875
876
877
878 public void setRemotePort(boolean value)
879 {
880 _remotePort = value;
881 }
882
883
884
885
886
887
888
889 public String getWhitelist()
890 {
891 return _whitelistStr;
892 }
893
894
895
896
897
898
899
900
901 public void setWhitelist(String value)
902 {
903 _whitelistStr = value;
904 initWhitelist();
905 }
906
907
908
909
910
911 class RateTracker extends Timeout.Task implements HttpSessionBindingListener
912 {
913 protected final String _id;
914 protected final int _type;
915 protected final long[] _timestamps;
916 protected int _next;
917
918 public RateTracker(String id, int type,int maxRequestsPerSecond)
919 {
920 _id = id;
921 _type = type;
922 _timestamps=new long[maxRequestsPerSecond];
923 _next=0;
924 }
925
926
927
928
929 public boolean isRateExceeded(long now)
930 {
931 final long last;
932 synchronized (this)
933 {
934 last=_timestamps[_next];
935 _timestamps[_next]=now;
936 _next= (_next+1)%_timestamps.length;
937 }
938
939 boolean exceeded=last!=0 && (now-last)<1000L;
940 return exceeded;
941 }
942
943
944 public String getId()
945 {
946 return _id;
947 }
948
949 public int getType()
950 {
951 return _type;
952 }
953
954
955 public void valueBound(HttpSessionBindingEvent event)
956 {
957 }
958
959 public void valueUnbound(HttpSessionBindingEvent event)
960 {
961 _rateTrackers.remove(_id);
962 }
963
964 public void expired()
965 {
966 long now = _trackerTimeoutQ.getNow();
967 int latestIndex = _next == 0 ? 3 : (_next - 1 ) % _timestamps.length;
968 long last=_timestamps[latestIndex];
969 boolean hasRecentRequest = last != 0 && (now-last)<1000L;
970
971 if (hasRecentRequest)
972 reschedule();
973 else
974 _rateTrackers.remove(_id);
975 }
976
977 @Override
978 public String toString()
979 {
980 return "RateTracker/"+_id+"/"+_type;
981 }
982 }
983
984 class FixedRateTracker extends RateTracker
985 {
986 public FixedRateTracker(String id, int type, int numRecentRequestsTracked)
987 {
988 super(id,type,numRecentRequestsTracked);
989 }
990
991 @Override
992 public boolean isRateExceeded(long now)
993 {
994
995
996
997 synchronized (this)
998 {
999 _timestamps[_next]=now;
1000 _next= (_next+1)%_timestamps.length;
1001 }
1002
1003 return false;
1004 }
1005
1006 @Override
1007 public String toString()
1008 {
1009 return "Fixed"+super.toString();
1010 }
1011 }
1012 }