All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.tomakehurst.wiremock.servlet.WireMockHandlerDispatchingServlet Maven / Gradle / Ivy

/*
 * Copyright (C) 2011-2024 Thomas Akehurst
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.github.tomakehurst.wiremock.servlet;

import static com.github.tomakehurst.wiremock.common.ContentTypes.CONTENT_LENGTH;
import static com.github.tomakehurst.wiremock.common.Exceptions.throwUnchecked;
import static com.github.tomakehurst.wiremock.common.ParameterUtils.getFirstNonNull;
import static com.github.tomakehurst.wiremock.core.Options.ChunkedEncodingPolicy.BODY_FILE;
import static com.github.tomakehurst.wiremock.core.Options.ChunkedEncodingPolicy.NEVER;
import static com.github.tomakehurst.wiremock.http.RequestMethod.GET;
import static com.github.tomakehurst.wiremock.servlet.WireMockHttpServletRequestAdapter.ORIGINAL_REQUEST_KEY;
import static com.github.tomakehurst.wiremock.stubbing.ServeEvent.ORIGINAL_SERVE_EVENT_KEY;
import static java.net.HttpURLConnection.HTTP_NOT_FOUND;
import static java.net.URLDecoder.decode;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.concurrent.TimeUnit.MILLISECONDS;

import com.github.tomakehurst.wiremock.common.LocalNotifier;
import com.github.tomakehurst.wiremock.common.Notifier;
import com.github.tomakehurst.wiremock.core.FaultInjector;
import com.github.tomakehurst.wiremock.core.Options;
import com.github.tomakehurst.wiremock.core.WireMockApp;
import com.github.tomakehurst.wiremock.http.*;
import com.github.tomakehurst.wiremock.jetty.JettyHttpUtils;
import com.github.tomakehurst.wiremock.stubbing.ServeEvent;
import com.github.tomakehurst.wiremock.verification.LoggedRequest;
import jakarta.servlet.*;
import jakarta.servlet.http.HttpServlet;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.io.InputStream;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ScheduledExecutorService;

public class WireMockHandlerDispatchingServlet extends HttpServlet {

  public static final String SHOULD_FORWARD_TO_FILES_CONTEXT = "shouldForwardToFilesContext";
  public static final String ASYNCHRONOUS_RESPONSE_EXECUTOR =
      WireMockHandlerDispatchingServlet.class.getSimpleName() + ".asynchronousResponseExecutor";
  public static final String MAPPED_UNDER_KEY = "mappedUnder";

  private static final long serialVersionUID = -6602042274260495538L;

  private ScheduledExecutorService scheduledExecutorService;

  private RequestHandler requestHandler;
  private FaultInjectorFactory faultHandlerFactory;
  private String mappedUnder;
  private Notifier notifier;
  private String wiremockFileSourceRoot = "/";
  private boolean shouldForwardToFilesContext;
  private MultipartRequestConfigurer multipartRequestConfigurer;
  private Options.ChunkedEncodingPolicy chunkedEncodingPolicy;
  private boolean browserProxyingEnabled;
  private JettyHttpUtils utils;

  @Override
  public void init(ServletConfig config) {
    ServletContext context = config.getServletContext();
    shouldForwardToFilesContext = getFileContextForwardingFlagFrom(config);

    if (context.getInitParameter("WireMockFileSourceRoot") != null) {
      wiremockFileSourceRoot = context.getInitParameter("WireMockFileSourceRoot");
    }

    scheduledExecutorService =
        (ScheduledExecutorService) context.getAttribute(ASYNCHRONOUS_RESPONSE_EXECUTOR);

    String handlerClassName = config.getInitParameter(RequestHandler.HANDLER_CLASS_KEY);
    String faultInjectorFactoryClassName =
        config.getInitParameter(FaultInjectorFactory.INJECTOR_CLASS_KEY);
    mappedUnder = getNormalizedMappedUnder(config);
    context.log(
        RequestHandler.HANDLER_CLASS_KEY
            + " from context returned "
            + handlerClassName
            + ". Normalized mapped under returned '"
            + mappedUnder
            + "'");
    requestHandler = (RequestHandler) context.getAttribute(handlerClassName);

    faultHandlerFactory =
        faultInjectorFactoryClassName != null
            ? (FaultInjectorFactory) context.getAttribute(faultInjectorFactoryClassName)
            : new NoFaultInjectorFactory();

    notifier = (Notifier) context.getAttribute(Notifier.KEY);

    multipartRequestConfigurer =
        (MultipartRequestConfigurer) context.getAttribute(MultipartRequestConfigurer.KEY);

    Object chunkedEncodingPolicyAttr =
        context.getAttribute(Options.ChunkedEncodingPolicy.class.getName());
    chunkedEncodingPolicy =
        chunkedEncodingPolicyAttr != null
            ? (Options.ChunkedEncodingPolicy) chunkedEncodingPolicyAttr
            : Options.ChunkedEncodingPolicy.ALWAYS;

    browserProxyingEnabled =
        Boolean.parseBoolean(
            getFirstNonNull(context.getAttribute("browserProxyingEnabled"), "false").toString());

    utils = (JettyHttpUtils) context.getAttribute(JettyHttpUtils.class.getName());
  }

  private String getNormalizedMappedUnder(ServletConfig config) {
    String mappedUnder = config.getInitParameter(MAPPED_UNDER_KEY);
    if (mappedUnder == null) {
      return null;
    }
    if (mappedUnder.endsWith("/")) {
      mappedUnder = mappedUnder.substring(0, mappedUnder.length() - 1);
    }
    return mappedUnder;
  }

  private boolean getFileContextForwardingFlagFrom(ServletConfig config) {
    String flagValue = config.getInitParameter(SHOULD_FORWARD_TO_FILES_CONTEXT);
    return Boolean.parseBoolean(flagValue);
  }

  @Override
  protected void service(
      HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse)
      throws ServletException, IOException {
    LocalNotifier.set(notifier);

    // TODO: The HTTP/1.x CONNECT is also forwarded to the servlet now. To keep backward
    // compatible behavior (with proxy involved), skipping the CONNECT handling altogether.
    if (Objects.equals(httpServletRequest.getMethod(), "CONNECT")) {
      return;
    }

    Request request =
        new WireMockHttpServletRequestAdapter(
            httpServletRequest,
            multipartRequestConfigurer,
            mappedUnder,
            browserProxyingEnabled,
            utils);

    ServletHttpResponder responder =
        new ServletHttpResponder(httpServletRequest, httpServletResponse);

    final ServeEvent originalServeEvent =
        httpServletRequest.getAttribute(ORIGINAL_SERVE_EVENT_KEY) != null
            ? (ServeEvent) httpServletRequest.getAttribute(ORIGINAL_SERVE_EVENT_KEY)
            : null;

    requestHandler.handle(request, responder, originalServeEvent);
  }

  private class ServletHttpResponder implements HttpResponder {

    private final HttpServletRequest httpServletRequest;
    private final HttpServletResponse httpServletResponse;

    public ServletHttpResponder(
        HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) {
      this.httpServletRequest = httpServletRequest;
      this.httpServletResponse = httpServletResponse;
    }

    @Override
    public void respond(
        final Request request, final Response response, Map attributes) {
      if (Thread.currentThread().isInterrupted()) {
        return;
      }

      httpServletRequest.setAttribute(ORIGINAL_REQUEST_KEY, LoggedRequest.createFrom(request));
      attributes.forEach(httpServletRequest::setAttribute);

      if (isAsyncSupported(response, httpServletRequest)) {
        respondAsync(request, response);
      } else {
        respondSync(request, response);
      }
    }

    private void respondSync(Request request, Response response) {
      delayIfRequired(response.getInitialDelay());
      respondTo(request, response);
    }

    private void delayIfRequired(long delayMillis) {
      try {
        MILLISECONDS.sleep(delayMillis);
      } catch (InterruptedException e) {
        Thread.currentThread().interrupt();
      }
    }

    private boolean isAsyncSupported(Response response, HttpServletRequest httpServletRequest) {
      return scheduledExecutorService != null
          && response.getInitialDelay() > 0
          && httpServletRequest.isAsyncSupported();
    }

    private void respondAsync(final Request request, final Response response) {
      final AsyncContext asyncContext = httpServletRequest.startAsync();
      scheduledExecutorService.schedule(
          () -> {
            try {
              respondTo(request, response);
            } finally {
              asyncContext.complete();
            }
          },
          response.getInitialDelay(),
          MILLISECONDS);
    }

    private void respondTo(Request request, Response response) {
      try {
        if (response.wasConfigured()) {
          applyResponse(response, httpServletRequest, httpServletResponse);
        } else if (request.getMethod().equals(GET) && shouldForwardToFilesContext) {
          forwardToFilesContext(httpServletRequest, httpServletResponse, request);
        } else {
          httpServletResponse.sendError(HTTP_NOT_FOUND);
        }
      } catch (Exception e) {
        throwUnchecked(e);
      }
    }
  }

  public void applyResponse(
      Response response,
      HttpServletRequest httpServletRequest,
      HttpServletResponse httpServletResponse) {
    Fault fault = response.getFault();
    if (fault != null) {
      FaultInjector faultInjector = buildFaultInjector(httpServletRequest, httpServletResponse);
      fault.apply(faultInjector);
      httpServletResponse.addHeader(Fault.class.getName(), fault.name());
      return;
    }

    if (response.getStatusMessage() == null) {
      httpServletResponse.setStatus(response.getStatus());
    } else {
      utils.setStatusWithReason(
          response.getStatus(), response.getStatusMessage(), httpServletResponse);
    }

    for (HttpHeader header : response.getHeaders().all()) {
      for (String value : header.values()) {
        httpServletResponse.addHeader(header.key(), value);
      }
    }

    if ((chunkedEncodingPolicy == NEVER
            || (chunkedEncodingPolicy == BODY_FILE && response.hasInlineBody()))
        && httpServletResponse.getHeader(CONTENT_LENGTH) == null) {
      httpServletResponse.setContentLength(response.getBody().length);
    }

    if (response.shouldAddChunkedDribbleDelay()) {
      writeAndTranslateExceptionsWithChunkedDribbleDelay(
          httpServletResponse, response.getBodyStream(), response.getChunkedDribbleDelay());
    } else {
      writeAndTranslateExceptions(httpServletResponse, response.getBodyStream());
    }
  }

  private FaultInjector buildFaultInjector(
      HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) {
    return faultHandlerFactory.buildFaultInjector(httpServletRequest, httpServletResponse);
  }

  private static void writeAndTranslateExceptions(
      HttpServletResponse httpServletResponse, InputStream content) {
    try (ServletOutputStream out = httpServletResponse.getOutputStream()) {
      content.transferTo(out);
      out.flush();
    } catch (IOException e) {
      throwUnchecked(e);
    } finally {
      try {
        content.close();
      } catch (IOException e) {
        // well, we tried
      }
    }
  }

  private void writeAndTranslateExceptionsWithChunkedDribbleDelay(
      HttpServletResponse httpServletResponse,
      InputStream bodyStream,
      ChunkedDribbleDelay chunkedDribbleDelay) {
    try (ServletOutputStream out = httpServletResponse.getOutputStream()) {
      byte[] body = bodyStream.readAllBytes();

      if (body.length < 1) {
        notifier.error("Cannot chunk dribble delay when no body set");
        out.flush();
        return;
      }

      byte[][] chunkedBody = BodyChunker.chunkBody(body, chunkedDribbleDelay.getNumberOfChunks());

      int chunkInterval = chunkedDribbleDelay.getTotalDuration() / chunkedBody.length;

      for (byte[] bodyChunk : chunkedBody) {
        Thread.sleep(chunkInterval);
        out.write(bodyChunk);
        out.flush();
      }

    } catch (IOException e) {
      throwUnchecked(e);
    } catch (InterruptedException ignored) {
      // Ignore the interrupt quietly since it's probably the client timing out, which is a
      // completely valid outcome
    }
  }

  private void forwardToFilesContext(
      HttpServletRequest httpServletRequest,
      HttpServletResponse httpServletResponse,
      Request request)
      throws ServletException, IOException {
    String forwardUrl = wiremockFileSourceRoot + WireMockApp.FILES_ROOT + request.getUrl();
    RequestDispatcher dispatcher =
        httpServletRequest.getRequestDispatcher(decode(forwardUrl, UTF_8));
    dispatcher.forward(httpServletRequest, httpServletResponse);
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy