
com.bakdata.deduplication.clustering.RefineCluster Maven / Gradle / Ivy
/*
* The MIT License
*
* Copyright (c) 2018 bakdata GmbH
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*/
package com.bakdata.deduplication.clustering;
import com.bakdata.deduplication.candidate_selection.Candidate;
import com.bakdata.deduplication.classifier.Classification;
import com.bakdata.deduplication.classifier.ClassifiedCandidate;
import com.bakdata.deduplication.classifier.Classifier;
import com.google.common.primitives.Bytes;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.PriorityQueue;
import java.util.Random;
import java.util.Set;
import java.util.Spliterators;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import lombok.Builder;
import lombok.NonNull;
import lombok.Value;
import lombok.experimental.Wither;
@Value
@Builder
public class RefineCluster, T> {
private static final int MAX_SUB_CLUSTERS = 100;
@Builder.Default
int maxSmallClusterSize = 10;
@NonNull
Classifier classifier;
@NonNull
Function, C> clusterIdGenerator;
private static float getWeight(final Classification classification) {
switch (classification.getResult()) {
case DUPLICATE:
return classification.getConfidence();
case NON_DUPLICATE:
return -classification.getConfidence();
case UNKNOWN:
return -0.0f;
default:
throw new IllegalStateException();
}
}
private static int getNumEdges(final int n) {
return n * (n - 1) / 2;
}
private static float scoreClustering(final byte[] partitions, final float[][] weightMatrix) {
final int n = partitions.length;
final int[] partitionSizes = new int[n];
for (final byte clustering : partitions) {
partitionSizes[clustering]++;
}
float score = 0;
for (int rowIndex = 0; rowIndex < n; rowIndex++) {
for (int colIndex = rowIndex + 1; colIndex < n; colIndex++) {
if (partitions[rowIndex] == partitions[colIndex]) {
score += weightMatrix[rowIndex][colIndex] / partitionSizes[partitions[rowIndex]];
} else {
score -= weightMatrix[rowIndex][colIndex] / (n - partitionSizes[partitions[rowIndex]]) +
weightMatrix[rowIndex][colIndex] / (n - partitionSizes[partitions[colIndex]]);
}
}
}
return score;
}
private List> getRelevantClassifications(final Cluster cluster,
final Map>> relevantClassificationIndex) {
return cluster.getElements().stream()
.flatMap(record -> relevantClassificationIndex.getOrDefault(record, List.of()).stream()
.filter(classifiedCandidate -> cluster.contains(classifiedCandidate.getCandidate().getOldRecord())))
.collect(Collectors.toList());
}
public List> refine(final Collection> transitiveClosure,
final Iterable> knownClassifications) {
final Map>> relevantClassificationIndex =
this.getRelevantClassificationIndex(knownClassifications);
return transitiveClosure.stream()
.flatMap(cluster -> this.refineCluster(cluster,
this.getRelevantClassifications(cluster, relevantClassificationIndex)))
.collect(Collectors.toList());
}
private Map>> getRelevantClassificationIndex(final Iterable> knownClassifications) {
final Map>> relevantClassifications = new HashMap<>();
for (final ClassifiedCandidate knownClassification : knownClassifications) {
final Candidate candidate = knownClassification.getCandidate();
relevantClassifications.computeIfAbsent(candidate.getNewRecord(), r -> new LinkedList<>()).add(knownClassification);
}
return relevantClassifications;
}
private byte[] refineBigCluster(final Cluster cluster, final List> knownClassifications) {
final List duplicates = this.toWeightedEdges(knownClassifications, cluster);
final int desiredNumEdges = RefineCluster.getNumEdges(this.maxSmallClusterSize);
return this.greedyCluster(cluster, this.getWeightedEdges(cluster, duplicates, desiredNumEdges));
}
/**
* Performs perfect clustering by maximizing intra-cluster similarity and minimizing inter-cluster similarity.
* Quite compute-heavy for larger clusters as we perform
* a complete pair-wise comparison (expensive and quadratic)
* and compare EACH possible clustering (cheap and exponential).
*
* @return the best clustering
*/
private byte[] refineSmallCluster(final Cluster cluster, final List> knownClassifications) {
final float[][] weightMatrix = this.getKnownWeightMatrix(cluster, knownClassifications);
final int n = cluster.size();
for (int rowIndex = 0; rowIndex < n; rowIndex++) {
for (int colIndex = rowIndex + 1; colIndex < n; colIndex++) {
if (Float.isNaN(weightMatrix[rowIndex][colIndex])) {
weightMatrix[rowIndex][colIndex] =
getWeight(this.classifier.classify(new Candidate<>(cluster.get(rowIndex), cluster.get(colIndex))));
}
}
}
return StreamSupport.stream(Spliterators.spliteratorUnknownSize(new ClusteringGenerator((byte) n), 0), false)
.map(clustering -> new AbstractMap.SimpleEntry<>(clustering.clone(),
RefineCluster.scoreClustering(clustering, weightMatrix)))
.max(Comparator.comparingDouble(Map.Entry::getValue))
.map(Map.Entry::getKey)
.orElseThrow(() -> new IllegalStateException("Non-empty clusters should have one valid clustering"));
}
private List toWeightedEdges(final Collection> knownClassifications,
final Cluster cluster) {
final Map clusterIndex =
IntStream.range(0, cluster.size()).boxed().collect(Collectors.toMap(cluster::get, i -> i));
return knownClassifications.stream()
.map(knownClassification ->
WeightedEdge.of(clusterIndex.get(knownClassification.getCandidate().getNewRecord()),
clusterIndex.get(knownClassification.getCandidate().getOldRecord()),
getWeight(knownClassification.getClassification())))
.collect(Collectors.toList());
}
private Stream> refineCluster(final Cluster cluster,
final List> knownClassifications) {
if (cluster.size() <= 2) {
return Stream.of(cluster);
}
final byte[] bestClustering;
if (cluster.size() > this.maxSmallClusterSize) {
// large cluster with high probability of error
bestClustering = this.refineBigCluster(cluster, knownClassifications);
} else {
bestClustering = this.refineSmallCluster(cluster, knownClassifications);
}
return this.getSubClusters(bestClustering, cluster);
}
private float[][] getKnownWeightMatrix(final Cluster cluster,
final Iterable> knownClassifications) {
final var n = cluster.size();
final var weightMatrix = new float[n][n];
for (final var row : weightMatrix) {
Arrays.fill(row, Float.NaN);
}
final var clusterIndex =
IntStream.range(0, n).boxed().collect(Collectors.toMap(cluster::get, i -> i));
for (final ClassifiedCandidate knownClassification : knownClassifications) {
final var firstIndex = clusterIndex.get(knownClassification.getCandidate().getNewRecord());
final var secondIndex = clusterIndex.get(knownClassification.getCandidate().getOldRecord());
weightMatrix[Math.min(firstIndex, secondIndex)][Math.max(firstIndex, secondIndex)] =
getWeight(knownClassification.getClassification());
}
return weightMatrix;
}
private Stream> getSubClusters(final byte[] bestClustering, final Cluster cluster) {
final Map> subClusters = IntStream.range(0, bestClustering.length)
.mapToObj(index -> new AbstractMap.SimpleEntry<>(bestClustering[index], cluster.get(index)))
.collect(Collectors.groupingBy(Map.Entry::getKey, Collectors.mapping(Map.Entry::getValue, Collectors.toList())));
return subClusters.values().stream()
.map(records -> new Cluster<>(this.clusterIdGenerator.apply(records), records));
}
private byte[] greedyCluster(final Cluster cluster, final Collection extends WeightedEdge> edges) {
final Collection queue = new PriorityQueue<>(Comparator.comparing(WeightedEdge::getWeight));
queue.addAll(edges);
final float[][] weightMatrix = new float[cluster.size()][cluster.size()];
for (final WeightedEdge edge : edges) {
weightMatrix[edge.left][edge.right] = edge.getWeight();
}
// start with each publication in its own cluster
byte[] clustering = Bytes.toArray(IntStream.range(0, cluster.size()).boxed().collect(Collectors.toList()));
float score = 0;
for (final WeightedEdge edge : queue) {
final byte[] newClustering = clustering.clone();
final byte newClusterId = newClustering[edge.left];
final byte oldClusterId = newClustering[edge.right];
for (int i = 0; i < newClustering.length; i++) {
if (newClustering[i] == oldClusterId) {
newClustering[i] = newClusterId;
}
}
final float newScore = RefineCluster.scoreClustering(newClustering, weightMatrix);
if (newScore > score) {
score = newScore;
clustering = newClustering;
}
}
return clustering;
}
private List addRandomEdges(final List extends WeightedEdge> edges, final int desiredNumEdges) {
// add random edges with distance 2..n of known edges (e.g., neighbors of known edges).
List lastAddedEdges;
final Set weightedEdges = new LinkedHashSet<>(edges);
for (int distance = 2; distance < this.maxSmallClusterSize && weightedEdges.size() < desiredNumEdges; distance++) {
lastAddedEdges = edges.stream()
.flatMap(e1 -> edges.stream().filter(e1::overlaps).map(e1::getTriangleEdge))
.filter(e -> !weightedEdges.contains(e))
.limit((long) desiredNumEdges - edges.size())
.collect(Collectors.toList());
weightedEdges.addAll(lastAddedEdges);
Collections.shuffle(lastAddedEdges);
}
if (weightedEdges.size() < desiredNumEdges) {
throw new IllegalStateException("We have a connected components, so we should get a fully connected graph");
}
return new ArrayList<>(weightedEdges);
}
private List getRandomEdges(final int potentialNumEdges, final int desiredNumEdges) {
final List weightedEdges;
weightedEdges = new Random().ints(0, potentialNumEdges)
.distinct()
.limit(desiredNumEdges)
.mapToObj(i -> {
// reverse of Gaussian
int leftIndex = (int) (Math.sqrt(i + 0.25) - 0.5);
int rightIndex = i - RefineCluster.getNumEdges(leftIndex) + leftIndex;
return WeightedEdge.of(leftIndex, rightIndex, Float.NaN);
})
.collect(Collectors.toList());
return weightedEdges;
}
private List getWeightedEdges(final Cluster cluster,
final List duplicates,
final int desiredNumEdges) {
final List weightedEdges;
if (duplicates.isEmpty()) {
final int n = cluster.size();
weightedEdges = this.getRandomEdges(RefineCluster.getNumEdges(n), desiredNumEdges);
} else {
Collections.shuffle(duplicates);
weightedEdges = this.addRandomEdges(duplicates, desiredNumEdges);
}
return weightedEdges.stream().map(weightedEdge -> {
float weight = weightedEdge.getWeight();
if (Float.isNaN(weight)) {
// calculate weight for dummy entry
T left = cluster.get(weightedEdge.getLeft());
T right = cluster.get(weightedEdge.getRight());
return weightedEdge.withWeight(getWeight(this.classifier.classify(new Candidate<>(left, right))));
}
return weightedEdge;
}).collect(Collectors.toList());
}
private static final class ClusteringGenerator implements Iterator {
final byte n;
final byte[] clustering;
boolean hasNext = true;
ClusteringGenerator(final byte n) {
this.n = n;
this.clustering = new byte[n];
}
@Override
public boolean hasNext() {
if (this.hasNext) {
return true;
}
for (byte i = (byte) (this.n - (byte) 1); i > 0; i--) {
if (this.clustering[i] < this.n && !this.incrementWouldResultInSkippedInteger(i)) {
this.clustering[i]++;
Arrays.fill(this.clustering, i + 1, this.n, (byte) 0);
this.hasNext = true;
return true;
}
}
return false;
}
@Override
public byte[] next() {
if (!this.hasNext()) {
throw new NoSuchElementException();
}
this.hasNext = false;
return this.clustering;
}
private boolean incrementWouldResultInSkippedInteger(final byte i) {
for (byte j = (byte) (i - 1); j >= 0; j--) {
if (this.clustering[i] <= this.clustering[j]) {
return false;
}
}
return true;
}
}
@Value
private static class WeightedEdge {
private int left;
private int right;
@Wither
private
float weight;
static WeightedEdge of(final int leftIndex, final int rightIndex, final float weight) {
return new WeightedEdge(Math.min(leftIndex, rightIndex), Math.max(leftIndex, rightIndex), weight);
}
WeightedEdge getTriangleEdge(final WeightedEdge e) {
if (this.left < e.left) {
return new WeightedEdge(this.left, e.left + e.right - this.right, Float.NaN);
} else if (this.left == e.left) {
return new WeightedEdge(Math.min(this.right, e.right), Math.max(this.right, e.right), Float.NaN);
}
return new WeightedEdge(e.left, this.left + this.right - e.right, Float.NaN);
}
boolean overlaps(final WeightedEdge e) {
return e.left == this.left || e.getLeft() == this.right || e.right == this.getLeft() || e.getRight() == this
.getRight();
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy