All Downloads are FREE. Search and download functionalities are using the official Maven repository.

neureka.backend.main.memory.MemUtil Maven / Gradle / Ivy

The newest version!
package neureka.backend.main.memory;

import neureka.Neureka;
import neureka.Tensor;

import java.util.Arrays;
import java.util.List;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
 *  Utility methods for deleting tensors or preventing thereof.
 *  In essence, it exposes convenience methods for setting and resetting
 *  the {@link Tensor#isIntermediate} flag or supplied tensors...
 *  This is an internal library class which should not be used
 *  anywhere but in Neurekas backend.
 *  Do not use this anywhere else!
 */
public class MemUtil {

    private MemUtil() {}

    /**
     *  This method will try to delete the provided array of tensors
     *  if the tensors are not important computation
     *  graph components (like derivatives for example).
     *
     * @param tensors The tensors which should be deleted if possible.
     */
    public static void autoDelete( Tensor... tensors ) {
        /*
             When we are purely in the JVM world, then the garbage
             collector will take care of freeing our memory,
             and we don't really have a saying in when something gets collected...
             However, this is different for native memory (for example the GPU memory)!
             In that case we can manually free up the data array of a tensor.
             The code below will delete intermediate tensors which are expected
             to be no longer used.
        */
        if ( Neureka.get().settings().debug().isDeletingIntermediateTensors() ) {
            for ( Tensor t : tensors ) {
                // Tensors flagged as 'intermediate' will automatically be deleted!
                if ( !t.isDeleted() && t.isIntermediate() ) {
                    if (
                        t.getGraphNode()
                        .map(n->!n.isUsedAsDerivative())
                        .orElse(true) // No graph, we can delete it!
                    )
                        t.mut().delete();
                }
            }
        }
    }

    /**
     *  This method makes sure that the provided tensors do not get deleted
     *  by setting the {@link Tensor#isIntermediate} flag to off
     *  during the execution of the provided {@link Supplier} lambda!
     *  In said lambda the supplied thing will ultimately be returned by
     *  this method...
     *  All provided tensors will have the {@link Tensor#isIntermediate} flag
     *  set to their original state after execution.
     *
     * @param tensors An array of tensors which should not be deleted during the execution of the supplied lambda.
     * @param during A lambda producing a result during which the provided tensors should not be deleted.
     * @param  The type of the result produced by the provided lambda.
     * @return The result produced by the provided lambda.
     */
    public static  T keep(Tensor[] tensors, Supplier during ) {
        List> doNotDelete = Arrays.stream(tensors).filter(Tensor::isIntermediate).collect(Collectors.toList());
        doNotDelete.forEach( t -> t.mut().setIsIntermediate( false ) );
        T result = during.get();
        // After having calculated the result we allow deletion of the provided tensors again:
        doNotDelete.forEach( t -> t.mut().setIsIntermediate( true ) );
        return result;
    }

    /**
     *  This method makes sure that the provided tensors do not get deleted
     *  by setting the {@link Tensor#isIntermediate} flag to off
     *  during the execution of the provided {@link Supplier} lambda!
     *  In said lambda the supplied thing will ultimately be returned by
     *  this method...
     *  Both of the provided tensors will have the {@link Tensor#isIntermediate} flag
     *  set to their original state after execution.
     *
     * @param a The first tensor which should not be deleted during the execution of the provided lambda.
     * @param b The second tensor which should not be deleted during the execution of the provided lambda.
     * @param during A lambda producing a result during whose execution the first to arguments should not be deleted.
     * @param  The type of the result produced by the provided lambda.
     * @return The result produced by the provided lambda.
     */
    public static  T keep(Tensor a, Tensor b, Supplier during ) {
        List> doNotDelete = Stream.of(a, b).filter(Tensor::isIntermediate).collect(Collectors.toList());
        doNotDelete.forEach( t -> t.mut().setIsIntermediate( false ) );
        T result = during.get();
        // After having calculated the result we allow deletion of the provided tensors again:
        doNotDelete.forEach( t -> t.mut().setIsIntermediate( true ) );
        return result;
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy