com.expleague.ml.clustering.impl.ConnectedComponentOptimizer Maven / Gradle / Ivy
package com.expleague.ml.clustering.impl;
import com.expleague.commons.func.Computable;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecIterator;
import com.expleague.ml.clustering.ClusterizationAlgorithm;
import gnu.trove.map.hash.TIntObjectHashMap;
import gnu.trove.procedure.TObjectProcedure;
import org.jetbrains.annotations.NotNull;
import java.util.*;
/**
* User: solar
* Date: 14.02.2010
* Time: 0:48:33
*/
public class ConnectedComponentOptimizer implements ClusterizationAlgorithm {
private final ClusterizationAlgorithm algorithm;
private final double minToJoin;
public ConnectedComponentOptimizer(final ClusterizationAlgorithm algorithm, final double minToJoin) {
this.algorithm = algorithm;
this.minToJoin = minToJoin;
}
private static class IndexedVecIter {
VecIterator iter;
T t;
int componentIndex;
private IndexedVecIter(final VecIterator iter, final T t, final int index) {
this.iter = iter;
this.t = t;
componentIndex = index;
}
}
private static class VecIterEntry implements Comparable {
List iters = new LinkedList();
final int index;
public VecIterEntry(final int index) {
this.index = index;
}
@Override
public int compareTo(@NotNull final VecIterEntry node) {
return index - node.index;
}
}
private static void processIter(final Set iters, final TIntObjectHashMap cache, final IndexedVecIter iter) {
final int index = iter.iter.index();
VecIterEntry iterEntry = cache.get(index);
if (iterEntry == null) {
iterEntry = new VecIterEntry(index);
iters.add(iterEntry);
cache.put(index, iterEntry);
}
iterEntry.iters.add(iter);
}
@NotNull
@Override
public Collection extends Collection> cluster(final Collection dataSet, final Computable data2DVector) {
final TreeSet iters = new TreeSet<>();
final TIntObjectHashMap cache = new TIntObjectHashMap();
final List> entries = new ArrayList>();
final double minToJoin = this.minToJoin;// + 0.5 * (1 - Math.min(1, Math.log(2000) / Math.log(dataSet.size())));
{
int index = 1;
for (final T t : dataSet) {
final Vec vec = data2DVector.compute(t);
final VecIterator iter = vec.nonZeroes();
while (iter.advance() && iter.value() < minToJoin);
if (iter.isValid()) {
final IndexedVecIter entry = new IndexedVecIter(iter, t, index++);
entries.add(entry);
processIter(iters, cache, entry);
}
}
}
while (!iters.isEmpty()) {
final VecIterEntry topEntry = iters.pollFirst();
int maxComponentIndex = 0;
final boolean join = topEntry.iters.size() > 1 && topEntry.iters.size() < dataSet.size() / 10;
if (join) {
double sum = 0;
int count = 0;
int prev = 0;
for (final IndexedVecIter iter : topEntry.iters) {
count++;
sum += iter.iter.value();
if (prev != 0 && prev != iter.iter.index())
System.err.println("FUCK!!!");
prev = iter.iter.index();
if (iter.componentIndex > maxComponentIndex) {
maxComponentIndex = iter.componentIndex;
}
}
// System.out.println(termsBasis.fromIndex(topEntry.iters.at(0).iter.index()) + ": " + topEntry.iters.size()+ ":" + maxComponentIndex + ":" + (sum / Math.max(1, count)));
}
for (final IndexedVecIter iter : topEntry.iters) {
if (join)
iter.componentIndex = maxComponentIndex;
while (iter.iter.advance() && iter.iter.value() < minToJoin);
if (iter.iter.isValid())
processIter(iters, cache, iter);
}
}
final TIntObjectHashMap> components = new TIntObjectHashMap>();
for (final IndexedVecIter entry : entries) {
Set component = components.get(entry.componentIndex);
if (component == null)
components.put(entry.componentIndex, component = new HashSet());
component.add(entry.t);
}
// System.out.println(components.size() + " components found");
final List> clusters = new ArrayList>();
components.forEachValue(new TObjectProcedure>() {
@Override
public boolean execute(final Set ts) {
for (final Collection cluster : algorithm.cluster(ts, data2DVector)) {
clusters.add(cluster);
}
return true;
}
});
return clusters;
}
}