
ai.vespa.rankingexpression.importer.operations.Shape Maven / Gradle / Ivy
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.List;
public class Shape extends IntermediateOperation {
public Shape(String modelName, String nodeName, List inputs) {
super(modelName, nodeName, inputs);
createConstantValue();
}
@Override
protected OrderedTensorType lazyGetType() {
if ( ! allInputTypesPresent(1)) return null;
OrderedTensorType inputType = inputs.get(0).type().get();
return new OrderedTensorType.Builder(resultValueType())
.add(TensorType.Dimension.indexed(vespaName(), inputType.dimensions().size()))
.build();
}
@Override
protected TensorFunction lazyGetFunction() {
return null; // will be added by function() since this is constant.
}
@Override
public boolean isConstant() {
return true;
}
@Override
public Shape withInputs(List inputs) {
return new Shape(modelName(), name(), inputs);
}
private void createConstantValue() {
if (!allInputTypesPresent(1)) {
return;
}
OrderedTensorType inputType = inputs.get(0).type().get();
IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) Tensor.Builder.of(type().get().type());
List inputDimensions = inputType.dimensions();
for (int i = 0; i < inputDimensions.size(); i++) {
builder.cellByDirectIndex(i, inputDimensions.get(i).size().orElse(-1L));
}
this.setConstantValue(new TensorValue(builder.build()));
}
@Override
public String operationName() { return "Shape"; }
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy