All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.junit.jupiter.api.AssertTimeout Maven / Gradle / Ivy

There is a newer version: 5.11.3
Show newest version
/*
 * Copyright 2015-2021 the original author or authors.
 *
 * All rights reserved. This program and the accompanying materials are
 * made available under the terms of the Eclipse Public License v2.0 which
 * accompanies this distribution and is available at
 *
 * https://www.eclipse.org/legal/epl-v20.html
 */

package org.junit.jupiter.api;

import static org.junit.jupiter.api.AssertionUtils.buildPrefix;
import static org.junit.jupiter.api.AssertionUtils.nullSafeGet;
import static org.junit.jupiter.api.Assertions.fail;

import java.time.Duration;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;

import org.junit.jupiter.api.function.Executable;
import org.junit.jupiter.api.function.ThrowingSupplier;
import org.junit.platform.commons.JUnitException;
import org.junit.platform.commons.util.ExceptionUtils;
import org.opentest4j.AssertionFailedError;

/**
 * {@code AssertTimeout} is a collection of utility methods that support asserting
 * the execution of the code under test did not take longer than the timeout duration.
 *
 * @since 5.0
 */
class AssertTimeout {

	private AssertTimeout() {
		/* no-op */
	}

	static void assertTimeout(Duration timeout, Executable executable) {
		assertTimeout(timeout, executable, (String) null);
	}

	static void assertTimeout(Duration timeout, Executable executable, String message) {
		assertTimeout(timeout, () -> {
			executable.execute();
			return null;
		}, message);
	}

	static void assertTimeout(Duration timeout, Executable executable, Supplier messageSupplier) {
		assertTimeout(timeout, () -> {
			executable.execute();
			return null;
		}, messageSupplier);
	}

	static  T assertTimeout(Duration timeout, ThrowingSupplier supplier) {
		return assertTimeout(timeout, supplier, (Object) null);
	}

	static  T assertTimeout(Duration timeout, ThrowingSupplier supplier, String message) {
		return assertTimeout(timeout, supplier, (Object) message);
	}

	static  T assertTimeout(Duration timeout, ThrowingSupplier supplier, Supplier messageSupplier) {
		return assertTimeout(timeout, supplier, (Object) messageSupplier);
	}

	private static  T assertTimeout(Duration timeout, ThrowingSupplier supplier, Object messageOrSupplier) {
		long timeoutInMillis = timeout.toMillis();
		long start = System.currentTimeMillis();
		T result = null;
		try {
			result = supplier.get();
		}
		catch (Throwable ex) {
			ExceptionUtils.throwAsUncheckedException(ex);
		}

		long timeElapsed = System.currentTimeMillis() - start;
		if (timeElapsed > timeoutInMillis) {
			fail(buildPrefix(nullSafeGet(messageOrSupplier)) + "execution exceeded timeout of " + timeoutInMillis
					+ " ms by " + (timeElapsed - timeoutInMillis) + " ms");
		}
		return result;
	}

	static void assertTimeoutPreemptively(Duration timeout, Executable executable) {
		assertTimeoutPreemptively(timeout, executable, (String) null);
	}

	static void assertTimeoutPreemptively(Duration timeout, Executable executable, String message) {
		assertTimeoutPreemptively(timeout, () -> {
			executable.execute();
			return null;
		}, message);
	}

	static void assertTimeoutPreemptively(Duration timeout, Executable executable, Supplier messageSupplier) {
		assertTimeoutPreemptively(timeout, () -> {
			executable.execute();
			return null;
		}, messageSupplier);
	}

	static  T assertTimeoutPreemptively(Duration timeout, ThrowingSupplier supplier) {
		return assertTimeoutPreemptively(timeout, supplier, (Object) null);
	}

	static  T assertTimeoutPreemptively(Duration timeout, ThrowingSupplier supplier, String message) {
		return assertTimeoutPreemptively(timeout, supplier, (Object) message);
	}

	static  T assertTimeoutPreemptively(Duration timeout, ThrowingSupplier supplier,
			Supplier messageSupplier) {

		return assertTimeoutPreemptively(timeout, supplier, (Object) messageSupplier);
	}

	private static  T assertTimeoutPreemptively(Duration timeout, ThrowingSupplier supplier,
			Object messageOrSupplier) {

		AtomicReference threadReference = new AtomicReference<>();
		ExecutorService executorService = Executors.newSingleThreadExecutor(new TimeoutThreadFactory());

		try {
			Future future = executorService.submit(() -> {
				try {
					threadReference.set(Thread.currentThread());
					return supplier.get();
				}
				catch (Throwable throwable) {
					throw ExceptionUtils.throwAsUncheckedException(throwable);
				}
			});

			long timeoutInMillis = timeout.toMillis();
			try {
				return future.get(timeoutInMillis, TimeUnit.MILLISECONDS);
			}
			catch (TimeoutException ex) {
				String message = buildPrefix(nullSafeGet(messageOrSupplier)) + "execution timed out after "
						+ timeoutInMillis + " ms";

				Thread thread = threadReference.get();
				if (thread != null) {
					ExecutionTimeoutException exception = new ExecutionTimeoutException(
						"Execution timed out in thread " + thread.getName());
					exception.setStackTrace(thread.getStackTrace());
					throw new AssertionFailedError(message, exception);
				}
				else {
					throw new AssertionFailedError(message);
				}
			}
			catch (ExecutionException ex) {
				throw ExceptionUtils.throwAsUncheckedException(ex.getCause());
			}
			catch (Throwable ex) {
				throw ExceptionUtils.throwAsUncheckedException(ex);
			}
		}
		finally {
			executorService.shutdownNow();
		}
	}

	private static class ExecutionTimeoutException extends JUnitException {

		private static final long serialVersionUID = 1L;

		ExecutionTimeoutException(String message) {
			super(message);
		}
	}

	/**
	 * The thread factory used for preemptive timeout.
	 *
	 * The factory creates threads with meaningful names, helpful for debugging purposes.
	 */
	private static class TimeoutThreadFactory implements ThreadFactory {
		private static final AtomicInteger threadNumber = new AtomicInteger(1);

		public Thread newThread(Runnable r) {
			return new Thread(r, "junit-timeout-thread-" + threadNumber.getAndIncrement());
		}
	}

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy