All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.yahoo.tensor.functions.DynamicTensor Maven / Gradle / Ivy

Go to download

Library for use in Java components of Vespa. Shared code which do not fit anywhere else.

There is a newer version: 8.409.18
Show newest version
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.functions;

import com.google.common.collect.ImmutableMap;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;

import java.util.ArrayList;
import java.util.List;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;

/**
 * A function which is a tensor whose values are computed by individual lambda functions on evaluation.
 *
 * @author bratseth
 */
public abstract class DynamicTensor extends PrimitiveTensorFunction {

    private final TensorType type;

    DynamicTensor(TensorType type) {
        this.type = type;
    }

    @Override
    public TensorType type(TypeContext context) { return type; }

    @Override
    public List> arguments() { return List.of(); }

    public abstract List> cellGeneratorFunctions();

    @Override
    public TensorFunction withArguments(List> arguments) {
        if (!arguments.isEmpty())
            throw new IllegalArgumentException("Dynamic tensors must have 0 arguments, got " + arguments.size());
        return this;
    }

    @Override
    public PrimitiveTensorFunction toPrimitive() { return this; }

    TensorType type() { return type; }

    abstract String contentToString(ToStringContext context);

    @Override
    public String toString(ToStringContext context) {
        return type().toString() + ":" + contentToString(context);
    }

    /** Creates a dynamic tensor function. The cell addresses must match the type. */
    public static  DynamicTensor from(TensorType type, Map> cells) {
        return new MappedDynamicTensor<>(type, cells);
    }

    /** Creates a dynamic tensor function for a bound, indexed tensor */
    public static  DynamicTensor from(TensorType type, List> cells) {
        return new IndexedDynamicTensor<>(type, cells);
    }

    private static class MappedDynamicTensor extends DynamicTensor {

        private final ImmutableMap> cells;

        MappedDynamicTensor(TensorType type, Map> cells) {
            super(type);
            this.cells = ImmutableMap.copyOf(cells);
        }

        public List> cellGeneratorFunctions() {
            var result = new ArrayList>();
            for (var fun : cells.values()) {
                fun.asTensorFunction().ifPresent(result::add);
            }
            return result;
        }

        public TensorFunction withTransformedFunctions(
                Function, ScalarFunction> transformer)
        {
            Map> transformedCells = new LinkedHashMap<>();
            for (var orig : cells.entrySet()) {
                var transformed = transformer.apply(orig.getValue());
                transformedCells.put(orig.getKey(), transformed);
            }
            return new MappedDynamicTensor<>(type(), transformedCells);
        }

        @Override
        public Tensor evaluate(EvaluationContext context) {
            Tensor.Builder builder = Tensor.Builder.of(type());
            for (var cell : cells.entrySet())
                builder.cell(cell.getKey(), cell.getValue().apply(context));
            return builder.build();
        }

        @Override
        String contentToString(ToStringContext context) {
            if (type().dimensions().isEmpty()) {
                if (cells.isEmpty()) return "{}";
                return "{{}:" + cells.values().iterator().next().toString(context) + "}";
            }

            StringBuilder b = new StringBuilder("{");
            for (var cell : cells.entrySet()) {
                b.append(cell.getKey().toString(type())).append(":").append(cell.getValue().toString(context));
                b.append(",");
            }
            if (b.length() > 1)
                b.setLength(b.length() - 1);
            b.append("}");

            return b.toString();
        }

        @Override
        public int hashCode() { return Objects.hash("mappedDynamicTensor", type(), cells); }

    }

    private static class IndexedDynamicTensor extends DynamicTensor {

        private final List> cells;

        IndexedDynamicTensor(TensorType type, List> cells) {
            super(type);
            if ( ! type.hasOnlyIndexedBoundDimensions())
                throw new IllegalArgumentException("A dynamic tensor can only be created from a list if the type has " +
                                                   "only indexed, bound dimensions, but this has " + type);
            this.cells = List.copyOf(cells);
        }

        public List> cellGeneratorFunctions() {
            var result = new ArrayList>();
            for (var fun : cells) {
                fun.asTensorFunction().ifPresent(result::add);
            }
            return result;
        }

        public TensorFunction withTransformedFunctions(
                Function, ScalarFunction> transformer)
        {
            List> transformedCells = new ArrayList<>();
            for (var orig : cells) {
                var transformed = transformer.apply(orig);
                transformedCells.add(transformed);
            }
            return new IndexedDynamicTensor<>(type(), transformedCells);
        }

        @Override
        public Tensor evaluate(EvaluationContext context) {
            IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type());
            for (int i = 0; i < cells.size(); i++)
                builder.cellByDirectIndex(i, cells.get(i).apply(context));
            return builder.build();
        }

        @Override
        String contentToString(ToStringContext context) {
            if (type().dimensions().isEmpty()) {
                if (cells.isEmpty()) return "{}";
                return "{" + cells.get(0).toString(context) + "}";
            }

            IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(type());
            StringBuilder b = new StringBuilder("{");
            for (var cell : cells) {
                indexes.next();
                b.append(indexes.toAddress().toString(type())).append(":").append(cell.toString(context));
                b.append(",");
            }
            if (b.length() > 1)
                b.setLength(b.length() - 1);
            b.append("}");

            return b.toString();
        }

        @Override
        public int hashCode() { return Objects.hash("indexedDynamicTensor", type(), cells); }

    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy