
com.bakdata.deduplication.similarity.CommonSimilarityMeasures Maven / Gradle / Ivy
/*
* The MIT License
*
* Copyright (c) 2018 bakdata GmbH
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*/
package com.bakdata.deduplication.similarity;
import static com.bakdata.deduplication.similarity.SimilarityMeasure.unknown;
import com.google.common.base.Splitter;
import com.google.common.collect.Lists;
import java.time.temporal.Temporal;
import java.time.temporal.TemporalUnit;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import lombok.Builder;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.Singular;
import lombok.Value;
import lombok.experimental.UtilityClass;
import org.apache.commons.codec.StringEncoder;
import org.apache.commons.codec.language.ColognePhonetic;
import org.apache.commons.codec.language.RefinedSoundex;
import org.apache.commons.codec.language.Soundex;
import org.apache.commons.codec.language.bm.BeiderMorseEncoder;
import org.apache.commons.text.similarity.JaroWinklerDistance;
import org.apache.commons.text.similarity.LevenshteinDistance;
import org.apache.commons.text.similarity.SimilarityScore;
@UtilityClass
public class CommonSimilarityMeasures {
private static final Splitter WHITE_SPACE_SPLITTER = Splitter.on(Pattern.compile("\\s+"));
public static SimilarityTransformation> bigram() {
return ngram(2);
}
public static SimilarityMeasure equality() {
return ((left, right, context) -> left.equals(right) ? 1 : 0);
}
public static SimilarityMeasure inequality() {
return ((left, right, context) -> left.equals(right) ? 0 : 1);
}
public static > SimilarityMeasure jaccard() {
return (left, right, context) -> {
@SuppressWarnings("unchecked") final Set leftSet =
left instanceof Set ? (Set) left : new HashSet<>(left);
@SuppressWarnings("unchecked") final Set rightSet =
left instanceof Set ? (Set) right : new HashSet<>(right);
final long intersectCount = leftSet.stream().filter(rightSet::contains).count();
return (float) intersectCount / (rightSet.size() + leftSet.size() - intersectCount);
};
}
public static SimilarityMeasure levenshtein() {
return new Levensthein<>(0);
}
public static SimilarityMeasure jaroWinkler() {
return new SimilarityScoreMeasure<>(new JaroWinklerDistance());
}
public static > SimilarityMeasure mongeElkan(final SimilarityMeasure pairMeasure) {
return mongeElkan(pairMeasure, Integer.MAX_VALUE / 2);
}
public static > SimilarityMeasure cosine() {
return (left, right, context) -> {
if (left == null || right == null) {
return unknown();
}
final Map leftHistogram =
left.stream().collect(Collectors.groupingBy(w -> w, Collectors.counting()));
final Map rightHistogram =
right.stream().collect(Collectors.groupingBy(w -> w, Collectors.counting()));
float dotProduct = 0;
for (final Map.Entry leftEntry : leftHistogram.entrySet()) {
final Long rightCount = rightHistogram.get(leftEntry.getKey());
if (rightCount != null) {
dotProduct += leftEntry.getValue() * rightCount;
}
}
return dotProduct / getLength(leftHistogram) / getLength(rightHistogram);
};
}
private static float getLength(final Map histogram) {
return (float) Math.sqrt(histogram.values().stream().mapToDouble(count -> count * count).sum());
}
public static > SimilarityMeasure mongeElkan(final SimilarityMeasure pairMeasure, final int maxPositionDiff) {
return new MongeElkan<>(pairMeasure, maxPositionDiff, 0);
}
public static SimilarityMeasure negate(final SimilarityMeasure super T> measure) {
return (left, right, context) -> 1 - measure.getSimilarity(left, right, context);
}
@SuppressWarnings("unchecked")
private static > List ensureList(final C leftCollection) {
return leftCollection instanceof List ? (List) leftCollection : List.copyOf(leftCollection);
}
public static > SimilarityMeasure positionWise(final SimilarityMeasure pairMeasure) {
return mongeElkan(pairMeasure, 0);
}
public static SimilarityTransformation colognePhonetic() {
return codec(new ColognePhonetic());
}
@SafeVarargs
public static SimilarityMeasure max(final SimilarityMeasure super T>... measures) {
if (measures.length == 0) {
throw new IllegalArgumentException();
}
return (left, right, context) -> {
if (left == null || right == null) {
return unknown();
}
float max = -1;
for (int i = 0; max < 1 && i < measures.length; i++) {
final float similarity = measures[i].getSimilarity(left, right, context);
if (!Float.isNaN(similarity)) {
max = Math.max(similarity, max);
}
}
return max == -1 ? Float.NaN : max;
};
}
public static SimilarityMeasure maxDiff(final int diff, final TemporalUnit unit) {
return (left, right, context) ->
Math.max(0, 1 - (float) Math.abs(left.until(right, unit)) / diff);
}
@SafeVarargs
public static SimilarityMeasure min(final SimilarityMeasure super T>... measures) {
if (measures.length == 0) {
throw new IllegalArgumentException();
}
return (left, right, context) -> {
if (left == null || right == null) {
return unknown();
}
float min = 2;
for (int i = 0; min > 0 && i < measures.length; i++) {
final float similarity = measures[i].getSimilarity(left, right, context);
if (!Float.isNaN(similarity)) {
min = Math.min(similarity, min);
}
}
return min == 2 ? Float.NaN : min;
};
}
public static SimilarityTransformation> ngram(final int n) {
return (t, context) -> IntStream.range(0, t.length() - n + 1)
.mapToObj(i -> t.subSequence(i, i + n))
.collect(Collectors.toList());
}
public static SimilarityTransformation soundex() {
return codec(new Soundex());
}
public static SimilarityTransformation refinedSoundex(final char[] mapping) {
return codec(new RefinedSoundex(mapping));
}
public static SimilarityTransformation beiderMorse() {
return codec(new BeiderMorseEncoder());
}
public static SimilarityTransformation codec(final StringEncoder encoder) {
return (s, context) -> encoder.encode(s);
}
public static SimilarityTransformation> words() {
return (t, context) -> Lists.newArrayList(WHITE_SPACE_SPLITTER.split(t));
}
public static SimilarityTransformation transform(final Function function) {
return (t, context) -> function.apply(t);
}
public static SimilarityTransformation> trigram() {
return ngram(3);
}
public static WeightedAggregation.WeightedAggregationBuilder weightedAggregation(final BiFunction, List, Float> aggregator) {
return WeightedAggregation.builder().aggregator(aggregator);
}
public static WeightedAggregation.WeightedAggregationBuilder weightedAverage() {
return weightedAggregation((weightedSims, weights) ->
(float) (weightedSims.stream().mapToDouble(sim -> sim).sum() / weights.stream().mapToDouble(w -> w).sum()));
}
static int getMaxLen(final CharSequence left, final CharSequence right) {
return Math.max(left.length(), right.length());
}
/**
* Used to translate {@link SimilarityScore} that are actually distance functions to similarity scores
*/
@RequiredArgsConstructor
public static class DistanceSimilarityMeasure implements SimilarityMeasure {
private final SimilarityScore extends Number> score;
@Override
public float getSimilarity(final CharSequence left, final CharSequence right, final SimilarityContext context) {
final float distance = this.score.apply(left, right).floatValue();
if (distance == -1) {
return 0;
}
return 1.0f - distance / getMaxLen(left, right);
}
}
@RequiredArgsConstructor
public static class SimilarityScoreMeasure implements SimilarityMeasure {
private final SimilarityScore extends Number> score;
@Override
public float getSimilarity(final CharSequence left, final CharSequence right, final SimilarityContext context) {
return this.score.apply(left, right).floatValue();
}
}
public static class Levensthein implements SimilarityMeasure {
private final float threshold;
public Levensthein(final float threshold) {
this.threshold = threshold;
}
@Override
public float getSimilarity(final CharSequence left, final CharSequence right, final SimilarityContext context) {
final var maxLen = getMaxLen(left, right);
final var maxDiff = (int) (maxLen * (1 - this.threshold));
final var measure = new DistanceSimilarityMeasure(new LevenshteinDistance(maxDiff));
return measure.getSimilarity(left, right, context);
}
@Override
public SimilarityMeasure cutoff(final float threshold) {
if (threshold < this.threshold) {
return this;
}
return new Levensthein<>(threshold);
}
}
@Builder
@Value
public static class WeightedAggregation implements SimilarityMeasure {
BiFunction, List, Float> aggregator;
@Singular
List> weightedSimilarities;
@Getter(lazy = true)
List weights = this.weightedSimilarities.stream().map(WeightedSimilarity::getWeight).collect(Collectors.toList());
@Override
public float getSimilarity(final R left, final R right, final SimilarityContext context) {
final var weightedSims = this.weightedSimilarities.stream()
.map(ws -> ws.getMeasure().getSimilarity(left, right, context) * ws.getWeight())
.collect(Collectors.toList());
List adjustedWeights = null;
for (int i = 0; i < weightedSims.size(); i++) {
if (weightedSims.get(i).isNaN()) {
if (adjustedWeights == null) {
adjustedWeights = new ArrayList<>(this.getWeights());
}
adjustedWeights.set(i, 0.0f);
weightedSims.set(i, 0.0f);
}
}
return this.aggregator.apply(weightedSims, adjustedWeights == null ? this.getWeights() : adjustedWeights);
}
@Value
public static class WeightedSimilarity {
float weight;
SimilarityMeasure measure;
}
public static class WeightedAggregationBuilder {
public WeightedAggregationBuilder add(final float weight, final Function extractor, final SimilarityMeasure super T> measure) {
return this.add(weight, measure.of(extractor));
}
public WeightedAggregationBuilder add(final float weight, final SimilarityMeasure measure) {
return this.weightedSimilarity(new WeightedSimilarity<>(weight, measure));
}
}
}
@Value
private static class MongeElkan, T> implements SimilarityMeasure {
private final SimilarityMeasure pairMeasure;
private final int maxPositionDiff;
private final float cutoff;
@Override
public float getSimilarity(final C leftCollection, final C rightCollection, final SimilarityContext context) {
if (leftCollection.isEmpty() || rightCollection.isEmpty()) {
return 0;
}
final List leftList = ensureList(leftCollection);
final List rightList = ensureList(rightCollection);
// when cutoff is .9 and |left| = 3, then on average each element has .1 buffer
// as soon as the current sum + buffer < index, the cutoff threshold cannot be passed (buffer used up)
final float cutoffBuffer = (1 - this.cutoff) * leftCollection.size();
float sum = 0;
for (int leftIndex = 0; leftIndex < leftCollection.size() && (cutoffBuffer + sum) >= leftIndex; leftIndex++) {
float max = 0;
for (int rightIndex = Math.max(0, leftIndex - this.maxPositionDiff),
rightMax = Math.min(rightCollection.size(), leftIndex + this.maxPositionDiff); max < 1.0 && rightIndex < rightMax; rightIndex++) {
max = Math.max(max, this.pairMeasure
.getSimilarity(leftList.get(leftIndex), rightList.get(rightIndex), context));
}
sum += max;
}
return CutoffSimiliarityMeasure.cutoff(sum / leftCollection.size(), this.cutoff);
}
@Override
public SimilarityMeasure cutoff(final float threshold) {
return new MongeElkan<>(this.pairMeasure, this.maxPositionDiff, threshold);
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy