1
2
3
4
5
6
7
8
9
10
11
12
13
14 package org.eclipse.jetty.servlets;
15
16 import java.io.IOException;
17 import java.util.HashSet;
18 import java.util.Queue;
19 import java.util.StringTokenizer;
20 import java.util.concurrent.ConcurrentHashMap;
21 import java.util.concurrent.Semaphore;
22 import java.util.concurrent.TimeUnit;
23
24 import javax.servlet.Filter;
25 import javax.servlet.FilterChain;
26 import javax.servlet.FilterConfig;
27 import javax.servlet.ServletContext;
28 import javax.servlet.ServletException;
29 import javax.servlet.ServletRequest;
30 import javax.servlet.ServletResponse;
31 import javax.servlet.http.HttpServletRequest;
32 import javax.servlet.http.HttpServletResponse;
33 import javax.servlet.http.HttpSession;
34 import javax.servlet.http.HttpSessionBindingEvent;
35 import javax.servlet.http.HttpSessionBindingListener;
36
37 import org.eclipse.jetty.continuation.Continuation;
38 import org.eclipse.jetty.continuation.ContinuationSupport;
39 import org.eclipse.jetty.util.ArrayQueue;
40 import org.eclipse.jetty.util.log.Log;
41 import org.eclipse.jetty.util.thread.Timeout;
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94 public class DoSFilter implements Filter
95 {
96 final static String __TRACKER = "DoSFilter.Tracker";
97 final static String __THROTTLED = "DoSFilter.Throttled";
98
99 final static int __DEFAULT_MAX_REQUESTS_PER_SEC = 25;
100 final static int __DEFAULT_DELAY_MS = 100;
101 final static int __DEFAULT_THROTTLE = 5;
102 final static int __DEFAULT_WAIT_MS=50;
103 final static long __DEFAULT_THROTTLE_MS = 30000L;
104 final static long __DEFAULT_MAX_REQUEST_MS_INIT_PARAM=30000L;
105 final static long __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM=30000L;
106
107 final static String MAX_REQUESTS_PER_S_INIT_PARAM = "maxRequestsPerSec";
108 final static String DELAY_MS_INIT_PARAM = "delayMs";
109 final static String THROTTLED_REQUESTS_INIT_PARAM = "throttledRequests";
110 final static String MAX_WAIT_INIT_PARAM="maxWaitMs";
111 final static String THROTTLE_MS_INIT_PARAM = "throttleMs";
112 final static String MAX_REQUEST_MS_INIT_PARAM="maxRequestMs";
113 final static String MAX_IDLE_TRACKER_MS_INIT_PARAM="maxIdleTrackerMs";
114 final static String INSERT_HEADERS_INIT_PARAM="insertHeaders";
115 final static String TRACK_SESSIONS_INIT_PARAM="trackSessions";
116 final static String REMOTE_PORT_INIT_PARAM="remotePort";
117 final static String IP_WHITELIST_INIT_PARAM="ipWhitelist";
118
119 final static int USER_AUTH = 2;
120 final static int USER_SESSION = 2;
121 final static int USER_IP = 1;
122 final static int USER_UNKNOWN = 0;
123
124 ServletContext _context;
125
126 protected long _delayMs;
127 protected long _throttleMs;
128 protected long _waitMs;
129 protected long _maxRequestMs;
130 protected long _maxIdleTrackerMs;
131 protected boolean _insertHeaders;
132 protected boolean _trackSessions;
133 protected boolean _remotePort;
134 protected Semaphore _passes;
135 protected Queue<Continuation>[] _queue;
136
137 protected int _maxRequestsPerSec;
138 protected final ConcurrentHashMap<String, RateTracker> _rateTrackers=new ConcurrentHashMap<String, RateTracker>();
139 private HashSet<String> _whitelist;
140
141 private final Timeout _requestTimeoutQ = new Timeout();
142 private final Timeout _trackerTimeoutQ = new Timeout();
143
144 private Thread _timerThread;
145 private volatile boolean _running;
146
147 public void init(FilterConfig filterConfig)
148 {
149 _context = filterConfig.getServletContext();
150
151 _queue = new Queue[getMaxPriority() + 1];
152 for (int p = 0; p < _queue.length; p++)
153 _queue[p] = new ArrayQueue<Continuation>();
154
155 int baseRateLimit = __DEFAULT_MAX_REQUESTS_PER_SEC;
156 if (filterConfig.getInitParameter(MAX_REQUESTS_PER_S_INIT_PARAM) != null)
157 baseRateLimit = Integer.parseInt(filterConfig.getInitParameter(MAX_REQUESTS_PER_S_INIT_PARAM));
158 _maxRequestsPerSec = baseRateLimit;
159
160 long delay = __DEFAULT_DELAY_MS;
161 if (filterConfig.getInitParameter(DELAY_MS_INIT_PARAM) != null)
162 delay = Integer.parseInt(filterConfig.getInitParameter(DELAY_MS_INIT_PARAM));
163 _delayMs = delay;
164
165 int passes = __DEFAULT_THROTTLE;
166 if (filterConfig.getInitParameter(THROTTLED_REQUESTS_INIT_PARAM) != null)
167 passes = Integer.parseInt(filterConfig.getInitParameter(THROTTLED_REQUESTS_INIT_PARAM));
168 _passes = new Semaphore(passes,true);
169
170 long wait = __DEFAULT_WAIT_MS;
171 if (filterConfig.getInitParameter(MAX_WAIT_INIT_PARAM) != null)
172 wait = Integer.parseInt(filterConfig.getInitParameter(MAX_WAIT_INIT_PARAM));
173 _waitMs = wait;
174
175 long suspend = __DEFAULT_THROTTLE_MS;
176 if (filterConfig.getInitParameter(THROTTLE_MS_INIT_PARAM) != null)
177 suspend = Integer.parseInt(filterConfig.getInitParameter(THROTTLE_MS_INIT_PARAM));
178 _throttleMs = suspend;
179
180 long maxRequestMs = __DEFAULT_MAX_REQUEST_MS_INIT_PARAM;
181 if (filterConfig.getInitParameter(MAX_REQUEST_MS_INIT_PARAM) != null )
182 maxRequestMs = Long.parseLong(filterConfig.getInitParameter(MAX_REQUEST_MS_INIT_PARAM));
183 _maxRequestMs = maxRequestMs;
184
185 long maxIdleTrackerMs = __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM;
186 if (filterConfig.getInitParameter(MAX_IDLE_TRACKER_MS_INIT_PARAM) != null )
187 maxIdleTrackerMs = Long.parseLong(filterConfig.getInitParameter(MAX_IDLE_TRACKER_MS_INIT_PARAM));
188 _maxIdleTrackerMs = maxIdleTrackerMs;
189
190 String whitelistString = "";
191 if (filterConfig.getInitParameter(IP_WHITELIST_INIT_PARAM) !=null )
192 whitelistString = filterConfig.getInitParameter(IP_WHITELIST_INIT_PARAM);
193
194
195 if (whitelistString.length() == 0 )
196 _whitelist = new HashSet<String>();
197 else
198 {
199 StringTokenizer tokenizer = new StringTokenizer(whitelistString, ",");
200 _whitelist = new HashSet<String>(tokenizer.countTokens());
201 while (tokenizer.hasMoreTokens())
202 _whitelist.add(tokenizer.nextToken().trim());
203
204 Log.info("Whitelisted IP addresses: {}", _whitelist.toString());
205 }
206
207 String tmp = filterConfig.getInitParameter(INSERT_HEADERS_INIT_PARAM);
208 _insertHeaders = tmp==null || Boolean.parseBoolean(tmp);
209
210 tmp = filterConfig.getInitParameter(TRACK_SESSIONS_INIT_PARAM);
211 _trackSessions = tmp==null || Boolean.parseBoolean(tmp);
212
213 tmp = filterConfig.getInitParameter(REMOTE_PORT_INIT_PARAM);
214 _remotePort = tmp!=null&& Boolean.parseBoolean(tmp);
215
216 _requestTimeoutQ.setNow();
217 _requestTimeoutQ.setDuration(_maxRequestMs);
218
219 _trackerTimeoutQ.setNow();
220 _trackerTimeoutQ.setDuration(_maxIdleTrackerMs);
221
222 _running=true;
223 _timerThread = (new Thread()
224 {
225 public void run()
226 {
227 try
228 {
229 while (_running)
230 {
231 synchronized (_requestTimeoutQ)
232 {
233 _requestTimeoutQ.setNow();
234 _requestTimeoutQ.tick();
235
236 _trackerTimeoutQ.setNow(_requestTimeoutQ.getNow());
237 _trackerTimeoutQ.tick();
238 }
239 try
240 {
241 Thread.sleep(100);
242 }
243 catch (InterruptedException e)
244 {
245 Log.ignore(e);
246 }
247 }
248 }
249 finally
250 {
251 Log.info("DoSFilter timer exited");
252 }
253 }
254 });
255 _timerThread.start();
256 }
257
258
259 public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterchain) throws IOException, ServletException
260 {
261 final HttpServletRequest srequest = (HttpServletRequest)request;
262 final HttpServletResponse sresponse = (HttpServletResponse)response;
263
264 final long now=_requestTimeoutQ.getNow();
265
266
267 RateTracker tracker = (RateTracker)request.getAttribute(__TRACKER);
268
269 if (tracker==null)
270 {
271
272
273
274 tracker = getRateTracker(request);
275
276
277 final boolean overRateLimit = tracker.isRateExceeded(now);
278
279
280 if (!overRateLimit)
281 {
282 doFilterChain(filterchain,srequest,sresponse);
283 return;
284 }
285
286
287 Log.warn("DOS ALERT: ip="+srequest.getRemoteAddr()+",session="+srequest.getRequestedSessionId()+",user="+srequest.getUserPrincipal());
288
289
290 switch((int)_delayMs)
291 {
292 case -1:
293 {
294
295 if (_insertHeaders)
296 ((HttpServletResponse)response).addHeader("DoSFilter","unavailable");
297 ((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
298 return;
299 }
300 case 0:
301 {
302
303 request.setAttribute(__TRACKER,tracker);
304 break;
305 }
306 default:
307 {
308
309 if (_insertHeaders)
310 ((HttpServletResponse)response).addHeader("DoSFilter","delayed");
311 Continuation continuation = ContinuationSupport.getContinuation(request);
312 request.setAttribute(__TRACKER,tracker);
313 if (_delayMs > 0)
314 continuation.setTimeout(_delayMs);
315 continuation.suspend();
316 return;
317 }
318 }
319 }
320
321
322 boolean accepted = false;
323 try
324 {
325
326 accepted = _passes.tryAcquire(_waitMs,TimeUnit.MILLISECONDS);
327
328 if (!accepted)
329 {
330
331 final Continuation continuation = ContinuationSupport.getContinuation(request,response);
332
333 Boolean throttled = (Boolean)request.getAttribute(__THROTTLED);
334 if (throttled!=Boolean.TRUE && _throttleMs>0)
335 {
336 int priority = getPriority(request,tracker);
337 request.setAttribute(__THROTTLED,Boolean.TRUE);
338 if (_insertHeaders)
339 ((HttpServletResponse)response).addHeader("DoSFilter","throttled");
340 if (_throttleMs > 0)
341 continuation.setTimeout(_throttleMs);
342 continuation.suspend();
343
344 _queue[priority].add(continuation);
345 return;
346 }
347
348 else if (request.getAttribute("javax.servlet.resumed")==Boolean.TRUE)
349 {
350
351 _passes.acquire();
352 accepted = true;
353 }
354 }
355
356
357 if (accepted)
358
359 doFilterChain(filterchain,srequest,sresponse);
360 else
361 {
362
363 if (_insertHeaders)
364 ((HttpServletResponse)response).addHeader("DoSFilter","unavailable");
365 ((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
366 }
367 }
368 catch (InterruptedException e)
369 {
370 _context.log("DoS",e);
371 ((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
372 }
373 finally
374 {
375 if (accepted)
376 {
377
378 synchronized (_queue)
379 {
380 for (int p = _queue.length; p-- > 0;)
381 {
382 Continuation continuation = _queue[p].poll();
383
384 if (continuation != null)
385 {
386 continuation.resume();
387 break;
388 }
389 }
390 }
391 _passes.release();
392 }
393 }
394 }
395
396
397
398
399
400
401
402
403 protected void doFilterChain(FilterChain chain, final HttpServletRequest request, final HttpServletResponse response)
404 throws IOException, ServletException
405 {
406 final Thread thread=Thread.currentThread();
407
408 final Timeout.Task requestTimeout = new Timeout.Task()
409 {
410 public void expired()
411 {
412 closeConnection(request, response, thread);
413 }
414 };
415
416 try
417 {
418 synchronized (_requestTimeoutQ)
419 {
420 _requestTimeoutQ.schedule(requestTimeout);
421 }
422 chain.doFilter(request,response);
423 }
424 finally
425 {
426 synchronized (_requestTimeoutQ)
427 {
428 requestTimeout.cancel();
429 }
430 }
431 }
432
433
434
435
436
437
438
439
440 protected void closeConnection(HttpServletRequest request, HttpServletResponse response, Thread thread)
441 {
442
443 if( !response.isCommitted() )
444 {
445 response.setHeader("Connection", "close");
446 }
447 try
448 {
449 try
450 {
451 response.getWriter().close();
452 }
453 catch (IllegalStateException e)
454 {
455 response.getOutputStream().close();
456 }
457 }
458 catch (IOException e)
459 {
460 Log.warn(e);
461 }
462
463
464 thread.interrupt();
465 }
466
467
468
469
470
471
472
473
474 protected int getPriority(ServletRequest request, RateTracker tracker)
475 {
476 if (extractUserId(request)!=null)
477 return USER_AUTH;
478 if (tracker!=null)
479 return tracker.getType();
480 return USER_UNKNOWN;
481 }
482
483
484
485
486 protected int getMaxPriority()
487 {
488 return USER_AUTH;
489 }
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507 public RateTracker getRateTracker(ServletRequest request)
508 {
509 HttpServletRequest srequest = (HttpServletRequest)request;
510
511 String loadId;
512 final int type;
513
514 loadId = extractUserId(request);
515 HttpSession session=srequest.getSession(false);
516 if (_trackSessions && session!=null && !session.isNew())
517 {
518 loadId=session.getId();
519 type = USER_SESSION;
520 }
521 else
522 {
523 loadId = _remotePort?(request.getRemoteAddr()+request.getRemotePort()):request.getRemoteAddr();
524 type = USER_IP;
525 }
526
527 RateTracker tracker=_rateTrackers.get(loadId);
528
529 if (tracker==null)
530 {
531 RateTracker t;
532 if (_whitelist.contains(request.getRemoteAddr()))
533 {
534 t = new FixedRateTracker(loadId,type,_maxRequestsPerSec);
535 }
536 else
537 {
538 t = new RateTracker(loadId,type,_maxRequestsPerSec);
539 }
540
541 tracker=_rateTrackers.putIfAbsent(loadId,t);
542 if (tracker==null)
543 tracker=t;
544
545 if (type == USER_IP)
546 {
547
548 synchronized (_trackerTimeoutQ)
549 {
550 _trackerTimeoutQ.schedule(tracker);
551 }
552 }
553 else if (session!=null)
554
555 session.setAttribute(__TRACKER,tracker);
556 }
557
558 return tracker;
559 }
560
561 public void destroy()
562 {
563 _running=false;
564 _timerThread.interrupt();
565 synchronized (_requestTimeoutQ)
566 {
567 _requestTimeoutQ.cancelAll();
568 _trackerTimeoutQ.cancelAll();
569 }
570 }
571
572
573
574
575
576
577
578
579 protected String extractUserId(ServletRequest request)
580 {
581 return null;
582 }
583
584
585
586
587
588 class RateTracker extends Timeout.Task implements HttpSessionBindingListener
589 {
590 protected final String _id;
591 protected final int _type;
592 protected final long[] _timestamps;
593 protected int _next;
594
595 public RateTracker(String id, int type,int maxRequestsPerSecond)
596 {
597 _id = id;
598 _type = type;
599 _timestamps=new long[maxRequestsPerSecond];
600 _next=0;
601 }
602
603
604
605
606 public boolean isRateExceeded(long now)
607 {
608 final long last;
609 synchronized (this)
610 {
611 last=_timestamps[_next];
612 _timestamps[_next]=now;
613 _next= (_next+1)%_timestamps.length;
614 }
615
616 boolean exceeded=last!=0 && (now-last)<1000L;
617 return exceeded;
618 }
619
620
621 public String getId()
622 {
623 return _id;
624 }
625
626 public int getType()
627 {
628 return _type;
629 }
630
631
632 public void valueBound(HttpSessionBindingEvent event)
633 {
634 }
635
636 public void valueUnbound(HttpSessionBindingEvent event)
637 {
638 _rateTrackers.remove(_id);
639 }
640
641 public void expired()
642 {
643 long now = _trackerTimeoutQ.getNow();
644 int latestIndex = _next == 0 ? 3 : (_next - 1 ) % _timestamps.length;
645 long last=_timestamps[latestIndex];
646 boolean hasRecentRequest = last != 0 && (now-last)<1000L;
647
648 if (hasRecentRequest)
649 reschedule();
650 else
651 _rateTrackers.remove(_id);
652 }
653 }
654
655 class FixedRateTracker extends RateTracker
656 {
657 public FixedRateTracker(String id, int type, int numRecentRequestsTracked)
658 {
659 super(id,type,numRecentRequestsTracked);
660 }
661
662 public boolean isRateExceeded(long now)
663 {
664
665
666
667 synchronized (this)
668 {
669 _timestamps[_next]=now;
670 _next= (_next+1)%_timestamps.length;
671 }
672
673 return false;
674 }
675 }
676 }