com.expleague.ml.clustering.impl.GenericNearestNeighborAlgoritm Maven / Gradle / Ivy
package com.expleague.ml.clustering.impl;
import com.expleague.commons.func.Computable;
import com.expleague.commons.math.metrics.Metric;
import com.expleague.ml.clustering.GenericClusterizationAlgorithm;
import com.expleague.commons.util.Factories;
import org.jetbrains.annotations.NotNull;
import java.util.Collection;
/**
* User: terry
* Date: 16.01.2010
*/
public class GenericNearestNeighborAlgoritm implements GenericClusterizationAlgorithm {
private final Metric metric;
private final double acceptanceDistance;
private final double rejectionDistance;
public GenericNearestNeighborAlgoritm(final Metric metric, final double acceptanceDistance, final double rejectionDistance) {
this.metric = metric;
this.acceptanceDistance = acceptanceDistance;
this.rejectionDistance = rejectionDistance;
}
@NotNull
@Override
public Collection extends Collection> cluster(final Collection dataSet, final Computable data2DVector) {
final Collection> clusters = Factories.hashSet();
for (final X data : dataSet) {
final V dataVector = data2DVector.compute(data);
Collection nearestCluster = null;
double nearestDistance = Double.MAX_VALUE;
for (final Collection cluster : clusters) {
for (final X element : cluster) {
final double candidateDistance = metric.distance(data2DVector.compute(element), dataVector);
if (candidateDistance < nearestDistance && candidateDistance < acceptanceDistance) {
nearestDistance = candidateDistance;
nearestCluster = cluster;
} else if (candidateDistance > rejectionDistance) break;
}
}
if (nearestCluster == null) {
clusters.add(Factories.hashSet(data));
} else {
nearestCluster.add(data);
}
}
return clusters;
}
}