All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.eclipse.jetty.server.handler.ThreadLimitHandler Maven / Gradle / Ivy

There is a newer version: 12.1.0.alpha0
Show newest version
//
// ========================================================================
// Copyright (c) 1995 Mort Bay Consulting Pty Ltd and others.
//
// This program and the accompanying materials are made available under the
// terms of the Eclipse Public License v. 2.0 which is available at
// https://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0
// which is available at https://www.apache.org/licenses/LICENSE-2.0.
//
// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0
// ========================================================================
//

package org.eclipse.jetty.server.handler;

import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.WritePendingException;
import java.util.ArrayDeque;
import java.util.Deque;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;

import org.eclipse.jetty.http.HostPortHttpField;
import org.eclipse.jetty.http.HttpField;
import org.eclipse.jetty.http.HttpHeader;
import org.eclipse.jetty.http.HttpStatus;
import org.eclipse.jetty.http.QuotedCSV;
import org.eclipse.jetty.server.ForwardedRequestCustomizer;
import org.eclipse.jetty.server.Handler;
import org.eclipse.jetty.server.Request;
import org.eclipse.jetty.server.Response;
import org.eclipse.jetty.util.Callback;
import org.eclipse.jetty.util.StringUtil;
import org.eclipse.jetty.util.annotation.ManagedAttribute;
import org.eclipse.jetty.util.annotation.ManagedOperation;
import org.eclipse.jetty.util.annotation.Name;
import org.eclipse.jetty.util.thread.AutoLock;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * 

Handler to limit the threads per IP address for DOS protection

*

The ThreadLimitHandler applies a limit to the number of Threads * that can be used simultaneously per remote IP address.

*

The handler makes a determination of the remote IP separately to * any that may be made by the {@link ForwardedRequestCustomizer} or similar:

*
    *
  • This handler will use only a single style of forwarded header. * This is on the assumption that a trusted local proxy * will produce only a single forwarded header and that any additional * headers are likely from untrusted client side proxies.
  • *
  • If multiple instances of a forwarded header are provided, this * handler will use the right-most instance, which will have been set from * the trusted local proxy
  • *
*

Requests in excess of the limit will be asynchronously suspended until * a thread is available.

*/ public class ThreadLimitHandler extends ConditionalHandler.Abstract { private static final Logger LOG = LoggerFactory.getLogger(ThreadLimitHandler.class); private final boolean _rfc7239; private final String _forwardedHeader; private final ConcurrentMap _remotes = new ConcurrentHashMap<>(); private volatile boolean _enabled; private int _threadLimit = 10; public ThreadLimitHandler() { this(null, null, true); } public ThreadLimitHandler(@Name("forwardedHeader") String forwardedHeader) { this(null, forwardedHeader, HttpHeader.FORWARDED.is(forwardedHeader)); } public ThreadLimitHandler(@Name("forwardedHeader") String forwardedHeader, @Name("rfc7239") boolean rfc7239) { this(null, forwardedHeader, rfc7239); } public ThreadLimitHandler(@Name("handler") Handler handler, @Name("forwardedHeader") String forwardedHeader, @Name("rfc7239") boolean rfc7239) { super(handler); _rfc7239 = rfc7239; _forwardedHeader = forwardedHeader; _enabled = true; } @Override protected void doStart() throws Exception { super.doStart(); LOG.info(String.format("ThreadLimitHandler enable=%b limit=%d", _enabled, _threadLimit)); } @ManagedAttribute("true if this handler is enabled") public boolean isEnabled() { return _enabled; } public void setEnabled(boolean enabled) { _enabled = enabled; LOG.info(String.format("ThreadLimitHandler enable=%b limit=%d", _enabled, _threadLimit)); } @ManagedAttribute("The maximum threads that can be dispatched per remote IP") public int getThreadLimit() { return _threadLimit; } protected int getThreadLimit(String ip) { return _threadLimit; } public void setThreadLimit(int threadLimit) { if (threadLimit <= 0) throw new IllegalArgumentException("limit must be >0"); _threadLimit = threadLimit; } @ManagedOperation("Include IP in thread limits") public void include(String inetAddressPattern) { includeInetAddressPattern(inetAddressPattern); } @ManagedOperation("Exclude IP from thread limits") public void exclude(String inetAddressPattern) { excludeInetAddressPattern(inetAddressPattern); } @Override public boolean onConditionsMet(Request request, Response response, Callback callback) throws Exception { Handler next = getHandler(); if (next == null) return false; if (!_enabled) return next.handle(request, response, callback); // Get the remote address of the request Remote remote = getRemote(request); if (remote == null) { // if remote is not known, handle normally return next.handle(request, response, callback); } // We accept the request and will always handle it. LimitedRequest limitedRequest = new LimitedRequest(remote, next, request, response, callback); limitedRequest.handle(); return true; } @Override protected boolean onConditionsNotMet(Request request, Response response, Callback callback) throws Exception { return nextHandler(request, response, callback); } private Remote getRemote(Request baseRequest) { String ip = getRemoteIP(baseRequest); LOG.debug("ip={}", ip); if (ip == null) return null; int limit = getThreadLimit(ip); if (limit <= 0) return null; Remote remote = _remotes.get(ip); if (remote == null) { Remote r = new Remote(baseRequest.getContext(), ip, limit); remote = _remotes.putIfAbsent(ip, r); if (remote == null) remote = r; } return remote; } protected String getRemoteIP(Request baseRequest) { // Do we have a forwarded header set? if (_forwardedHeader != null && !_forwardedHeader.isEmpty()) { // Yes, then try to get the remote IP from the header String remote = _rfc7239 ? getForwarded(baseRequest) : getXForwardedFor(baseRequest); if (remote != null && !remote.isEmpty()) return remote; } // If no remote IP from a header, determine it directly from the channel // Do not use the request methods, as they may have been lied to by the // RequestCustomizer! if (baseRequest.getConnectionMetaData().getRemoteSocketAddress() instanceof InetSocketAddress inetAddr) { // TODO ???? if (inetAddr.getAddress() != null) return inetAddr.getAddress().getHostAddress(); } return null; } private String getForwarded(Request request) { // Get the right most Forwarded for value. // This is the value from the closest proxy and the only one that // can be trusted. RFC7239 rfc7239 = new RFC7239(); for (HttpField field : request.getHeaders()) { if (_forwardedHeader.equalsIgnoreCase(field.getName())) rfc7239.addValue(field.getValue()); } if (rfc7239.getFor() != null) return new HostPortHttpField(rfc7239.getFor()).getHost(); return null; } private String getXForwardedFor(Request request) { // Get the right most XForwarded-For for value. // This is the value from the closest proxy and the only one that // can be trusted. String forwardedFor = null; for (HttpField field : request.getHeaders()) { if (_forwardedHeader.equalsIgnoreCase(field.getName())) forwardedFor = field.getValue(); } if (forwardedFor == null || forwardedFor.isEmpty()) return null; int comma = forwardedFor.lastIndexOf(','); return (comma >= 0) ? forwardedFor.substring(comma + 1).trim() : forwardedFor; } private static class LimitedRequest extends Request.Wrapper { private final Remote _remote; private final Handler _handler; private final LimitedResponse _response; private final Callback _callback; private final AtomicReference _onContent = new AtomicReference<>(); public LimitedRequest(Remote remote, Handler handler, Request request, Response response, Callback callback) { super(request); _remote = remote; _handler = Objects.requireNonNull(handler); _response = new LimitedResponse(this, response); _callback = Objects.requireNonNull(callback); } protected Handler getHandler() { return _handler; } protected Response getResponse() { return _response; } protected Callback getCallback() { return _callback; } protected void handle() throws Exception { Permit permit = _remote.acquire(); // Did we get a permit? if (permit.isAllocated()) { if (LOG.isDebugEnabled()) LOG.debug("Thread permitted {} {} {}", _remote, getWrapped(), _handler); handle(permit); } else { if (LOG.isDebugEnabled()) LOG.debug("Thread limited {} {} {}", _remote, getWrapped(), _handler); permit.whenAllocated(this::handle); } } protected void handle(Permit permit) { try { if (!_handler.handle(this, _response, _callback)) Response.writeError(this, _response, _callback, HttpStatus.NOT_FOUND_404); } catch (Throwable x) { _callback.failed(x); } finally { permit.release(); } } @Override public void demand(Runnable onContent) { if (!_onContent.compareAndSet(null, Objects.requireNonNull(onContent))) throw new IllegalStateException("Pending demand"); super.demand(this::onContent); } private void onContent() { Permit permit = _remote.acquire(); if (permit.isAllocated()) onPermittedContent(permit); else permit.whenAllocated(this::onPermittedContent); } private void onPermittedContent(Permit permit) { try { Runnable onContent = _onContent.getAndSet(null); onContent.run(); } finally { permit.release(); } } } private static class LimitedResponse extends Response.Wrapper implements Callback { private final Remote _remote; private final AtomicReference _writeCallback = new AtomicReference<>(); public LimitedResponse(LimitedRequest limitedRequest, Response response) { super(limitedRequest, response); _remote = limitedRequest._remote; } @Override public void write(boolean last, ByteBuffer byteBuffer, Callback callback) { if (!_writeCallback.compareAndSet(null, Objects.requireNonNull(callback))) throw new WritePendingException(); super.write(last, byteBuffer, this); } @Override public void succeeded() { Permit permit = _remote.acquire(); if (permit.isAllocated()) permittedSuccess(permit); else permit.whenAllocated(this::permittedSuccess); } private void permittedSuccess(Permit permit) { try { _writeCallback.getAndSet(null).succeeded(); } finally { permit.release(); } } @Override public void failed(Throwable x) { Permit permit = _remote.acquire(); if (permit.isAllocated()) permittedFailure(permit, x); else permit.whenAllocated(p -> permittedFailure(p, x)); } private void permittedFailure(Permit permit, Throwable x) { try { _writeCallback.getAndSet(null).failed(x); } finally { permit.release(); } } } private interface Permit { boolean isAllocated(); void whenAllocated(Consumer permitConsumer); void release(); } private static class NoopPermit implements Permit { @Override public boolean isAllocated() { return true; } @Override public void whenAllocated(Consumer permitConsumer) { throw new UnsupportedOperationException(); } @Override public void release() { } } private static class AllocatedPermit implements Permit { private final Remote _remote; private AllocatedPermit(Remote remote) { _remote = remote; } @Override public boolean isAllocated() { return true; } @Override public void whenAllocated(Consumer permitConsumer) { throw new UnsupportedOperationException(); } @Override public void release() { _remote.release(); } @Override public String toString() { return "AllocatedPermit:" + _remote; } } private static class FuturePermit implements Permit { private final CompletableFuture _future = new CompletableFuture<>(); private final Remote _remote; private FuturePermit(Remote remote) { _remote = remote; } public boolean isAllocated() { return _future.isDone(); } public void whenAllocated(Consumer permitConsumer) { _future.thenAccept(permitConsumer); } void complete() { if (!_future.complete(this)) throw new IllegalStateException(); } public void release() { _remote.release(); } } private static final class Remote { private final Executor _executor; private final String _ip; private final int _limit; private final AutoLock _lock = new AutoLock(); private int _permits; private final Deque _queue = new ArrayDeque<>(); private final Permit _permitted = new AllocatedPermit(this); private final ThreadLocal _threadPermit = new ThreadLocal<>(); private static final Permit NOOP = new NoopPermit(); public Remote(Executor executor, String ip, int limit) { _executor = executor; _ip = ip; _limit = limit; } Permit acquire() { try (AutoLock lock = _lock.lock()) { // Does this thread already have an available pass if (_threadPermit.get() == Boolean.TRUE) return NOOP; // Do we have available passes? if (_permits < _limit) { // Yes - increment the allocated passes _permits++; _threadPermit.set(Boolean.TRUE); // return the already completed future return _permitted; } // No pass available, so queue a new future FuturePermit futurePermit = new FuturePermit(this); _queue.addLast(futurePermit); return futurePermit; } } public void release() { FuturePermit pending; try (AutoLock lock = _lock.lock()) { // reduce the allocated passes _permits--; _threadPermit.set(Boolean.FALSE); // Are there any future passes pending? pending = _queue.pollFirst(); // yes, allocate them a permit if (pending != null) _permits++; } if (pending != null) { // We cannot complete the pending in this thread, as we may be in handle(), demand() or write // callback that is serialized and other actions are waiting for the return. Thus, we must execute. _executor.execute(pending::complete); } } @Override public String toString() { try (AutoLock lock = _lock.lock()) { return String.format("R[ip=%s,p=%d,l=%d,q=%d]", _ip, _permits, _limit, _queue.size()); } } } private static final class RFC7239 extends QuotedCSV { String _for; private RFC7239() { super(false); } String getFor() { return _for; } @Override protected void parsedParam(StringBuilder buffer, int valueLength, int paramName, int paramValue) { if (valueLength == 0 && paramValue > paramName) { String name = StringUtil.asciiToLowerCase(buffer.substring(paramName, paramValue - 1)); if ("for".equalsIgnoreCase(name)) { String value = buffer.substring(paramValue); // if unknown, clear any leftward values if ("unknown".equalsIgnoreCase(value)) _for = null; // Otherwise accept IP or token(starting with '_') as remote keys else _for = value; } } } } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy