All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.yahoo.document.json.readers.TensorModifyUpdateReader Maven / Gradle / Ivy

// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

package com.yahoo.document.json.readers;

import com.yahoo.document.Field;
import com.yahoo.document.TensorDataType;
import com.yahoo.document.datatypes.TensorFieldValue;
import com.yahoo.document.json.TokenBuffer;
import com.yahoo.document.update.TensorModifyUpdate;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;

import java.util.Iterator;

import static com.yahoo.document.json.readers.JsonParserHelpers.expectObjectStart;
import static com.yahoo.document.json.readers.TensorReader.TENSOR_BLOCKS;
import static com.yahoo.document.json.readers.TensorReader.TENSOR_CELLS;
import static com.yahoo.document.json.readers.TensorReader.readTensorBlocks;
import static com.yahoo.document.json.readers.TensorReader.readTensorCells;

/**
 * Reader of a "modify" update of a tensor field.
 */
public class TensorModifyUpdateReader {

    public static final String UPDATE_MODIFY = "modify";
    private static final String MODIFY_OPERATION = "operation";
    private static final String MODIFY_REPLACE = "replace";
    private static final String MODIFY_ADD = "add";
    private static final String MODIFY_MULTIPLY = "multiply";

    public static TensorModifyUpdate createModifyUpdate(TokenBuffer buffer, Field field) {
        expectFieldIsOfTypeTensor(field);
        expectTensorTypeHasNoIndexedUnboundDimensions(field);
        expectObjectStart(buffer.currentToken());

        ModifyUpdateResult result = createModifyUpdateResult(buffer, field);
        expectOperationSpecified(result.operation, field.getName());
        expectTensorSpecified(result.tensor, field.getName());

        return new TensorModifyUpdate(result.operation, result.tensor);
    }

    private static void expectFieldIsOfTypeTensor(Field field) {
        if ( ! (field.getDataType() instanceof TensorDataType)) {
            throw new IllegalArgumentException("A modify update can only be applied to tensor fields. " +
                                               "Field '" + field.getName() + "' is of type '" +
                                               field.getDataType().getName() + "'");
        }
    }

    private static void expectTensorTypeHasNoIndexedUnboundDimensions(Field field) {
        TensorType tensorType = ((TensorDataType)field.getDataType()).getTensorType();
        if (tensorType.dimensions().stream()
                .anyMatch(dim -> dim.type().equals(TensorType.Dimension.Type.indexedUnbound))) {
            throw new IllegalArgumentException("A modify update cannot be applied to tensor types with indexed unbound dimensions. " +
                                               "Field '" + field.getName() + "' has unsupported tensor type '" + tensorType + "'");
        }
    }

    private static void expectOperationSpecified(TensorModifyUpdate.Operation operation, String fieldName) {
        if (operation == null) {
            throw new IllegalArgumentException("Modify update for field '" + fieldName + "' does not contain an operation");
        }
    }

    private static void expectTensorSpecified(TensorFieldValue tensor, String fieldName) {
        if (tensor == null) {
            throw new IllegalArgumentException("Modify update for field '" + fieldName + "' does not contain tensor cells");
        }
    }

    private static class ModifyUpdateResult {
        TensorModifyUpdate.Operation operation = null;
        TensorFieldValue tensor = null;
    }

    private static ModifyUpdateResult createModifyUpdateResult(TokenBuffer buffer, Field field) {
        ModifyUpdateResult result = new ModifyUpdateResult();
        buffer.next();
        int localNesting = buffer.nesting();
        while (localNesting <= buffer.nesting()) {
            switch (buffer.currentName()) {
                case MODIFY_OPERATION:
                    result.operation = createOperation(buffer, field.getName());
                    break;
                case TENSOR_CELLS:
                    result.tensor = createTensorFromCells(buffer, field);
                    break;
                case TENSOR_BLOCKS:
                    result.tensor = createTensorFromBlocks(buffer, field);
                    break;
                default:
                    throw new IllegalArgumentException("Unknown JSON string '" + buffer.currentName() + "' in modify update for field '" + field.getName() + "'");
            }
            buffer.next();
        }
        return result;
    }

    private static TensorModifyUpdate.Operation createOperation(TokenBuffer buffer, String fieldName) {
        switch (buffer.currentText()) {
            case MODIFY_REPLACE:
                return TensorModifyUpdate.Operation.REPLACE;
            case MODIFY_ADD:
                return TensorModifyUpdate.Operation.ADD;
            case MODIFY_MULTIPLY:
                return TensorModifyUpdate.Operation.MULTIPLY;
            default:
                throw new IllegalArgumentException("Unknown operation '" + buffer.currentText() + "' in modify update for field '" + fieldName + "'");
        }
    }

    private static TensorFieldValue createTensorFromCells(TokenBuffer buffer, Field field) {
        TensorDataType tensorDataType = (TensorDataType)field.getDataType();
        TensorType originalType = tensorDataType.getTensorType();
        TensorType convertedType = TensorModifyUpdate.convertDimensionsToMapped(originalType);

        Tensor.Builder tensorBuilder = Tensor.Builder.of(convertedType);
        readTensorCells(buffer, tensorBuilder);
        Tensor tensor = tensorBuilder.build();

        validateBounds(tensor, originalType);

        return new TensorFieldValue(tensor);
    }

    private static TensorFieldValue createTensorFromBlocks(TokenBuffer buffer, Field field) {
        TensorDataType tensorDataType = (TensorDataType)field.getDataType();
        TensorType type = tensorDataType.getTensorType();

        Tensor.Builder tensorBuilder = Tensor.Builder.of(type);
        readTensorBlocks(buffer, tensorBuilder);
        Tensor tensor = convertToSparse(tensorBuilder.build());
        validateBounds(tensor, type);

        return new TensorFieldValue(tensor);
    }

    private static Tensor convertToSparse(Tensor tensor) {
        if (tensor.type().dimensions().stream().noneMatch(dimension -> dimension.isIndexed())) return tensor;
        Tensor.Builder b = Tensor.Builder.of(TensorModifyUpdate.convertDimensionsToMapped(tensor.type()));
        for (Iterator i = tensor.cellIterator(); i.hasNext(); )
            b.cell(i.next());
        return b.build();
    }

    /** Only validate if original type has indexed bound dimensions */
    static void validateBounds(Tensor convertedTensor, TensorType originalType) {
        if (originalType.dimensions().stream().noneMatch(d -> d instanceof TensorType.IndexedBoundDimension)) {
            return;
        }
        for (Iterator cellIterator = convertedTensor.cellIterator(); cellIterator.hasNext(); ) {
            Tensor.Cell cell = cellIterator.next();
            TensorAddress address = cell.getKey();
            for (int i = 0; i < address.size(); ++i) {
                TensorType.Dimension dim = originalType.dimensions().get(i);
                if (dim instanceof TensorType.IndexedBoundDimension) {
                    long label = address.numericLabel(i);
                    long bound = dim.size().get();  // size is non-optional for indexed bound
                    if (label >= bound) {
                        throw new IndexOutOfBoundsException("Dimension '" + originalType.dimensions().get(i).name() +
                                                            "' has label '" + label + "' but type is " + originalType.toString());
                    }
                }
            }
        }
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy