All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.commons.text.similarity.CosineSimilarity Maven / Gradle / Ivy

There is a newer version: 1.12.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.commons.text.similarity;

import java.util.HashSet;
import java.util.Map;
import java.util.Set;

/**
 * Measures the Cosine similarity of two vectors of an inner product space and
 * compares the angle between them.
 *
 * 

* For further explanation about the Cosine Similarity, refer to * http://en.wikipedia.org/wiki/Cosine_similarity. *

* * @since 1.0 */ public class CosineSimilarity { /** * Calculates the cosine similarity for two given vectors. * * @param leftVector left vector * @param rightVector right vector * @return cosine similarity between the two vectors */ public Double cosineSimilarity(final Map leftVector, final Map rightVector) { if (leftVector == null || rightVector == null) { throw new IllegalArgumentException("Vectors must not be null"); } final Set intersection = getIntersection(leftVector, rightVector); final double dotProduct = dot(leftVector, rightVector, intersection); double d1 = 0.0d; for (final Integer value : leftVector.values()) { d1 += Math.pow(value, 2); } double d2 = 0.0d; for (final Integer value : rightVector.values()) { d2 += Math.pow(value, 2); } double cosineSimilarity; if (d1 <= 0.0 || d2 <= 0.0) { cosineSimilarity = 0.0; } else { cosineSimilarity = dotProduct / (Math.sqrt(d1) * Math.sqrt(d2)); } return cosineSimilarity; } /** * Returns a set with strings common to the two given maps. * * @param leftVector left vector map * @param rightVector right vector map * @return common strings */ private Set getIntersection(final Map leftVector, final Map rightVector) { final Set intersection = new HashSet<>(leftVector.keySet()); intersection.retainAll(rightVector.keySet()); return intersection; } /** * Computes the dot product of two vectors. It ignores remaining elements. It means * that if a vector is longer than other, then a smaller part of it will be used to compute * the dot product. * * @param leftVector left vector * @param rightVector right vector * @param intersection common elements * @return The dot product */ private double dot(final Map leftVector, final Map rightVector, final Set intersection) { long dotProduct = 0; for (final CharSequence key : intersection) { dotProduct += leftVector.get(key) * rightVector.get(key); } return dotProduct; } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy