All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.yahoo.tensor.functions.Rename Maven / Gradle / Ivy

Go to download

Library for use in Java components of Vespa. Shared code which do not fit anywhere else.

There is a newer version: 8.409.18
Show newest version
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.functions;

import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.TypeResolver;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;

import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;

/**
 * The rename tensor function returns a tensor where some dimensions are assigned new names.
 *
 * @author bratseth
 */
public class Rename extends PrimitiveTensorFunction {

    private final TensorFunction argument;
    private final List fromDimensions;
    private final List toDimensions;
    private final Map fromToMap;

    public Rename(TensorFunction argument, String fromDimension, String toDimension) {
        this(argument, List.of(fromDimension), List.of(toDimension));
    }

    public Rename(TensorFunction argument, List fromDimensions, List toDimensions) {
        Objects.requireNonNull(argument, "The argument tensor cannot be null");
        Objects.requireNonNull(fromDimensions, "The 'from' dimensions cannot be null");
        Objects.requireNonNull(toDimensions, "The 'to' dimensions cannot be null");
        if (fromDimensions.isEmpty())
            throw new IllegalArgumentException("from dimensions is empty, must rename at least one dimension");
        if (fromDimensions.size() != toDimensions.size())
            throw new IllegalArgumentException("Rename from and to dimensions must be equal, was " +
                                               fromDimensions.size() + " and " + toDimensions.size());
        this.argument = argument;
        this.fromDimensions = List.copyOf(fromDimensions);
        this.toDimensions = List.copyOf(toDimensions);
        this.fromToMap = fromToMap(fromDimensions, toDimensions);
    }

    public List fromDimensions() { return fromDimensions; }
    public List toDimensions() { return toDimensions; }

    private static Map fromToMap(List fromDimensions, List toDimensions) {
        Map map = new HashMap<>();
        for (int i = 0; i < fromDimensions.size(); i++)
            map.put(fromDimensions.get(i), toDimensions.get(i));
        return map;
    }

    @Override
    public List> arguments() { return List.of(argument); }

    @Override
    public TensorFunction withArguments(List> arguments) {
        if ( arguments.size() != 1)
            throw new IllegalArgumentException("Rename must have 1 argument, got " + arguments.size());
        return new Rename<>(arguments.get(0), fromDimensions, toDimensions);
    }

    @Override
    public PrimitiveTensorFunction toPrimitive() { return this; }

    @Override
    public TensorType type(TypeContext context) {
        List resolvedFromDimensions = fromDimensions.stream().map(d -> context.resolveBinding(d)).toList();
        List resolvedToDimensions = toDimensions.stream().map(d -> context.resolveBinding(d)).toList();
        return TypeResolver.rename(argument.type(context), resolvedFromDimensions, resolvedToDimensions);
    }

    @Override
    public Tensor evaluate(EvaluationContext context) {
        Tensor tensor = argument.evaluate(context);

        TensorType renamedType = TypeResolver.rename(tensor.type(), fromDimensions, toDimensions);

        // an array which lists the index of each label in the renamed type
        int[] toIndexes = new int[tensor.type().dimensions().size()];
        for (int i = 0; i < tensor.type().dimensions().size(); i++) {
            String dimensionName = tensor.type().dimensions().get(i).name();
            String newDimensionName = fromToMap.getOrDefault(dimensionName, dimensionName);
            toIndexes[renamedType.indexOfDimension(newDimensionName).get()] = i;
        }

        // avoid building a new tensor if dimensions can simply be renamed
        if (simpleRenameIsPossible(toIndexes)) {
            return tensor.withType(renamedType);
        }

        Tensor.Builder builder = Tensor.Builder.of(renamedType);
        for (Iterator i = tensor.cellIterator(); i.hasNext(); ) {
            Map.Entry cell = i.next();
            TensorAddress renamedAddress = cell.getKey().partialCopy(toIndexes);
            builder.cell(renamedAddress, cell.getValue());
        }
        return builder.build();
    }

    /**
     * If none of the dimensions change order after rename we can do a simple rename.
     */
    private boolean simpleRenameIsPossible(int[] toIndexes) {
        for (int i = 0; i < toIndexes.length; ++i) {
            if (toIndexes[i] != i) {
                return false;
            }
        }
        return true;
    }

    private String toVectorString(List elements, ToStringContext context) {
        if (elements.size() == 1)
            return context.resolveBinding(elements.get(0));
        StringBuilder b = new StringBuilder("(");
        for (String element : elements)
            b.append(context.resolveBinding(element)).append(", ");
        b.setLength(b.length() - 2);
        b.append(")");
        return b.toString();
    }

    @Override
    public String toString(ToStringContext context) {
        return "rename(" + argument.toString(context) + ", " +
                       toVectorString(fromDimensions, context) + ", " + toVectorString(toDimensions, context) + ")";
    }

    @Override
    public int hashCode() { return Objects.hash("rename", argument, fromDimensions, toDimensions); }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy