All Downloads are FREE. Search and download functionalities are using the official Maven repository.

dev.resteasy.jetty.client.engine.JettyClientEngine Maven / Gradle / Ivy

The newest version!
package dev.resteasy.jetty.client.engine;

import java.io.InputStream;
import java.io.OutputStream;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.time.Duration;
import java.util.List;
import java.util.Objects;
import java.util.UUID;
import java.util.concurrent.CancellationException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;

import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.SSLContext;

import jakarta.ws.rs.ProcessingException;
import jakarta.ws.rs.WebApplicationException;
import jakarta.ws.rs.client.Invocation;
import jakarta.ws.rs.client.InvocationCallback;
import jakarta.ws.rs.client.ResponseProcessingException;
import jakarta.ws.rs.core.EntityPart;
import jakarta.ws.rs.core.MediaType;
import jakarta.ws.rs.core.MultivaluedHashMap;
import jakarta.ws.rs.core.MultivaluedMap;

import org.eclipse.jetty.client.HttpClient;
import org.eclipse.jetty.client.InputStreamResponseListener;
import org.eclipse.jetty.client.OutputStreamRequestContent;
import org.eclipse.jetty.client.Request;
import org.eclipse.jetty.client.Response;
import org.eclipse.jetty.http.HttpFields;
import org.jboss.logging.Logger;
import org.jboss.resteasy.client.jaxrs.engines.AsyncClientHttpEngine;
import org.jboss.resteasy.client.jaxrs.internal.ClientInvocation;
import org.jboss.resteasy.client.jaxrs.internal.ClientResponse;

public class JettyClientEngine implements AsyncClientHttpEngine {

    private static final Logger LOGGER = Logger.getLogger(JettyClientEngine.class);
    private static final MediaType MULTIPART_WILDCARD = new MediaType("multipart", "*");
    private static final Class MULTIPART_OUTPUT;

    static {
        // Check if the org.jboss.resteasy.plugins.providers.multipart.MultipartOutput is on the class path
        final String className = "org.jboss.resteasy.plugins.providers.multipart.MultipartOutput";
        Class multipartOutput = null;
        try {
            multipartOutput = Class.forName(className, false, resolveClassLoader());
        } catch (ClassNotFoundException e) {
            LOGGER.tracef(e, "Failed to load %s", className);
        }

        MULTIPART_OUTPUT = multipartOutput;
    }
    public static final String REQUEST_TIMEOUT_MS = JettyClientEngine.class + "$RequestTimeout";
    public static final String IDLE_TIMEOUT_MS = JettyClientEngine.class + "$IdleTimeout";
    // Yeah, this is the Jersey one, but there's no standard one and it makes more sense to reuse than make our own...
    public static final String FOLLOW_REDIRECTS = "jersey.config.client.followRedirects";

    private static final InvocationCallback NOP = new InvocationCallback() {
        @Override
        public void completed(ClientResponse response) {
        }

        @Override
        public void failed(Throwable throwable) {
        }
    };

    private final HttpClient client;

    public JettyClientEngine(final HttpClient client) {
        if (!client.isStarted()) {
            try {
                client.start();
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
        this.client = client;
    }

    @Override
    public SSLContext getSslContext() {
        return client.getSslContextFactory().getSslContext();
    }

    @Override
    public HostnameVerifier getHostnameVerifier() {
        throw new UnsupportedOperationException();
    }

    @Override
    public ClientResponse invoke(Invocation invocation) {
        Future future = submit((ClientInvocation) invocation, false, NOP, null);
        try {
            return future.get(1, TimeUnit.HOURS); // There's already an idle and connect timeout, do we need one here?
        } catch (InterruptedException e) {
            future.cancel(true);
            Thread.currentThread().interrupt();
            throw clientException(e, null);
        } catch (TimeoutException | ExecutionException e) {
            future.cancel(true);
            throw clientException(e.getCause(), null);
        }
    }

    @Override
    public  Future submit(ClientInvocation invocation, boolean bufIn, InvocationCallback callback,
            ResultExtractor extractor) {
        return doSubmit(invocation, bufIn, callback, extractor);
    }

    @Override
    public  CompletableFuture submit(ClientInvocation request, boolean buffered, ResultExtractor extractor,
            ExecutorService executorService) {
        return doSubmit(request, buffered, null, extractor);
    }

    private  CompletableFuture doSubmit(ClientInvocation invocation, boolean buffered, InvocationCallback callback,
            ResultExtractor extractor) {
        final ExecutorService asyncExecutor = invocation.asyncInvocationExecutor();

        final Request request = client.newRequest(invocation.getUri());
        final CompletableFuture future = new RequestFuture(request);

        // Determine if this is a multipart request
        final Object entity = invocation.getEntity();
        final boolean addBoundary = isMultipart(invocation) && canSetBoundary(entity);

        invocation.getMutableProperties().forEach(request::attribute);
        request.method(invocation.getMethod());
        request.headers(mutableHeaders -> invocation.getHeaders().asMap()
                .forEach((h, vs) -> vs.forEach(v -> {
                    String headerValue = v;
                    if (addBoundary && h.equalsIgnoreCase("content-type")) {
                        final MediaType mediaType = MediaType.valueOf(v);
                        // Set the boundary if needed
                        if (mediaType.getParameters().get("boundary") == null) {
                            headerValue = headerValue + "; boundary=" + UUID.randomUUID();
                            // Replace the MediaType on the invocation if we've added a boundary
                            invocation.getHeaders().setMediaType(MediaType.valueOf(headerValue));
                        }
                    }
                    mutableHeaders.add(h, headerValue);
                })));
        configureTimeout(request);
        if (request.getAttributes().get(FOLLOW_REDIRECTS) == Boolean.FALSE) {
            request.followRedirects(false);
        }

        if (entity != null) {
            final OutputStreamRequestContent contentOut = new OutputStreamRequestContent(
                    Objects.toString(invocation.getHeaders().getMediaType(), null));
            asyncExecutor.execute(() -> {
                try {
                    try (OutputStream bodyOut = contentOut.getOutputStream()) {
                        invocation.writeRequestBody(bodyOut);
                    }
                } catch (Exception e) { // Also catch any exception thrown from close
                    future.completeExceptionally(e);
                    if (callback != null) {
                        callback.failed(e);
                    }
                }
            });
            request.body(contentOut);
        }

        request.send(new InputStreamResponseListener() {
            private ClientResponse cr;

            @Override
            @SuppressWarnings("unchecked")
            public void onHeaders(Response response) {
                super.onHeaders(response);
                InputStream inputStream = getInputStream();
                cr = new JettyClientResponse(invocation.getClientConfiguration(), inputStream);
                cr.setProperties(invocation.getMutableProperties());
                cr.setStatus(response.getStatus());
                cr.setHeaders(extract(response.getHeaders()));
                asyncExecutor.submit(() -> {
                    try {
                        if (buffered) {
                            cr.bufferEntity();
                        }
                        complete(extractor == null ? (T) cr : extractor.extractResult(cr));
                    } catch (Exception e) {
                        try {
                            inputStream.close();
                        } catch (Exception e1) {
                            e.addSuppressed(e1);
                        }
                        onFailure(response, e);
                    }
                });
            }

            @Override
            public void onFailure(Response response, Throwable failure) {
                super.onFailure(response, failure);
                failed(failure);
            }

            private void complete(T result) {
                future.complete(result);
                if (callback != null) {
                    callback.completed(result);
                }
            }

            private void failed(Throwable t) {
                final RuntimeException x = clientException(t, cr);
                future.completeExceptionally(x);
                if (callback != null) {
                    callback.failed(x);
                }
            }
        });
        return future;
    }

    private void configureTimeout(final Request request) {
        final Object timeout = request.getAttributes().get(REQUEST_TIMEOUT_MS);
        final Object idleTimeout = request.getAttributes().get(IDLE_TIMEOUT_MS);
        final long timeoutMs = parseTimeoutMs(timeout);
        final long idleTimeoutMs = parseTimeoutMs(idleTimeout);
        if (timeoutMs > 0) {
            request.timeout(timeoutMs, TimeUnit.MILLISECONDS);
        }

        if (idleTimeoutMs > 0) {
            request.idleTimeout(idleTimeoutMs, TimeUnit.MILLISECONDS);
        }
    }

    private long parseTimeoutMs(final Object timeout) {
        final long timeoutMs;
        if (timeout instanceof Duration) {
            timeoutMs = ((Duration) timeout).toMillis();
        } else if (timeout instanceof Number) {
            timeoutMs = ((Number) timeout).intValue();
        } else if (timeout != null) {
            timeoutMs = Integer.parseInt(timeout.toString());
        } else {
            timeoutMs = -1;
        }
        return timeoutMs;
    }

    @Override
    public void close() {
        try {
            client.stop();
        } catch (Exception e) {
            throw new RuntimeException("Unable to close JettyHttpEngine", e);
        }
    }

    MultivaluedMap extract(HttpFields headers) {
        final MultivaluedMap extracted = new MultivaluedHashMap<>();
        headers.forEach(h -> extracted.add(h.getName(), h.getValue()));
        return extracted;
    }

    private static RuntimeException clientException(Throwable ex, jakarta.ws.rs.core.Response clientResponse) {
        RuntimeException ret;
        if (ex == null) {
            final NullPointerException e = new NullPointerException();
            e.fillInStackTrace();
            ret = new ProcessingException(e);
        } else if (ex instanceof WebApplicationException) {
            ret = (WebApplicationException) ex;
        } else if (ex instanceof ProcessingException) {
            ret = (ProcessingException) ex;
        } else if (clientResponse != null) {
            ret = new ResponseProcessingException(clientResponse, ex);
        } else {
            ret = new ProcessingException(ex);
        }
        ret.fillInStackTrace();
        return ret;
    }

    private static ClassLoader resolveClassLoader() {
        if (System.getSecurityManager() == null) {
            ClassLoader cl = Thread.currentThread().getContextClassLoader();
            if (cl == null) {
                cl = JettyClientEngine.class.getClassLoader();
            }
            return cl == null ? ClassLoader.getSystemClassLoader() : cl;
        }
        return AccessController.doPrivileged((PrivilegedAction) () -> {
            ClassLoader cl = Thread.currentThread().getContextClassLoader();
            if (cl == null) {
                cl = JettyClientEngine.class.getClassLoader();
            }
            return cl == null ? ClassLoader.getSystemClassLoader() : cl;
        });
    }

    private static boolean isMultipart(final ClientInvocation invocation) {
        return MULTIPART_WILDCARD.isCompatible(invocation.getHeaders().getMediaType());
    }

    private static boolean canSetBoundary(final Object entity) {
        if (MULTIPART_OUTPUT != null && MULTIPART_OUTPUT.isInstance(entity)) {
            return true;
        }
        if (entity instanceof EntityPart) {
            return true;
        }
        if (entity instanceof final List list) {
            // We're a list, if we're not empty check the first type to see if it's an entity part
            if (!list.isEmpty()) {
                return list.get(0) instanceof EntityPart;
            }
        }
        return false;
    }

    static class RequestFuture extends CompletableFuture {
        private final Request request;

        RequestFuture(final Request request) {
            this.request = request;
        }

        @Override
        public boolean cancel(boolean mayInterruptIfRunning) {
            final boolean cancelled = super.cancel(mayInterruptIfRunning);
            if (mayInterruptIfRunning && cancelled) {
                request.abort(new CancellationException());
            }
            return cancelled;
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy