All Downloads are FREE. Search and download functionalities are using the official Maven repository.

fr.lirmm.boreal.util.evaluator.BatchProcessor Maven / Gradle / Ivy

There is a newer version: 1.6.3
Show newest version
package fr.lirmm.boreal.util.evaluator;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collectors;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import fr.lirmm.boreal.util.externalHaltingConditions.ExternalAlgorithmHaltingConditions;

/**
 * A generic class for batch processing of input objects.
 *
 * @param           The input type.
 * @param  The output type.
 */
public class BatchProcessor {

	private final Collection inputs;
	private final Function batchTransformationFunction;
	private final ExternalAlgorithmHaltingConditions externalHaltingConditions;
	private final int maxParallelTasks = Runtime.getRuntime().availableProcessors() * 2;
	private final BiFunction outputIfTimeout;

	protected static final Logger LOG = LoggerFactory.getLogger(BatchProcessor.class);

	/**
	 * Constructs a BatchProcessor with the given inputs and a transformation
	 * function.
	 *
	 * @param inputs                    The collection of input objects.
	 * @param transformationFunction    The function to apply to each input object.
	 * @param externalHaltingConditions
	 * @param outputIfTimeout           Default output in case of timeout
	 */
	public BatchProcessor(Collection inputs, Function transformationFunction,
			ExternalAlgorithmHaltingConditions externalHaltingConditions,
			BiFunction outputIfTimeout) {
		this.inputs = inputs;
		this.batchTransformationFunction = transformationFunction;
		this.externalHaltingConditions = externalHaltingConditions;
		this.outputIfTimeout = outputIfTimeout;
	}

	/**
	 * Applies the transformation function to all input objects and returns the
	 * results.
	 *
	 * @return A list of output objects resulting from the transformation of each
	 *         input object.
	 */
	public List processBatch() {

		// Create a virtual thread executor
		ExecutorService executor = Executors.newVirtualThreadPerTaskExecutor();

		List results = new ArrayList<>();
		for (I input : inputs) {
			LOG.debug("{}", input.toString());
			Callable task = () -> batchTransformationFunction.apply(input);
			OutputType result = timeoutEval(task, executor);
			if (result == null) {
				result = outputIfTimeout.apply(input, "" + externalHaltingConditions.timeout().toMillis());
			}
			results.add(result);
		}

		LOG.debug("{}", results);
		executor.shutdownNow();
		return results;
	}

	private OutputType timeoutEval(Callable task, ExecutorService executor) {
		var timeout = this.externalHaltingConditions.timeout().toMillis();
		Future future = executor.submit(task);
		try {
			return future.get(timeout, TimeUnit.SECONDS);
		} catch (TimeoutException e) {
			LOG.warn("The task did not complete within the timeout of {} (seconds)", timeout);
			future.cancel(true);
			return null;
		} catch (InterruptedException | ExecutionException e) {
			LOG.error("An error occurred during task execution.",e);
			throw new RuntimeException(
					String.format(
							"[%s::timeoutEval] An error occurred during task execution: %s.",
							this.getClass(), e.getMessage()), e);
		}
	}

	/**
	 * Applies the transformation function to all input objects in parallel while
	 * controlling the level of parallelization, and returns the results.
	 *
	 * @return A list of output objects resulting from the transformation of each
	 *         input object.
	 */
	public List processBatchParallel() {
		var timeout = this.externalHaltingConditions.timeout().toMillis();

		// Create a virtual thread executor
		ExecutorService executor = Executors.newVirtualThreadPerTaskExecutor();
		Semaphore semaphore = new Semaphore(maxParallelTasks);

		List> futures = inputs.stream().map(input -> CompletableFuture.supplyAsync(() -> {
			try {
				semaphore.acquire(); // Acquire a permit, blocking if necessary
				// LOG.debug("{}",input.toString());
				return batchTransformationFunction.apply(input);
			} catch (InterruptedException e) {
				Thread.currentThread().interrupt();
				return null;
			} finally {
				semaphore.release();
			}
		}, executor).completeOnTimeout(null, timeout, TimeUnit.SECONDS)).toList();

		// Join all futures and collect the results
		List results = futures.stream().map(CompletableFuture::join).collect(Collectors.toList());

		// LOG.debug("{}",results.toString());
		executor.shutdownNow();
		return results;
	}

}