io.quarkus.resteasy.reactive.server.servlet.runtime.ServletRequestContext Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of quarkus-rest-servlet Show documentation
Show all versions of quarkus-rest-servlet Show documentation
Servlet support for Quarkus REST
package io.quarkus.resteasy.reactive.server.servlet.runtime;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Enumeration;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import jakarta.enterprise.event.Event;
import jakarta.servlet.AsyncContext;
import jakarta.servlet.ReadListener;
import jakarta.servlet.ServletInputStream;
import jakarta.servlet.ServletOutputStream;
import jakarta.servlet.WriteListener;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import jakarta.ws.rs.core.SecurityContext;
import org.jboss.resteasy.reactive.server.core.Deployment;
import org.jboss.resteasy.reactive.server.core.ResteasyReactiveRequestContext;
import org.jboss.resteasy.reactive.server.spi.ServerHttpRequest;
import org.jboss.resteasy.reactive.server.spi.ServerHttpResponse;
import org.jboss.resteasy.reactive.server.spi.ServerRestHandler;
import org.jboss.resteasy.reactive.spi.ThreadSetupAction;
import io.netty.channel.EventLoop;
import io.netty.handler.codec.http.QueryStringDecoder;
import io.netty.util.concurrent.ScheduledFuture;
import io.quarkus.arc.Arc;
import io.quarkus.arc.impl.LazyValue;
import io.quarkus.resteasy.reactive.server.runtime.ResteasyReactiveSecurityContext;
import io.quarkus.runtime.BlockingOperationControl;
import io.quarkus.security.identity.SecurityIdentity;
import io.quarkus.vertx.http.runtime.security.QuarkusHttpUser;
import io.undertow.server.HttpServerExchange;
import io.undertow.server.ResponseCommitListener;
import io.vertx.core.Handler;
import io.vertx.core.MultiMap;
import io.vertx.core.http.HttpServerRequest;
import io.vertx.core.http.HttpServerResponse;
import io.vertx.core.net.impl.ConnectionBase;
import io.vertx.ext.web.RoutingContext;
public class ServletRequestContext extends ResteasyReactiveRequestContext
implements ServerHttpRequest, ServerHttpResponse, ResponseCommitListener {
private static final LazyValue> SECURITY_IDENTITY_EVENT = new LazyValue<>(
ServletRequestContext::createEvent);
final RoutingContext context;
final HttpServletRequest request;
final HttpServletResponse response;
AsyncContext asyncContext;
ServletWriteListener writeListener;
ServletReadListener readListener;
byte[] asyncWriteData;
boolean closed;
Consumer asyncWriteHandler;
protected Consumer preCommitTask;
public ServletRequestContext(Deployment deployment,
HttpServletRequest request, HttpServletResponse response,
ThreadSetupAction requestContext, ServerRestHandler[] handlerChain, ServerRestHandler[] abortHandlerChain,
RoutingContext context, HttpServerExchange exchange) {
super(deployment, requestContext, handlerChain, abortHandlerChain);
this.request = request;
this.response = response;
this.context = context;
exchange.addResponseCommitListener(this);
}
protected boolean isRequestScopeManagementRequired() {
return asyncContext != null;
}
protected void beginAsyncProcessing() {
asyncContext = request.startAsync();
}
@Override
public synchronized void close() {
if (asyncWriteData != null) {
closed = true;
} else {
super.close();
if (asyncContext != null) {
asyncContext.complete();
}
}
}
@Override
public ServerHttpResponse addCloseHandler(Runnable onClose) {
context.response().closeHandler(v -> onClose.run());
return this;
}
@Override
public ServerHttpRequest serverRequest() {
return this;
}
@Override
public ServerHttpResponse serverResponse() {
return this;
}
@Override
protected void setQueryParamsFrom(String uri) {
MultiMap map = context.queryParams();
map.clear();
Map> decodedParams = new QueryStringDecoder(uri).parameters();
for (Map.Entry> entry : decodedParams.entrySet()) {
map.add(entry.getKey(), entry.getValue());
}
}
protected void handleRequestScopeActivation() {
super.handleRequestScopeActivation();
QuarkusHttpUser user = (QuarkusHttpUser) context.user();
if (user != null) {
fireSecurityIdentity(user.getSecurityIdentity());
}
}
static void fireSecurityIdentity(SecurityIdentity identity) {
SECURITY_IDENTITY_EVENT.get().fire(identity);
}
static void clear() {
SECURITY_IDENTITY_EVENT.clear();
}
private static Event createEvent() {
return Arc.container().beanManager().getEvent().select(SecurityIdentity.class);
}
protected SecurityContext createSecurityContext() {
return new ResteasyReactiveSecurityContext(context);
}
@Override
protected EventLoop getEventLoop() {
return ((ConnectionBase) context.request().connection()).channel().eventLoop();
}
@Override
public Runnable registerTimer(long millis, Runnable task) {
ScheduledFuture> handle = getEventLoop().schedule(task, millis, TimeUnit.MILLISECONDS);
return new Runnable() {
@Override
public void run() {
handle.cancel(false);
}
};
}
@Override
public boolean resumeExternalProcessing() {
return false;
}
@Override
public String getRequestHeader(CharSequence name) {
return request.getHeader(name.toString());
}
@Override
public Iterable> getAllRequestHeaders() {
List> ret = new ArrayList<>();
Enumeration headerNames = request.getHeaderNames();
while (headerNames.hasMoreElements()) {
String name = headerNames.nextElement();
for (String v : new EnumerationIterable<>(request.getHeaders(name))) {
ret.add(new MapEntry<>(name, v));
}
}
return ret;
}
@Override
public List getAllRequestHeaders(String name) {
Enumeration headers = request.getHeaders(name);
if (headers == null) {
return Collections.emptyList();
}
List ret = new ArrayList<>();
while (headers.hasMoreElements()) {
ret.add(headers.nextElement());
}
return ret;
}
@Override
public boolean containsRequestHeader(CharSequence accept) {
return request.getHeader(accept.toString()) != null;
}
@Override
public String getRequestPath() {
return request.getServletPath();
}
@Override
public String getRequestMethod() {
return request.getMethod();
}
@Override
public String getRequestNormalisedPath() {
return context.normalizedPath();
}
@Override
public String getRequestAbsoluteUri() {
if (request.getQueryString() == null) {
return request.getRequestURL().toString();
} else {
return request.getRequestURL().append("?").append(request.getQueryString()).toString();
}
}
@Override
public String getRequestScheme() {
return request.getScheme();
}
@Override
public String getRequestHost() {
return context.request().authority().toString();
}
@Override
public void closeConnection() {
try {
response.getOutputStream().close();
} catch (IOException e) {
//ignore
}
context.request().connection().close();
}
@Override
public String getQueryParam(String name) {
if (!context.queryParams().contains(name)) {
return null;
}
return request.getParameter(name);
}
@Override
public List getAllQueryParams(String name) {
return context.queryParam(name);
}
@Override
public String query() {
return request.getQueryString();
}
@Override
public Collection queryParamNames() {
return context.queryParams().names();
}
@Override
public boolean isRequestEnded() {
return context.request().isEnded();
}
@Override
public InputStream createInputStream(ByteBuffer existingData) {
return new ServletResteasyReactiveInputStream(existingData, request);
}
@Override
public InputStream createInputStream() {
try {
return request.getInputStream();
} catch (IOException e) {
throw new RuntimeException(e);
}
}
@Override
public ServerHttpResponse pauseRequestInput() {
//TODO
return this;
}
@Override
public ServerHttpResponse resumeRequestInput() {
return this;
}
@Override
public ServerHttpResponse setReadListener(ReadCallback callback) {
try {
ServletInputStream in = request.getInputStream();
if (!request.isAsyncStarted()) {
request.startAsync();
}
in.setReadListener(new ServletReadListener(in, callback));
} catch (IOException e) {
resume(e);
}
return this;
}
@SuppressWarnings("unchecked")
@Override
public T unwrap(Class theType) {
if (theType == RoutingContext.class) {
return (T) context;
} else if (theType == HttpServerRequest.class) {
return (T) context.request();
} else if (theType == HttpServerResponse.class) {
return (T) context.response();
} else if (theType == HttpServletRequest.class) {
return (T) request;
} else if (theType == HttpServletResponse.class) {
return (T) response;
}
return null;
}
@Override
public boolean isOnIoThread() {
//does not really apply to Servlet
return true;
}
@Override
public ServerHttpResponse setStatusCode(int code) {
response.setStatus(code);
return this;
}
@Override
public ServerHttpResponse end() {
try {
response.getOutputStream().close();
} catch (IOException e) {
//ignore
}
return this;
}
@Override
public boolean headWritten() {
return response.isCommitted();
}
@Override
public ServerHttpResponse end(byte[] data) {
if (BlockingOperationControl.isBlockingAllowed()) {
try {
response.getOutputStream().write(data);
response.getOutputStream().close();
} catch (IOException e) {
log.debug("IoException writing response", e);
}
} else {
write(data, new Consumer() {
@Override
public void accept(Throwable throwable) {
try {
response.getOutputStream().close();
} catch (IOException e) {
log.debug("IoException writing response", e);
}
}
});
}
return this;
}
@Override
public ServerHttpResponse end(String data) {
end(data.getBytes(StandardCharsets.UTF_8));
return this;
}
@Override
public ServerHttpResponse addResponseHeader(CharSequence name, CharSequence value) {
response.addHeader(name.toString(), value.toString());
return this;
}
@Override
public ServerHttpResponse setResponseHeader(CharSequence name, CharSequence value) {
response.setHeader(name.toString(), value != null ? value.toString() : null);
return this;
}
@Override
public ServerHttpResponse setResponseHeader(CharSequence name, Iterable values) {
for (CharSequence v : values) {
response.addHeader(name.toString(), v.toString());
}
return this;
}
@Override
public Iterable> getAllResponseHeaders() {
List> ret = new ArrayList<>();
Collection headerNames = response.getHeaderNames();
for (String name : headerNames) {
for (String v : response.getHeaders(name)) {
ret.add(new MapEntry<>(name, v));
}
}
return ret;
}
@Override
public String getResponseHeader(String name) {
return response.getHeader(name);
}
@Override
public void removeResponseHeader(String name) {
// Servlet API does not support this functionality
throw new UnsupportedOperationException();
}
@Override
public boolean closed() {
return context.response().closed();
}
@Override
public ServerHttpResponse setChunked(boolean chunked) {
context.response().setChunked(chunked);
return this;
}
@Override
public ServerHttpResponse write(byte[] data, Consumer asyncResultHandler) {
if (asyncWriteData != null) {
asyncResultHandler.accept(new IllegalStateException("Cannot write before data has all been written"));
}
if (asyncContext == null) {
try {
response.getOutputStream().write(data);
asyncResultHandler.accept(null);
} catch (IOException e) {
asyncResultHandler.accept(e);
}
} else {
synchronized (this) {
if (asyncWriteData != null) {
throw new IllegalStateException("Cannot write more than one piece of async data at a time");
}
asyncWriteData = data;
asyncWriteHandler = asyncResultHandler;
if (writeListener == null) {
try {
ServletOutputStream outputStream = response.getOutputStream();
outputStream.setWriteListener(writeListener = new ServletWriteListener(outputStream));
} catch (IOException e) {
asyncResultHandler.accept(e);
}
} else {
writeListener.onWritePossible();
}
}
}
return this;
}
@Override
public CompletionStage write(byte[] data) {
CompletableFuture cf = new CompletableFuture<>();
write(data, new Consumer() {
@Override
public void accept(Throwable throwable) {
if (throwable == null) {
cf.complete(null);
} else {
cf.completeExceptionally(throwable);
}
}
});
return cf;
}
@Override
public OutputStream createResponseOutputStream() {
try {
return response.getOutputStream();
} catch (IOException e) {
throw new RuntimeException(e);
}
}
@Override
public void setPreCommitListener(Consumer task) {
preCommitTask = task;
}
@Override
public void beforeCommit(HttpServerExchange exchange) {
if (preCommitTask != null) {
preCommitTask.accept(this);
}
}
class ServletWriteListener implements WriteListener {
private final ServletOutputStream outputStream;
ServletWriteListener(ServletOutputStream outputStream) {
this.outputStream = outputStream;
}
@Override
public void onWritePossible() {
synchronized (ServletRequestContext.this) {
if (!outputStream.isReady()) {
return;
}
Consumer ctx = asyncWriteHandler;
byte[] data = asyncWriteData;
asyncWriteHandler = null;
asyncWriteData = null;
try {
outputStream.write(data);
ctx.accept(null);
} catch (IOException e) {
ctx.accept(e);
}
if (closed) {
close();
}
}
}
@Override
public synchronized void onError(Throwable t) {
synchronized (ServletRequestContext.this) {
if (asyncWriteHandler != null) {
Consumer ctx = asyncWriteHandler;
asyncWriteHandler = null;
asyncWriteData = null;
ctx.accept(t);
close();
}
}
}
}
class ServletReadListener implements ReadListener {
final ServletInputStream inputStream;
final ReadCallback readCallback;
boolean paused;
boolean allDone;
Throwable problem;
ServletReadListener(ServletInputStream inputStream, ReadCallback readCallback) {
this.inputStream = inputStream;
this.readCallback = readCallback;
}
@Override
public void onDataAvailable() throws IOException {
synchronized (this) {
if (paused) {
return;
}
}
doRead();
}
private void doRead() {
if (inputStream.isReady()) {
byte[] buf = new byte[1024];
try {
int r = inputStream.read(buf);
readCallback.data(ByteBuffer.wrap(buf, 0, r));
} catch (IOException e) {
ServletRequestContext.this.resume(problem);
}
}
}
synchronized void pause() {
paused = true;
}
void resume() {
boolean allDone;
Throwable problem;
synchronized (this) {
paused = false;
allDone = this.allDone;
this.allDone = false;
problem = this.problem;
this.problem = null;
}
if (problem != null) {
ServletRequestContext.this.resume(problem);
} else if (allDone) {
readCallback.done();
} else {
doRead();
}
}
@Override
public void onAllDataRead() throws IOException {
synchronized (this) {
if (paused) {
allDone = true;
return;
}
}
readCallback.done();
}
@Override
public void onError(Throwable t) {
synchronized (this) {
if (paused) {
problem = t;
return;
}
}
ServletRequestContext.this.resume(t);
}
}
static final class MapEntry implements Map.Entry {
private final K key;
private V value;
MapEntry(K key, V value) {
this.key = key;
this.value = value;
}
@Override
public K getKey() {
return key;
}
@Override
public V getValue() {
return value;
}
@Override
public V setValue(V value) {
V old = value;
this.value = value;
return old;
}
}
@Override
public ServerHttpResponse sendFile(String path, long offset, long length) {
context.response().sendFile(path, offset, length);
return this;
}
@Override
public boolean isWriteQueueFull() {
return context.response().writeQueueFull();
}
@Override
public ServerHttpResponse addDrainHandler(Runnable onDrain) {
context.response().drainHandler(new Handler() {
@Override
public void handle(Void event) {
onDrain.run();
}
});
return this;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy