All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.yahoo.tensor.functions.Slice Maven / Gradle / Ivy

// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.functions;

import com.yahoo.api.annotations.Beta;
import com.yahoo.tensor.PartialAddress;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.TypeResolver;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;

/**
 * Returns a subspace of a tensor
 *
 * @author bratseth
 */
@Beta
public class Slice extends PrimitiveTensorFunction {

    private final TensorFunction argument;
    private final List> subspaceAddress;

    /**
     * Creates a value function
     *
     * @param argument the tensor to return a cell value from
     * @param subspaceAddress a description of the address of the cell to return the value of. This is not a TensorAddress
     *                        because those require a type, but a type is not resolved until this is evaluated
     */
    public Slice(TensorFunction argument, List> subspaceAddress) {
        this.argument = Objects.requireNonNull(argument, "Argument cannot be null");
        if (subspaceAddress.size() > 1 && subspaceAddress.stream().anyMatch(c -> c.dimension().isEmpty()))
            throw new IllegalArgumentException("Short form of subspace addresses is only supported with a single dimension: " +
                                               "Specify dimension names explicitly instead");
        this.subspaceAddress = subspaceAddress;
    }

    @Override
    public List> arguments() { return List.of(argument); }

    public List> selectorFunctions() {
        var result = new ArrayList>();
        for (var dimVal : subspaceAddress) {
            dimVal.index().flatMap(ScalarFunction::asTensorFunction).ifPresent(result::add);
        }
        return result;
    }

    public TensorFunction withTransformedFunctions(
            Function, ScalarFunction> transformer)
    {
        List> transformedAddress = new ArrayList<>();
        for (var orig : subspaceAddress) {
            var idxFun = orig.index();
            if (idxFun.isPresent()) {
                var transformed = transformer.apply(idxFun.get());
                transformedAddress.add(new DimensionValue(orig.dimension(), transformed));
            } else {
                transformedAddress.add(orig);
            }
        }
        return new Slice<>(argument, transformedAddress);
    }

    @Override
    public Slice withArguments(List> arguments) {
        if (arguments.size() != 1)
            throw new IllegalArgumentException("Value takes exactly one argument but got " + arguments.size());
        return new Slice<>(arguments.get(0), subspaceAddress);
    }

    @Override
    public PrimitiveTensorFunction toPrimitive() { return this; }

    @Override
    public Tensor evaluate(EvaluationContext context) {
        Tensor tensor = argument.evaluate(context);
        TensorType resultType = resultType(tensor.type());

        PartialAddress subspaceAddress = subspaceToAddress(tensor.type(), context);
        if (resultType.rank() == 0) { // shortcut common case
            return Tensor.from(tensor.get(subspaceAddress.asAddress(tensor.type())));
        }

        Tensor.Builder b = Tensor.Builder.of(resultType);
        for (Iterator i = tensor.cellIterator(); i.hasNext(); ) {
            Tensor.Cell cell = i.next();
            if (matches(subspaceAddress, cell.getKey(), tensor.type()))
                b.cell(remaining(resultType, cell.getKey(), tensor.type()), cell.getValue());
        }
        return b.build();
    }

    private PartialAddress subspaceToAddress(TensorType type, EvaluationContext context) {
        PartialAddress.Builder b = new PartialAddress.Builder(subspaceAddress.size());
        for (int i = 0; i < subspaceAddress.size(); i++) {
            if (subspaceAddress.get(i).label().isPresent())
                b.add(subspaceAddress.get(i).dimension().orElse(type.dimensions().get(i).name()),
                      subspaceAddress.get(i).label().get());
            else
                b.add(subspaceAddress.get(i).dimension().orElse(type.dimensions().get(i).name()),
                      subspaceAddress.get(i).index().get().apply(context).intValue());
        }
        return b.build();
    }

    private boolean matches(PartialAddress subspaceAddress,
                            TensorAddress address, TensorType type) {
        for (int i = 0; i < subspaceAddress.size(); i++) {
            String label = address.label(type.indexOfDimension(subspaceAddress.dimension(i)).get());
            if ( ! label.equals(subspaceAddress.label(i)))
                return false;
        }
        return true;
    }

    /** Returns the subset of the given address which is present in the subspace type */
    private TensorAddress remaining(TensorType subspaceType, TensorAddress address, TensorType type) {
        TensorAddress.Builder b = new TensorAddress.Builder(subspaceType);
        for (int i = 0; i < address.size(); i++) {
            String dimension = type.dimensions().get(i).name();
            if (subspaceType.dimension(type.dimensions().get(i).name()).isPresent())
                b.add(dimension, address.numericLabel(i));
        }
        return b.build();
    }

    @Override
    public TensorType type(TypeContext context) {
        return resultType(argument.type(context));
    }

    private List findDimensions(List dims, Predicate pred) {
        return dims.stream().filter(pred).map(TensorType.Dimension::name).toList();
    }

    private TensorType resultType(TensorType argumentType) {
        List peekDimensions;
        if (subspaceAddress.size() == 1 && subspaceAddress.get(0).dimension().isEmpty()) {
            // Special case where a single indexed or mapped dimension is sliced
            if (subspaceAddress.get(0).index().isPresent()) {
                peekDimensions = findDimensions(argumentType.dimensions(), TensorType.Dimension::isIndexed);
                if (peekDimensions.size() > 1) {
                    throw new IllegalArgumentException(this + " slices a single indexed dimension, cannot be applied " +
                                                       "to " + argumentType + ", which has multiple");
                }
            }
            else {
                peekDimensions = findDimensions(argumentType.dimensions(), TensorType.Dimension::isMapped);
                if (peekDimensions.size() > 1)
                    throw new IllegalArgumentException(this + " slices a single mapped dimension, cannot be applied " +
                                                       "to " + argumentType + ", which has multiple");
            }
        }
        else { // general slicing
            peekDimensions = subspaceAddress.stream().map(d -> d.dimension().get()).toList();
        }
        try {
            return TypeResolver.peek(argumentType, peekDimensions);
        }
        catch (IllegalArgumentException e) {
            throw new IllegalArgumentException(this + " cannot slice type " + argumentType, e);
        }
    }

    @Override
    public String toString(ToStringContext context) {
        StringBuilder b = new StringBuilder(argument.toString(context));
        if (context.typeContext().isEmpty()
            && subspaceAddress.size() == 1 && subspaceAddress.get(0).dimension().isEmpty()) { // use short forms
            if (subspaceAddress.get(0).index().isPresent())
                b.append("[").append(subspaceAddress.get(0).index().get().toString(context)).append("]");
            else
                b.append("{").append(subspaceAddress.get(0).label().get()).append("}");
        }
        else { // general form
            b.append("{").append(subspaceAddress.stream()
                                                .map(i -> i.toString(context, this))
                                                .collect(Collectors.joining(", "))).append("}");
        }
        return b.toString();
    }

    @Override
    public int hashCode() { return Objects.hash("slice", argument, subspaceAddress); }

    public static class DimensionValue  {

        private final Optional dimension;

        /** The label of this, or null if index is set */
        private final String label;

        /** The function returning the index of this, or null if label is set */
        private final ScalarFunction index;

        public DimensionValue(String dimension, String label) {
            this(Optional.of(dimension), label, null);
        }

        public DimensionValue(String dimension, int index) {
            this(Optional.of(dimension), null, new ConstantIntegerFunction<>(index));
        }

        public DimensionValue(int index) {
            this(Optional.empty(), null, new ConstantIntegerFunction<>(index));
        }

        public DimensionValue(String label) {
            this(Optional.empty(), label, null);
        }

        public DimensionValue(ScalarFunction index) {
            this(Optional.empty(), null, index);
        }

        public DimensionValue(Optional dimension, String label) {
            this(dimension, label, null);
        }

        public DimensionValue(Optional dimension, ScalarFunction index) {
            this(dimension, null, index);
        }

        public DimensionValue(String dimension, ScalarFunction index) {
            this(Optional.of(dimension), null, index);
        }

        private DimensionValue(Optional dimension, String label, ScalarFunction index) {
            this.dimension = dimension;
            this.label = label;
            this.index = index;
        }

        /**
         * Returns the given name of the dimension, or null if dense form is used, such that name
         * must be inferred from order
         */
        public Optional dimension() { return dimension; }

        /** Returns the label for this dimension or empty if it is provided by an index function */
        public Optional label() { return Optional.ofNullable(label); }

        /** Returns the index expression for this dimension, or empty if it is not a number */
        public Optional> index() { return Optional.ofNullable(index); }

        @Override
        public String toString() {
            return toString(null, null);
        }

        String toString(ToStringContext context, Slice owner) {
            StringBuilder b = new StringBuilder();
            Optional dimensionName = dimension;
            if (context != null && dimensionName.isEmpty()) { // This isn't just toString(): Output canonical form or fail
                TensorType type = context.typeContext().isPresent() ? owner.argument.type(context.typeContext().get()) : null;
                if (type == null || type.dimensions().size() != 1)
                    throw new IllegalArgumentException("The tensor dimension name being sliced by " + owner +
                                                       " cannot be uniquely resolved. Use the full form: " +
                                                       "'slice{myDimensionName:" + valueToString(context) + "}'");
                else
                    dimensionName = Optional.of(type.dimensions().get(0).name());
            }
            dimensionName.ifPresent(d -> b.append(d).append(":"));
            b.append(valueToString(context));
            return b.toString();
        }

        private String valueToString(ToStringContext context) {
            if (label != null) {
                return TensorAddress.labelToString(label);
            } else {
                return index.toString(context);
            }
        }

        @Override
        public int hashCode() { return Objects.hash(dimension, label, index); }


    }

    private static class ConstantIntegerFunction implements ScalarFunction {

        private final int value;

        public ConstantIntegerFunction(int value) {
            this.value = value;
        }

        @Override
        public Double apply(EvaluationContext context) {
            return (double)value;
        }

        @Override
        public String toString() { return String.valueOf(value); }

        @Override
        public int hashCode() { return Objects.hash("constantIntegerFunction", value); }

    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy