org.eclipse.jetty.server.handler.DoSHandler Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of jetty-server Show documentation
Show all versions of jetty-server Show documentation
The core jetty server artifact.
The newest version!
//
// ========================================================================
// Copyright (c) 1995 Mort Bay Consulting Pty Ltd and others.
//
// This program and the accompanying materials are made available under the
// terms of the Eclipse Public License v. 2.0 which is available at
// https://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0
// which is available at https://www.apache.org/licenses/LICENSE-2.0.
//
// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0
// ========================================================================
//
package org.eclipse.jetty.server.handler;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.time.Duration;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Deque;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import org.eclipse.jetty.http.HttpStatus;
import org.eclipse.jetty.io.CyclicTimeouts;
import org.eclipse.jetty.server.ConnectionMetaData;
import org.eclipse.jetty.server.Handler;
import org.eclipse.jetty.server.Request;
import org.eclipse.jetty.server.Response;
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.util.Callback;
import org.eclipse.jetty.util.NanoTime;
import org.eclipse.jetty.util.annotation.ManagedObject;
import org.eclipse.jetty.util.annotation.Name;
import org.eclipse.jetty.util.thread.AutoLock;
import org.eclipse.jetty.util.thread.Scheduler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* A Denial of Service Handler that protects from attacks by limiting the request rate from remote clients.
*/
@ManagedObject("DoS Prevention Handler")
public class DoSHandler extends ConditionalHandler.ElseNext
{
private static final Logger LOG = LoggerFactory.getLogger(DoSHandler.class);
/**
* A {@link Function} to create a remote client identifier from the remote address and remote port of a {@link Request}.
*/
public static final Function ID_FROM_REMOTE_ADDRESS_PORT = request ->
{
SocketAddress remoteSocketAddress = request.getConnectionMetaData().getRemoteSocketAddress();
if (remoteSocketAddress instanceof InetSocketAddress inetSocketAddress)
return inetSocketAddress.toString();
return remoteSocketAddress.toString();
};
/**
* A {@link Function} to create a remote client identifier from the remote address of a {@link Request}.
*/
public static final Function ID_FROM_REMOTE_ADDRESS = request ->
{
SocketAddress remoteSocketAddress = request.getConnectionMetaData().getRemoteSocketAddress();
if (remoteSocketAddress instanceof InetSocketAddress inetSocketAddress)
return inetSocketAddress.getAddress().toString();
return remoteSocketAddress.toString();
};
/**
* A {@link Function} to create a remote client identifier from the remote port of a {@link Request}.
* Useful if there is an untrusted intermediary, where the remote port can be a surrogate for the connection.
*/
public static final Function ID_FROM_REMOTE_PORT = request ->
{
SocketAddress remoteSocketAddress = request.getConnectionMetaData().getRemoteSocketAddress();
if (remoteSocketAddress instanceof InetSocketAddress inetSocketAddress)
return Integer.toString(inetSocketAddress.getPort());
return remoteSocketAddress.toString();
};
/**
* A {@link Function} to create a remote client identifier from {@link ConnectionMetaData#getId()} of a {@link Request}.
*/
public static final Function ID_FROM_CONNECTION = request -> request.getConnectionMetaData().getId();
private final Map _trackers = new ConcurrentHashMap<>();
private final Function _clientIdFn;
private final Tracker.Factory _trackerFactory;
private final Request.Handler _rejectHandler;
private final int _maxTrackers;
private final boolean _rejectUntracked;
private CyclicTimeouts _cyclicTimeouts;
/**
* @param trackerFactory Factory to create a Tracker
*/
public DoSHandler(@Name("trackerFactory") Tracker.Factory trackerFactory)
{
this(null, trackerFactory, null, -1);
}
/**
* @param clientIdFn Function to extract a remote client identifier from a request.
* @param trackerFactory Factory to create a Tracker
* @param rejectHandler A {@link Handler} used to reject excess requests, or {@code null} for a default.
* @param maxTrackers The maximum number of remote clients to track or -1 for a default value, 0 for unlimited.
* If this limit is exceeded, then requests from additional remote clients are rejected.
*/
public DoSHandler(
@Name("clientIdFn") Function clientIdFn,
@Name("trackerFactory") Tracker.Factory trackerFactory,
@Name("rejectHandler") Request.Handler rejectHandler,
@Name("maxTrackers") int maxTrackers)
{
this(null, clientIdFn, trackerFactory, rejectHandler, maxTrackers);
}
/**
* @param handler Then next {@link Handler} or {@code null}
* @param clientIdFn Function to extract a remote client identifier from a request.
* @param trackerFactory Factory to create a Tracker
* @param rejectHandler A {@link Handler} used to reject excess requests, or {@code null} for a default.
* @param maxTrackers The maximum number of remote clients to track or -1 for a default value, 0 for unlimited.
* If this limit is exceeded, then requests from additional remote clients are rejected.
*/
public DoSHandler(
@Name("handler") Handler handler,
@Name("clientIdFn") Function clientIdFn,
@Name("trackerFactory") Tracker.Factory trackerFactory,
@Name("rejectHandler") Request.Handler rejectHandler,
@Name("maxTrackers") int maxTrackers)
{
this(handler, clientIdFn, trackerFactory, rejectHandler, maxTrackers, false);
}
/**
* @param handler Then next {@link Handler} or {@code null}
* @param clientIdFn Function to extract a remote client identifier from a request.
* @param trackerFactory Factory to create a Tracker
* @param rejectHandler A {@link Handler} used to reject excess requests, or {@code null} for a default.
* @param maxTrackers The maximum number of remote clients to track or -1 for a default value, 0 for unlimited.
* If this limit is exceeded, then requests from additional remote clients are rejected.
*/
public DoSHandler(
@Name("handler") Handler handler,
@Name("clientIdFn") Function clientIdFn,
@Name("trackerFactory") Tracker.Factory trackerFactory,
@Name("rejectHandler") Request.Handler rejectHandler,
@Name("maxTrackers") int maxTrackers,
@Name("rejectUntracked") boolean rejectUntracked)
{
super(handler);
installBean(_trackers);
_clientIdFn = Objects.requireNonNullElse(clientIdFn, ID_FROM_REMOTE_ADDRESS);
installBean(_clientIdFn);
_trackerFactory = Objects.requireNonNull(trackerFactory);
installBean(_trackerFactory);
// default max trackers to a large, effectively infinite number, but ultimately bounded.
_maxTrackers = maxTrackers < 0 ? 100_000 : maxTrackers;
_rejectHandler = Objects.requireNonNullElseGet(rejectHandler, StatusRejectHandler::new);
installBean(_rejectHandler);
_rejectUntracked = rejectUntracked;
}
@Override
public void setServer(Server server)
{
super.setServer(server);
if (_rejectHandler instanceof Handler handler)
handler.setServer(server);
}
@Override
protected boolean onConditionsMet(Request request, Response response, Callback callback) throws Exception
{
// Calculate an id for the request (which may be global empty string).
String id = _clientIdFn.apply(request);
// Reject or handle untracked request
if (id == null)
id = "";
// Obtain a tracker, creating a new one if necessary (and not too many)
// Trackers are removed if CyclicTimeouts#onExpired returns true.
Tracker tracker = _trackers.computeIfAbsent(id, this::newTracker);
// If we have too many trackers, then we will have a null tracker
if (tracker == null)
return _rejectUntracked ? _rejectHandler.handle(request, response, callback) : nextHandler(request, response, callback);
// IS the request allowed by the tracker?
boolean allowed = tracker.onRequest(NanoTime.now());
if (LOG.isDebugEnabled())
LOG.debug("allowed={} {}", allowed, tracker);
if (allowed)
return nextHandler(request, response, callback);
// Otherwise reject the request as it is over rate
return _rejectHandler.handle(request, response, callback);
}
Tracker newTracker(String id)
{
if (_maxTrackers > 0 && _trackers.size() >= _maxTrackers)
return null;
Tracker tracker = _trackerFactory.newTracker(id);
_cyclicTimeouts.schedule(tracker);
return tracker;
}
@Override
protected void doStart() throws Exception
{
_cyclicTimeouts = new CyclicTimeouts<>(getServer().getScheduler())
{
@Override
protected Iterator iterator()
{
return _trackers.values().iterator();
}
@Override
protected boolean onExpired(Tracker tracker)
{
return true;
}
};
addBean(_cyclicTimeouts);
super.doStart();
}
@Override
protected void doStop() throws Exception
{
super.doStop();
removeBean(_cyclicTimeouts);
_cyclicTimeouts.destroy();
_cyclicTimeouts = null;
}
/**
* A RateTracker is associated with an id, and stores request rate data.
*/
public interface Tracker extends CyclicTimeouts.Expirable
{
/**
* Add a request to the tracker and check the rate limit
*
* @param now The timestamp of the request
* @return {@code true} if the request is below the limit
*/
boolean onRequest(long now);
interface Factory
{
Tracker newTracker(String id);
}
}
/**
* The Tracker implements the classic Leaky Bucket Algorithm.
*/
public static class LeakingBucketTrackerFactory implements Tracker.Factory
{
private final int _maxRequestsPerSecond;
private final int _bucketSize;
private final long _nanosPerDrip;
private final long _idleTimeout;
/**
* @param maxRequestsPerSecond the maximum requests per second allowed by this tracker
*/
public LeakingBucketTrackerFactory(
@Name("maxRequestsPerSecond") int maxRequestsPerSecond)
{
this(maxRequestsPerSecond, -1, null);
}
/**
* @param maxRequestsPerSecond the maximum requests per second allowed by this tracker
* @param bucketSize the size of the bucket in request/drips, which is effectively the burst capacity, giving the number
* of request that can be handled in excess of the short term rate, before being rejected.
* Use -1 for a heuristic value.
* @param idleTimeout The period to keep an empty bucket before removal, or null to remove a bucket immediately when
* empty
*/
public LeakingBucketTrackerFactory(
@Name("maxRequestsPerSecond") int maxRequestsPerSecond,
@Name("bucketSize") int bucketSize,
@Name("idleTimeout") Duration idleTimeout)
{
_maxRequestsPerSecond = maxRequestsPerSecond;
_nanosPerDrip = TimeUnit.SECONDS.toNanos(1) / _maxRequestsPerSecond;
_bucketSize = (bucketSize < 0)
? _maxRequestsPerSecond
: bucketSize;
_idleTimeout = idleTimeout == null ? 0 : idleTimeout.toNanos();
}
@Override
public Tracker newTracker(String id)
{
return new LeakingBucketTracker(id);
}
private class LeakingBucketTracker implements Tracker
{
private final AutoLock _lock = new AutoLock();
private final String _id;
private long _lastDripNanoTime;
private long _expireNanoTime;
private int _bucket;
public LeakingBucketTracker(String id)
{
_id = id;
long now = NanoTime.now();
_lastDripNanoTime = now;
_expireNanoTime = now + _nanosPerDrip + _idleTimeout;
}
@Override
public long getExpireNanoTime()
{
try (AutoLock ignored = _lock.lock())
{
return _expireNanoTime;
}
}
@Override
public boolean onRequest(long now)
{
try (AutoLock ignored = _lock.lock())
{
long elapsedSinceLastDrip = NanoTime.elapsed(_lastDripNanoTime, now);
long drips = elapsedSinceLastDrip / _nanosPerDrip;
_lastDripNanoTime = _lastDripNanoTime + drips * _nanosPerDrip;
_bucket = Math.min(_bucketSize, Math.toIntExact(Math.max(0L, _bucket - drips) + 1));
_expireNanoTime = now + _bucket * _nanosPerDrip + _idleTimeout;
return _bucket < _bucketSize;
}
}
@Override
public String toString()
{
try (AutoLock ignored = _lock.lock())
{
return "%s@%s{%d/%d}".formatted(getClass().getSimpleName(), _id, _bucket, _maxRequestsPerSecond);
}
}
}
}
/**
* A Handler to reject DoS requests with a status code or failure.
*/
public static class StatusRejectHandler implements Request.Handler
{
private final int _status;
public StatusRejectHandler()
{
this(-1);
}
/**
* @param status The status used to reject a request, or 0 to fail the request or -1 for a default ({@link HttpStatus#TOO_MANY_REQUESTS_429}.
*/
public StatusRejectHandler(int status)
{
_status = status >= 0 ? status : HttpStatus.TOO_MANY_REQUESTS_429;
if (_status != 0 && _status != HttpStatus.OK_200 && !HttpStatus.isClientError(_status) && !HttpStatus.isServerError(_status))
throw new IllegalArgumentException("status must be a client or server error");
}
@Override
public boolean handle(Request request, Response response, Callback callback) throws Exception
{
if (_status == 0)
callback.failed(new RejectedExecutionException());
else
Response.writeError(request, response, callback, _status);
return true;
}
}
/**
* A Handler to reject DoS requests after first delaying them.
*/
public static class DelayedRejectHandler extends Handler.Abstract
{
private record Exchange(Request request, Response response, Callback callback)
{
}
private final AutoLock _lock = new AutoLock();
private final Deque _delayQueue = new ArrayDeque<>();
private final int _maxDelayQueue;
private final long _delayMs;
private final Request.Handler _reject;
private Scheduler _scheduler;
public DelayedRejectHandler()
{
this(-1, -1, null);
}
/**
* @param delayMs The delay in milliseconds to hold rejected requests before sending a response or -1 for a default (1000ms)
* @param maxDelayQueue The maximum number of delayed requests to hold or -1 for a default (1000ms).
* @param reject The {@link Request.Handler} used to reject {@link Request}s or null for a default ({@link HttpStatus#TOO_MANY_REQUESTS_429}).
*/
public DelayedRejectHandler(
@Name("delayMs") long delayMs,
@Name("maxDelayQueue") int maxDelayQueue,
@Name("reject") Request.Handler reject)
{
_delayMs = delayMs >= 0 ? delayMs : 1000;
_maxDelayQueue = maxDelayQueue >= 0 ? maxDelayQueue : 1000;
_reject = Objects.requireNonNullElseGet(reject, () -> new StatusRejectHandler(HttpStatus.TOO_MANY_REQUESTS_429));
}
@Override
protected void doStart() throws Exception
{
super.doStart();
_scheduler = getServer().getScheduler();
addBean(_scheduler);
}
@Override
protected void doStop() throws Exception
{
super.doStop();
removeBean(_scheduler);
_scheduler = null;
}
@Override
public boolean handle(Request request, Response response, Callback callback) throws Exception
{
List rejects = null;
try (AutoLock ignored = _lock.lock())
{
while (_delayQueue.size() >= _maxDelayQueue)
{
Exchange exchange = _delayQueue.removeFirst();
if (rejects == null)
rejects = new ArrayList<>();
rejects.add(exchange);
}
if (_delayQueue.isEmpty())
_scheduler.schedule(this::onTick, _delayMs / 2, TimeUnit.MILLISECONDS);
_delayQueue.addLast(new Exchange(request, response, callback));
}
reject(rejects);
return true;
}
private void onTick()
{
long expired = NanoTime.now() - TimeUnit.MILLISECONDS.toNanos(_delayMs);
List rejects = null;
try (AutoLock ignored = _lock.lock())
{
Iterator iterator = _delayQueue.iterator();
while (iterator.hasNext())
{
Exchange exchange = iterator.next();
if (NanoTime.isBeforeOrSame(exchange.request.getBeginNanoTime(), expired))
{
iterator.remove();
if (rejects == null)
rejects = new ArrayList<>();
rejects.add(exchange);
}
}
if (!_delayQueue.isEmpty())
_scheduler.schedule(this::onTick, _delayMs / 2, TimeUnit.MILLISECONDS);
}
reject(rejects);
}
private void reject(List rejects)
{
if (rejects != null)
{
for (Exchange exchange : rejects)
{
try
{
if (!_reject.handle(exchange.request, exchange.response, exchange.callback))
exchange.callback.failed(new RejectedExecutionException());
}
catch (Throwable t)
{
exchange.callback.failed(t);
}
}
}
}
}
}