All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.graphaware.nlp.ml.similarity.CosineSimilarity Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (c) 2013-2018 GraphAware
 *
 * This file is part of the GraphAware Framework.
 *
 * GraphAware Framework is free software: you can redistribute it and/or modify it under the terms of
 * the GNU General Public License as published by the Free Software Foundation, either
 * version 3 of the License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
 * without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
 * See the GNU General Public License for more details. You should have received a copy of
 * the GNU General Public License along with this program.  If not, see
 * .
 */
package com.graphaware.nlp.ml.similarity;

import java.util.List;
import java.util.Map;
import java.util.TreeSet;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.IntStream;

public class CosineSimilarity implements Similarity {

    @Override
    public float getSimilarity(Map xVector, Map yVector) {
        float a = getDotProduct(xVector, yVector);
        float b = getNorm(xVector) * getNorm(yVector);

        if (b > 0) {
            return a / b;
        } else {
            return 0;
        }
    }
    //improved version for user functions
    public double getSimilarity(List xVector, List yVector) {
        double a = getDotProduct(xVector, yVector);
        double b = getNorm(xVector) * getNorm(yVector);

        if (b > 0) {
            return a / b;
        } else {
            return 0;
        }
    }

    public double getSimilarity(float[] xVector, float[] yVector) {
        double a = getDotProduct(xVector, yVector);
        double b = getNorm(xVector) * getNorm(yVector);

        if (b > 0) {
            return a / b;
        } else {
            return 0;
        }
    }

    private float getDotProduct(final Map xVector, final Map yVector) {

        final AtomicReference sum = new AtomicReference<>(0f);
        TreeSet keys = new TreeSet<>(xVector.keySet());
        keys.addAll(yVector.keySet());
        keys.stream().forEach((key) -> {
            if (xVector.containsKey(key) && yVector.containsKey(key)) {
                float curValue = sum.get();
                curValue += xVector.get(key) * yVector.get(key);
                sum.set(curValue);
            }
        });
        return sum.get();
    }

    private double getDotProduct(final List xVector, final List yVector) {

        final AtomicReference sum = new AtomicReference<>(0f);
        IntStream.range(0, xVector.size())
                .boxed().forEach((key) -> {
                    float curValue = sum.get();
                    curValue += xVector.get(key) * yVector.get(key);
                    sum.set(curValue);
                });
        return sum.get();
    }


    private double getDotProduct(final float[] xVector, final float[] yVector) {

        final AtomicReference sum = new AtomicReference<>(0f);
        IntStream.range(0, xVector.length)
                .boxed().forEach((key) -> {
            float curValue = sum.get();
            curValue += xVector[key] * yVector[key];
            sum.set(curValue);
        });
        return sum.get();
    }

    private float getNorm(Map xVector) {
        final AtomicReference sum = new AtomicReference<>(0f);
        xVector.values().stream().forEach((value) -> {
            float curValue = sum.get();
            curValue += value * value;
            sum.set(curValue);
        });
        return Double.valueOf(Math.sqrt(sum.get().doubleValue())).floatValue();
    }
    
    private double getNorm(List xVector) {
        final AtomicReference sum = new AtomicReference<>(0f);
        xVector.stream().forEach((value) -> {
            float curValue = sum.get();
            curValue += value * value;
            sum.set(curValue);
        });
        return Math.sqrt(sum.get().doubleValue());
    }

    private double getNorm(float[] xVector) {
        final AtomicReference sum = new AtomicReference<>(0f);
        for (float value: xVector) {
            float curValue = sum.get();
            curValue += value * value;
            sum.set(curValue);
        }
        return Math.sqrt(sum.get().doubleValue());
    }

    public double cosineSimilarity(float[] vectorA, float[] vectorB) {
        double dotProduct = 0.0d;
        double normA = 0.0d;
        double normB = 0.0d;
        for (int i = 0; i < vectorA.length; i++) {
            dotProduct += vectorA[i] * vectorB[i];
            normA += Math.pow(vectorA[i], 2);
            normB += Math.pow(vectorB[i], 2);
        }
        return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy