All Downloads are FREE. Search and download functionalities are using the official Maven repository.

fr.lirmm.boreal.util.evaluator.BatchProcessor Maven / Gradle / Ivy

package fr.lirmm.boreal.util.evaluator;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collectors;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * A generic class for batch processing of input objects.
 *
 * @param           The input type.
 * @param  The output type.
 */
public class BatchProcessor {

	private final Collection inputs;
	private final Function batchTransformationFunction;
	private Integer timeout;
	private int maxParallelTasks = Runtime.getRuntime().availableProcessors() * 2;
	private BiFunction outputIfTimeout;

	protected static Logger LOG = LoggerFactory.getLogger(BatchProcessor.class);;

	/**
	 * Constructs a BatchProcessor with the given inputs and a transformation
	 * function.
	 *
	 * @param inputs                 The collection of input objects.
	 * @param transformationFunction The function to apply to each input object.
	 * @param timeout                Timeout for the task (optional)
	 * @param outputIfTimeout        Default output in case of timeout
	 */
	public BatchProcessor(Collection inputs, Function transformationFunction, Integer timeout,
			BiFunction outputIfTimeout) {
		this.inputs = inputs;
		this.batchTransformationFunction = transformationFunction;
		this.timeout = timeout != null ? timeout : EvaluatorConstants.DEFAULT_TIMEOUT;
		this.outputIfTimeout = outputIfTimeout;

	}

	/**
	 * Applies the transformation function to all input objects and returns the
	 * results.
	 *
	 * @return A list of output objects resulting from the transformation of each
	 *         input object.
	 */
	public List processBatch() {

		// Create a virtual thread executor
		ExecutorService executor = Executors.newVirtualThreadPerTaskExecutor();

		List results = new ArrayList<>();
		for (I input : inputs) {
			LOG.debug("{}",input.toString());
			Callable task = () -> batchTransformationFunction.apply(input);
			OutputType result = timeoutEval(task, executor);
			if (result == null) {
				result = outputIfTimeout.apply(input,""+timeout);
			}
			results.add(result);

		}

		LOG.debug("{}",results.toString());
		executor.shutdownNow();
		return results;

	}

	private OutputType timeoutEval(Callable task, ExecutorService executor) {

		Future future = executor.submit(task);

		try {

			return future.get(this.timeout, TimeUnit.SECONDS);

		} catch (TimeoutException e) {
			LOG.error("The task did not complete within the timeout of " + this.timeout + " (seconds)");
			future.cancel(true);
			return null;
		} catch (InterruptedException | ExecutionException e) {
			LOG.error("An error occurred during task execution.",e);
			return null;
		} finally {

		}

	}

	/**
	 * Applies the transformation function to all input objects in parallel while
	 * controlling the level of parallelization, and returns the results.
	 *
	 * @return A list of output objects resulting from the transformation of each
	 *         input object.
	 */
	public List processBatchParallel() {

		// Create a virtual thread executor
		ExecutorService executor = Executors.newVirtualThreadPerTaskExecutor();
		Semaphore semaphore = new Semaphore(maxParallelTasks);

		List> futures = inputs.stream().map(input -> CompletableFuture.supplyAsync(() -> {
			try {
				semaphore.acquire(); // Acquire a permit, blocking if necessary
				//LOG.debug("{}",input.toString());
				return batchTransformationFunction.apply(input);
			} catch (InterruptedException e) {
				Thread.currentThread().interrupt();
				return null;
			} finally {
				semaphore.release();
			}
		}, executor).completeOnTimeout(null, timeout, TimeUnit.SECONDS)).collect(Collectors.toList());

		// Join all futures and collect the results
		List results = futures.stream().map(CompletableFuture::join).collect(Collectors.toList());

		//LOG.debug("{}",results.toString());
		executor.shutdownNow();
		return results;
	}

}