All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.yahoo.tensor.functions.Generate Maven / Gradle / Ivy

// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.functions;

import com.yahoo.tensor.DimensionSizes;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;

import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Function;

/**
 * An indexed tensor whose values are generated by a function
 *
 * @author bratseth
 */
public class Generate extends PrimitiveTensorFunction {

    private final TensorType type;

    // One of these are null
    private final Function, Double> freeGenerator;
    private final ScalarFunction boundGenerator;

    /** The same as Generate.free */
    public Generate(TensorType type, Function, Double> generator) {
        this(type, Objects.requireNonNull(generator), null);
    }

    /**
     * Creates a generated tensor from a free function
     *
     * @param type the type of the tensor
     * @param generator the function generating values from a list of numbers specifying the indexes of the
     *                  tensor cell which will receive the value
     * @throws IllegalArgumentException if any of the tensor dimensions are not indexed bound
     */
    public static  Generate free(TensorType type, Function, Double> generator) {
        return new Generate<>(type, Objects.requireNonNull(generator), null);
    }

    /**
     * Creates a generated tensor from a bound function
     *
     * @param type the type of the tensor
     * @param generator the function generating values from a list of numbers specifying the indexes of the
     *                  tensor cell which will receive the value
     * @throws IllegalArgumentException if any of the tensor dimensions are not indexed bound
     */
    public static  Generate bound(TensorType type, ScalarFunction generator) {
        return new Generate<>(type, null, Objects.requireNonNull(generator));
    }

    private Generate(TensorType type, Function, Double> freeGenerator, ScalarFunction boundGenerator) {
        Objects.requireNonNull(type, "The argument tensor type cannot be null");
        validateType(type);
        this.type = type;
        this.freeGenerator = freeGenerator;
        this.boundGenerator = boundGenerator;
    }

    private void validateType(TensorType type) {
        for (TensorType.Dimension dimension : type.dimensions())
            if (dimension.type() != TensorType.Dimension.Type.indexedBound)
                throw new IllegalArgumentException("A generated tensor can only have indexed bound dimensions");
    }

    @Override
    public List> arguments() {
        return boundGenerator != null && boundGenerator.asTensorFunction().isPresent()
               ? List.of(boundGenerator.asTensorFunction().get())
               : List.of();
    }

    @Override
    public TensorFunction withArguments(List> arguments) {
        if ( arguments.size() > 1)
            throw new IllegalArgumentException("Generate must have 0 or 1 arguments, got " + arguments.size());
        if (arguments.isEmpty()) return this;

        if (arguments.get(0).asScalarFunction().isEmpty())
            throw new IllegalArgumentException("The argument to generate must be convertible to a tensor function, " +
                                               "but got " + arguments.get(0));

        return new Generate<>(type, null, arguments.get(0).asScalarFunction().get());
    }

    @Override
    public PrimitiveTensorFunction toPrimitive() { return this; }

    @Override
    public TensorType type(TypeContext context) { return type; }

    @Override
    public Tensor evaluate(EvaluationContext context) {
        Tensor.Builder builder = Tensor.Builder.of(type);
        IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(dimensionSizes(type));
        GenerateEvaluationContext generateContext = new GenerateEvaluationContext(type, context);
        for (int i = 0; i < indexes.size(); i++) {
            indexes.next();
            builder.cell(generateContext.apply(indexes), indexes.indexesForReading());
        }
        return builder.build();
    }

    private DimensionSizes dimensionSizes(TensorType type) {
        DimensionSizes.Builder b = new DimensionSizes.Builder(type.dimensions().size());
        for (int i = 0; i < b.dimensions(); i++)
            b.set(i, type.dimensions().get(i).size().get());
        return b.build();
    }

    @Override
    public String toString(ToStringContext context) { return type + "(" + generatorToString(context) + ")"; }

    private String generatorToString(ToStringContext context) {
        if (freeGenerator != null)
            return freeGenerator.toString();
        else
            return boundGenerator.toString(new GenerateToStringContext(context));
    }

    @Override
    public int hashCode() { return Objects.hash("generate", type, freeGenerator, boundGenerator); }

    /**
     * A context for generating all the values of a tensor produced by evaluating Generate.
     * This returns all the current index values as variables and falls back to delivering from the given
     * evaluation context.
     */
    private class GenerateEvaluationContext implements EvaluationContext {

        private final TensorType type;
        private final EvaluationContext context;

        private IndexedTensor.Indexes indexes;

        GenerateEvaluationContext(TensorType type, EvaluationContext context) {
            this.type = type;
            this.context = context;
        }

        double apply(IndexedTensor.Indexes indexes) {
            if (freeGenerator != null) {
                return freeGenerator.apply(indexes.toList());
            }
            else {
                this.indexes = indexes;
                return boundGenerator.apply(this);
            }
        }

        @Override
        public Tensor getTensor(String name) {
            Optional index = type.indexOfDimension(name);
            if (index.isPresent()) // this is the name of a dimension
                return Tensor.from(indexes.indexesForReading()[index.get()]);
            else
                return context.getTensor(name);
        }

        @Override
        public TensorType getType(NAMETYPE name) {
            Optional index = type.indexOfDimension(name.name());
            if (index.isPresent()) // this is the name of a dimension
                return TensorType.empty;
            else
                return context.getType(name);
        }

        @Override
        public TensorType getType(String name) {
            Optional index = type.indexOfDimension(name);
            if (index.isPresent()) // this is the name of a dimension
                return TensorType.empty;
            else
                return context.getType(name);
        }

        @Override
        public String resolveBinding(String name) {
            return context.resolveBinding(name);
        }

    }

    /** A context which adds the bindings of the generate dimension names to the given context. */
    private class GenerateToStringContext implements ToStringContext {

        private final ToStringContext context;

        public GenerateToStringContext(ToStringContext context) {
            this.context = context;
        }

        @Override
        public String getBinding(String identifier) {
            if (type.dimension(identifier).isPresent())
                return identifier; // dimension names are bound but not substituted in the generate context
            else
                return context.getBinding(identifier);
        }

        @Override
        public ToStringContext parent() { return context; }

    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy