All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.yahoo.tensor.functions.MapSubspaces Maven / Gradle / Ivy

Go to download

Library for use in Java components of Vespa. Shared code which do not fit anywhere else.

There is a newer version: 8.441.21
Show newest version
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.functions;

import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.TypeResolver;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;

import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.List;
import java.util.Map;
import java.util.Objects;

/**
 * The map_subspaces tensor function transforms each dense subspace in a (mixed) tensor
 *
 * @author arnej
 */
public class MapSubspaces extends PrimitiveTensorFunction {

    private final TensorFunction argument;
    private final DenseSubspaceFunction function;

    private MapSubspaces(TensorFunction argument, DenseSubspaceFunction function) {
        this.argument = argument;
        this.function = function;
    }
    public MapSubspaces(TensorFunction argument, String functionArg, TensorFunction function) {
        this(argument, new DenseSubspaceFunction<>(functionArg, function));
        Objects.requireNonNull(argument, "The argument cannot be null");
        Objects.requireNonNull(functionArg, "The functionArg cannot be null");
        Objects.requireNonNull(function, "The function cannot be null");
    }

    private TensorType outputType(TensorType inputType) {
        var m = inputType.mappedSubtype();
        var d = function.outputType(inputType.indexedSubtype());
        if (m.rank() == 0) {
            return d;
        }
        if (d.rank() == 0) {
            return TypeResolver.map(m); // decay cell type
        }
        TensorType.Value cellType = d.valueType();
        Map dims = new HashMap<>();
        for (var dim : m.dimensions()) {
            dims.put(dim.name(), dim);
        }
        for (var dim : d.dimensions()) {
            var old = dims.put(dim.name(), dim);
            if (old != null) {
                throw new IllegalArgumentException("dimension name collision in map_subspaces: " + m + " vs " + d);
            }
        }
        return new TensorType(cellType, dims.values());
    }

    public TensorFunction argument() { return argument; }

    @Override
    public List> arguments() { return List.of(argument); }

    @Override
    public TensorFunction withArguments(List> arguments) {
        if ( arguments.size() != 1)
            throw new IllegalArgumentException("MapSubspaces must have 1 argument, got " + arguments.size());
        return new MapSubspaces(arguments.get(0), function);
    }

    @Override
    public PrimitiveTensorFunction toPrimitive() {
        return new MapSubspaces<>(argument.toPrimitive(), function);
    }

    @Override
    public TensorType type(TypeContext context) {
        return outputType(argument.type(context));
    }

    @Override
    public Tensor evaluate(EvaluationContext context) {
        Tensor input = argument().evaluate(context);
        TensorType inputType = input.type();
        TensorType inputTypeMapped = inputType.mappedSubtype();
        TensorType inputTypeDense = inputType.indexedSubtype();
        Map builders = new HashMap<>();
        for (Iterator iter = input.cellIterator(); iter.hasNext(); ) {
            var cell = iter.next();
            var fullAddr = cell.getKey();
            var mapAddrBuilder = new TensorAddress.Builder(inputTypeMapped);
            var idxAddrBuilder = new TensorAddress.Builder(inputTypeDense);
            for (int i = 0; i < inputType.dimensions().size(); i++) {
                var dim = inputType.dimensions().get(i);
                if (dim.isMapped()) {
                    mapAddrBuilder.add(dim.name(), fullAddr.numericLabel(i));
                } else {
                    idxAddrBuilder.add(dim.name(), fullAddr.numericLabel(i));
                }
            }
            var mapAddr = mapAddrBuilder.build();
            var builder = builders.computeIfAbsent(mapAddr, k -> Tensor.Builder.of(inputTypeDense));
            var idxAddr = idxAddrBuilder.build();
            builder.cell(idxAddr, cell.getValue());
        }
        TensorType outputType = outputType(input.type());
        TensorType denseOutputType = outputType.indexedSubtype();
        var denseOutputDims = denseOutputType.dimensions();
        Tensor.Builder builder = Tensor.Builder.of(outputType);
        for (var entry : builders.entrySet()) {
            TensorAddress mappedAddr = entry.getKey();
            Tensor denseInput = entry.getValue().build();
            Tensor denseOutput = function.map(denseInput);
            // XXX check denseOutput.type().dimensions()
            for (Iterator iter = denseOutput.cellIterator(); iter.hasNext(); ) {
                var cell = iter.next();
                var denseAddr = cell.getKey();
                var addrBuilder = new TensorAddress.Builder(outputType);
                for (int i = 0; i < inputTypeMapped.dimensions().size(); i++) {
                    var dim = inputTypeMapped.dimensions().get(i);
                    addrBuilder.add(dim.name(), mappedAddr.numericLabel(i));
                }
                for (int i = 0; i < denseOutputDims.size(); i++) {
                    var dim = denseOutputDims.get(i);
                    addrBuilder.add(dim.name(), denseAddr.numericLabel(i));
                }
                builder.cell(addrBuilder.build(), cell.getValue());
            }
        }
        return builder.build();
    }

    @Override
    public String toString(ToStringContext context) {
        return "map_subspaces(" + argument.toString(context) + ", " + function + ")";
    }

    @Override
    public int hashCode() { return Objects.hash("map_subspaces", argument, function); }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy