All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ca.uhn.fhir.test.utilities.server.BaseJettyServerExtension Maven / Gradle / Ivy

/*-
 * #%L
 * HAPI FHIR Test Utilities
 * %%
 * Copyright (C) 2014 - 2024 Smile CDR, Inc.
 * %%
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * #L%
 */
package ca.uhn.fhir.test.utilities.server;

import ca.uhn.fhir.rest.api.Constants;
import ca.uhn.fhir.test.utilities.JettyUtil;
import jakarta.annotation.PreDestroy;
import jakarta.servlet.DispatcherType;
import jakarta.servlet.Filter;
import jakarta.servlet.FilterChain;
import jakarta.servlet.FilterConfig;
import jakarta.servlet.ServletException;
import jakarta.servlet.ServletRequest;
import jakarta.servlet.ServletResponse;
import jakarta.servlet.http.HttpServlet;
import jakarta.servlet.http.HttpServletRequest;
import org.apache.commons.lang3.Validate;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClientBuilder;
import org.apache.http.impl.conn.PoolingHttpClientConnectionManager;
import org.eclipse.jetty.ee10.servlet.FilterHolder;
import org.eclipse.jetty.ee10.servlet.ServletContextHandler;
import org.eclipse.jetty.ee10.servlet.ServletHolder;
import org.eclipse.jetty.ee10.websocket.jakarta.server.config.JakartaWebSocketServletContainerInitializer;
import org.eclipse.jetty.io.Connection;
import org.eclipse.jetty.io.Connection.Listener;
import org.eclipse.jetty.server.Connector;
import org.eclipse.jetty.server.Handler;
import org.eclipse.jetty.server.HttpConnectionFactory;
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.ServerConnector;
import org.junit.jupiter.api.extension.AfterAllCallback;
import org.junit.jupiter.api.extension.AfterEachCallback;
import org.junit.jupiter.api.extension.BeforeEachCallback;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.context.annotation.AnnotatedBeanDefinitionReader;
import org.springframework.web.context.support.GenericWebApplicationContext;
import org.springframework.web.servlet.DispatcherServlet;
import org.springframework.web.socket.config.annotation.WebSocketConfigurer;

import java.io.IOException;
import java.util.ArrayList;
import java.util.EnumSet;
import java.util.Enumeration;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;

import static org.apache.commons.lang3.StringUtils.defaultString;
import static org.apache.commons.lang3.StringUtils.isNotBlank;

public abstract class BaseJettyServerExtension> implements BeforeEachCallback, AfterEachCallback, AfterAllCallback {
	private static final Logger ourLog = LoggerFactory.getLogger(BaseJettyServerExtension.class);
	private final List> myRequestHeaders = new ArrayList<>();
	private final List myRequestContentTypes = new ArrayList<>();
	private final List myServletFilters = new ArrayList<>();
	private String myServletPath = "/*";
	private Server myServer;
	private CloseableHttpClient myHttpClient;
	private int myPort = 0;
	private boolean myKeepAliveBetweenTests;
	private String myContextPath = "";
	private AtomicLong myConnectionsOpenedCounter;
	private Class myEnableSpringWebsocketSupport;
	private String myEnableSpringWebsocketContextPath;
	private long myIdleTimeoutMillis = 30000;
	private final List> myBeforeStartServerConsumers = new ArrayList<>();

	/**
	 * Sets the Jetty server "idle timeout" in millis. This is the amount of time that
	 * the HTTP processor will allow a request to take before it hangs up on the
	 * client. This means the amount of time receiving the request over the network or
	 * streaming the response, not the amount of time spent actually processing the
	 * request (ie this is a network timeout, not a CPU timeout). Default is
	 * 30000.
	 */
	public T withIdleTimeout(long theIdleTimeoutMillis) {
		Validate.isTrue(myServer == null, "Server is already started");
		myIdleTimeoutMillis = theIdleTimeoutMillis;
		return (T) this;
	}

	@SuppressWarnings("unchecked")
	public T withContextPath(String theContextPath) {
		Validate.isTrue(myServer == null, "Server is already started");
		myContextPath = defaultString(theContextPath);
		return (T) this;
	}

	public T withServletFilter(Filter theFilter) {
		Validate.isTrue(myServer == null, "Server is already started");
		Validate.notNull(theFilter, "theFilter must not be null");
		myServletFilters.add(theFilter);
		return (T) this;
	}

	@SuppressWarnings("unchecked")
	public T withServerBeforeStarted(Consumer theConsumer) {
		Validate.isTrue(myServer == null, "Server is already started");
		Validate.notNull(theConsumer, "theConsumer must not be null");
		myBeforeStartServerConsumers.add(theConsumer);
		return (T) this;
	}


	/**
	 * Returns the total number of connections that this server has received. This
	 * is not the current number of open connections, it's the number of new
	 * connections that have been opened at any point.
	 */
	public long getConnectionsOpenedCount() {
		return myConnectionsOpenedCounter.get();
	}

	public void resetConnectionsOpenedCount() {
		myConnectionsOpenedCounter.set(0);
	}

	public CloseableHttpClient getHttpClient() {
		return myHttpClient;
	}

	public List getRequestContentTypes() {
		return myRequestContentTypes;
	}

	public List> getRequestHeaders() {
		return myRequestHeaders;
	}

	@PreDestroy
	public void stopServer() throws Exception {
		if (!isRunning()) {
			return;
		}
		JettyUtil.closeServer(myServer);
		myServer = null;

		myHttpClient.close();
		myHttpClient = null;
	}

	protected void startServer() throws Exception {
		if (isRunning()) {
			return;
		}

		myServer = new Server();
		myConnectionsOpenedCounter = new AtomicLong(0);

		ServerConnector connector = new ServerConnector(myServer);
		connector.setIdleTimeout(myIdleTimeoutMillis);
		connector.setPort(myPort);
		myServer.setConnectors(new Connector[]{connector});

		HttpConnectionFactory connectionFactory = (HttpConnectionFactory) connector.getConnectionFactories().iterator().next();
		connectionFactory.addBean(new Listener() {
			@Override
			public void onOpened(Connection connection) {
				myConnectionsOpenedCounter.incrementAndGet();
			}

			@Override
			public void onClosed(Connection connection) {
				// nothing
			}
		});

		ServletHolder servletHolder = new ServletHolder(provideServlet());

		List handlerList = new ArrayList<>();

		ServletContextHandler contextHandler = new ServletContextHandler();
		contextHandler.setContextPath(myContextPath);
		contextHandler.addServlet(servletHolder, myServletPath);
		contextHandler.addFilter(new FilterHolder(requestCapturingFilter()), "/*", EnumSet.allOf(DispatcherType.class));
		for (Filter next : myServletFilters) {
			contextHandler.addFilter(new FilterHolder(next), "/*", EnumSet.allOf(DispatcherType.class));
		}
		handlerList.add(contextHandler);

		if (myEnableSpringWebsocketSupport != null) {

			GenericWebApplicationContext wac = new GenericWebApplicationContext();
			wac.setParent(SpringContextGrabbingTestExecutionListener.getApplicationContext());
			AnnotatedBeanDefinitionReader reader = new AnnotatedBeanDefinitionReader(wac);
			reader.register(myEnableSpringWebsocketSupport);

			DispatcherServlet dispatcherServlet = new DispatcherServlet();
			dispatcherServlet.setApplicationContext(wac);
			ServletHolder subsServletHolder = new ServletHolder();
			subsServletHolder.setServlet(dispatcherServlet);

			ServletContextHandler servletContextHandler = new ServletContextHandler();
			servletContextHandler.setContextPath(myEnableSpringWebsocketContextPath);
			servletContextHandler.setAllowNullPathInContext(true);
			servletContextHandler.addServlet(new ServletHolder(dispatcherServlet), "/*");

			JakartaWebSocketServletContainerInitializer.configure(servletContextHandler, null);

			handlerList.add(servletContextHandler);
		}

		myServer.setHandler(new Handler.Sequence(handlerList));

		for (Consumer next : myBeforeStartServerConsumers) {
			next.accept(myServer);
		}

		myServer.start();

		myPort = JettyUtil.getPortForStartedServer(myServer);
		ourLog.info("Server has started on port {}", myPort);
		PoolingHttpClientConnectionManager connectionManager = new PoolingHttpClientConnectionManager(5000, TimeUnit.MILLISECONDS);
		HttpClientBuilder builder = HttpClientBuilder.create();
		builder.setConnectionManager(connectionManager);
		myHttpClient = builder.build();
	}

	private Filter requestCapturingFilter() {
		return new RequestCapturingFilter();
	}

	public int getPort() {
		return myPort;
	}

	protected abstract HttpServlet provideServlet();


	public String getWebsocketContextPath() {
		return myEnableSpringWebsocketContextPath;
	}

	/**
	 * Should be in the format /the/path/*
	 */
	@SuppressWarnings("unchecked")
	public T withServletPath(String theServletPath) {
		Validate.isTrue(theServletPath.startsWith("/"), "Servlet path should start with /");
		Validate.isTrue(theServletPath.endsWith("/*"), "Servlet path should end with /*");
		myServletPath = theServletPath;
		return (T) this;
	}

	@SuppressWarnings("unchecked")
	public T withPort(int thePort) {
		myPort = thePort;
		return (T) this;
	}

	@SuppressWarnings("unchecked")
	public T keepAliveBetweenTests() {
		myKeepAliveBetweenTests = true;
		return (T) this;
	}

	protected boolean isRunning() {
		return myServer != null;
	}

	/**
	 * Returns the server base URL with no trailing slash
	 */
	public String getBaseUrl() {
		assert myServletPath.endsWith("/*");
		return "http://localhost:" + myPort + myContextPath + myServletPath.substring(0, myServletPath.length() - 2);
	}

	@Override
	public void beforeEach(ExtensionContext context) throws Exception {
		startServer();
		myRequestContentTypes.clear();
		myRequestHeaders.clear();
	}

	@Override
	public void afterEach(ExtensionContext context) throws Exception {
		if (!myKeepAliveBetweenTests) {
			stopServer();
		}
	}

	@Override
	public void afterAll(ExtensionContext context) throws Exception {
		stopServer();
	}

	/**
	 * To use this method, you need to add the following to your
	 * test class:
	 * @TestExecutionListeners(value = SpringContextGrabbingTestExecutionListener.class, mergeMode = TestExecutionListeners.MergeMode.MERGE_WITH_DEFAULTS)
	 */
	@SuppressWarnings("unchecked")
	public T withSpringWebsocketSupport(String theContextPath, Class theContextConfigClass) {
		assert !isRunning();
		assert theContextConfigClass != null;
		myEnableSpringWebsocketSupport = theContextConfigClass;
		myEnableSpringWebsocketContextPath = theContextPath;
		return (T) this;
	}

	private class RequestCapturingFilter implements Filter {
		@Override
		public void init(FilterConfig filterConfig) throws ServletException {
			// nothing
		}

		@Override
		public void doFilter(ServletRequest theRequest, ServletResponse theResponse, FilterChain theChain) throws IOException, ServletException {
			HttpServletRequest request = (HttpServletRequest) theRequest;

			String header = request.getHeader(Constants.HEADER_CONTENT_TYPE);
			if (isNotBlank(header)) {
				myRequestContentTypes.add(header.replaceAll(";.*", ""));
			} else {
				myRequestContentTypes.add(null);
			}

			java.util.Enumeration headerNamesEnum = request.getHeaderNames();
			List requestHeaders = new ArrayList<>();
			myRequestHeaders.add(requestHeaders);
			while (headerNamesEnum.hasMoreElements()) {
				String nextName = headerNamesEnum.nextElement();
				Enumeration valueEnum = request.getHeaders(nextName);
				while (valueEnum.hasMoreElements()) {
					String nextValue = valueEnum.nextElement();
					requestHeaders.add(nextName + ": " + nextValue);
				}
			}

			theChain.doFilter(theRequest, theResponse);
		}

		@Override
		public void destroy() {
			// nothing
		}
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy