com.expleague.ml.clustering.impl.NearestNeighborDRAlgorithm Maven / Gradle / Ivy
package com.expleague.ml.clustering.impl;
import com.expleague.commons.func.Computable;
import com.expleague.commons.math.metrics.Metric;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.util.Factories;
import com.expleague.ml.clustering.ClusterizationAlgorithm;
import org.jetbrains.annotations.NotNull;
import java.util.Collection;
/**
* User: terry
* Date: 16.01.2010
*/
public class NearestNeighborDRAlgorithm implements ClusterizationAlgorithm {
private final Metric metric;
private final double acceptanceDistance;
private final double distanceRatio;
public NearestNeighborDRAlgorithm(final Metric metric, final double acceptanceDistance, final double distanceRatio) {
this.metric = metric;
this.acceptanceDistance = acceptanceDistance;
this.distanceRatio = distanceRatio;
}
@NotNull
@Override
public Collection extends Collection> cluster(final Collection dataSet, final Computable data2DVector) {
final Collection> clusters = Factories.hashSet();
for (final X data : dataSet) {
final Vec dataVector = data2DVector.compute(data);
Collection nearestCluster = null;
double nearestDistance = Double.MAX_VALUE;
double nearest2Distance = Double.MAX_VALUE;
for (final Collection cluster : clusters) {
double minDistance = Double.MAX_VALUE;
for (final X element : cluster) {
final double candidateDistance = metric.distance(data2DVector.compute(element), dataVector);
minDistance = Math.min(minDistance, candidateDistance);
}
if (minDistance < nearestDistance) {
nearestDistance = minDistance;
nearestCluster = cluster;
}
else if (minDistance < nearest2Distance) {
nearest2Distance = minDistance;
}
}
final boolean good =
(nearestDistance < acceptanceDistance && (nearest2Distance == Double.MAX_VALUE || nearestDistance / nearest2Distance < distanceRatio));
if (nearestCluster == null || !good) {
clusters.add(Factories.hashSet(data));
} else {
nearestCluster.add(data);
}
}
return clusters;
}
}