All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.trino.operator.scalar.ArrayVectorFunctions Maven / Gradle / Ivy

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.trino.operator.scalar;

import io.trino.spi.block.Block;
import io.trino.spi.function.Description;
import io.trino.spi.function.ScalarFunction;
import io.trino.spi.function.SqlType;
import io.trino.spi.type.StandardTypes;

import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
import static io.trino.spi.type.DoubleType.DOUBLE;
import static io.trino.util.Failures.checkCondition;

public final class ArrayVectorFunctions
{
    private ArrayVectorFunctions() {}

    @Description("Calculates the euclidean distance between two vectors")
    @ScalarFunction
    @SqlType(StandardTypes.DOUBLE)
    public static double euclideanDistance(@SqlType("array(double)") Block first, @SqlType("array(double)") Block second)
    {
        checkCondition(first.getPositionCount() == second.getPositionCount(), INVALID_FUNCTION_ARGUMENT, "The arguments must have the same length");

        double sum = 0.0;
        for (int i = 0; i < first.getPositionCount(); i++) {
            double diff = DOUBLE.getDouble(first, i) - DOUBLE.getDouble(second, i);
            sum += diff * diff;
        }
        return Math.sqrt(sum);
    }

    @Description("Calculates the dot product between two vectors")
    @ScalarFunction
    @SqlType(StandardTypes.DOUBLE)
    public static double dotProduct(@SqlType("array(double)") Block first, @SqlType("array(double)") Block second)
    {
        checkCondition(first.getPositionCount() == second.getPositionCount(), INVALID_FUNCTION_ARGUMENT, "The arguments must have the same length");

        double dotProduct = 0.0;
        for (int i = 0; i < first.getPositionCount(); i++) {
            dotProduct += DOUBLE.getDouble(first, i) * DOUBLE.getDouble(second, i);
        }
        return dotProduct;
    }

    @Description("Calculates the cosine distance between two vectors")
    @ScalarFunction
    @SqlType(StandardTypes.DOUBLE)
    public static double cosineDistance(@SqlType("array(double)") Block first, @SqlType("array(double)") Block second)
    {
        checkCondition(first.getPositionCount() == second.getPositionCount(), INVALID_FUNCTION_ARGUMENT, "The arguments must have the same length");

        double firstMagnitude = 0.0;
        double secondMagnitude = 0.0;
        double dotProduct = 0.0;
        for (int i = 0; i < first.getPositionCount(); i++) {
            double firstValue = DOUBLE.getDouble(first, i);
            double secondValue = DOUBLE.getDouble(second, i);
            firstMagnitude += firstValue * firstValue;
            secondMagnitude += secondValue * secondValue;
            dotProduct += firstValue * secondValue;
        }

        checkCondition(firstMagnitude != 0 && secondMagnitude != 0, INVALID_FUNCTION_ARGUMENT, "Vector magnitude cannot be zero");
        double cosineSimilarity = dotProduct / Math.sqrt(firstMagnitude * secondMagnitude);
        return 1.0 - cosineSimilarity;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy