All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.carrot2.clustering.kmeans.BisectingKMeansClusteringAlgorithm Maven / Gradle / Ivy


/*
 * Carrot2 project.
 *
 * Copyright (C) 2002-2016, Dawid Weiss, Stanisław Osiński.
 * All rights reserved.
 *
 * Refer to the full license file "carrot2.LICENSE"
 * in the root folder of the repository checkout or at:
 * http://www.carrot2.org/carrot2.LICENSE
 */

package org.carrot2.clustering.kmeans;

import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;

import org.apache.commons.lang.ObjectUtils;
import org.carrot2.core.Cluster;
import org.carrot2.core.Document;
import org.carrot2.core.IClusteringAlgorithm;
import org.carrot2.core.LanguageCode;
import org.carrot2.core.ProcessingComponentBase;
import org.carrot2.core.ProcessingException;
import org.carrot2.core.attribute.AttributeNames;
import org.carrot2.core.attribute.CommonAttributes;
import org.carrot2.core.attribute.Init;
import org.carrot2.core.attribute.Internal;
import org.carrot2.core.attribute.Processing;
import org.carrot2.text.analysis.ITokenizer;
import org.carrot2.text.clustering.IMonolingualClusteringAlgorithm;
import org.carrot2.text.clustering.MultilingualClustering;
import org.carrot2.text.preprocessing.LabelFormatter;
import org.carrot2.text.preprocessing.PreprocessingContext;
import org.carrot2.text.preprocessing.pipeline.BasicPreprocessingPipeline;
import org.carrot2.text.preprocessing.pipeline.IPreprocessingPipeline;
import org.carrot2.text.vsm.ReducedVectorSpaceModelContext;
import org.carrot2.text.vsm.TermDocumentMatrixBuilder;
import org.carrot2.text.vsm.TermDocumentMatrixReducer;
import org.carrot2.text.vsm.VectorSpaceModelContext;
import org.carrot2.util.attribute.Attribute;
import org.carrot2.util.attribute.AttributeLevel;
import org.carrot2.util.attribute.Bindable;
import org.carrot2.util.attribute.DefaultGroups;
import org.carrot2.util.attribute.Group;
import org.carrot2.util.attribute.Input;
import org.carrot2.util.attribute.Label;
import org.carrot2.util.attribute.Level;
import org.carrot2.util.attribute.Output;
import org.carrot2.util.attribute.Required;
import org.carrot2.util.attribute.constraint.ImplementingClasses;
import org.carrot2.util.attribute.constraint.IntRange;

import com.carrotsearch.hppc.IntArrayList;
import com.carrotsearch.hppc.IntIntHashMap;
import com.carrotsearch.hppc.IntIntMap;
import com.carrotsearch.hppc.cursors.IntCursor;
import com.carrotsearch.hppc.cursors.IntIntCursor;
import com.carrotsearch.hppc.sorting.IndirectComparator;
import com.carrotsearch.hppc.sorting.IndirectSort;

import org.carrot2.mahout.math.function.Functions;
import org.carrot2.mahout.math.matrix.DoubleMatrix1D;
import org.carrot2.mahout.math.matrix.DoubleMatrix2D;
import org.carrot2.mahout.math.matrix.impl.DenseDoubleMatrix1D;
import org.carrot2.mahout.math.matrix.impl.DenseDoubleMatrix2D;
import org.carrot2.shaded.guava.common.collect.Lists;

/**
 * A very simple implementation of bisecting k-means clustering. Unlike other algorithms
 * in Carrot2, this one creates hard clusterings (one document belongs only to one
 * cluster). On the other hand, the clusters are labeled only with individual words that
 * may not always fully correspond to all documents in the cluster.
 */
@Bindable(prefix = "BisectingKMeansClusteringAlgorithm", inherit = CommonAttributes.class)
public class BisectingKMeansClusteringAlgorithm extends ProcessingComponentBase implements
    IClusteringAlgorithm
{
    /** {@link Group} name. */
    private final static String GROUP_KMEANS = "K-means";
    
    @Processing
    @Input
    @Required
    @Internal
    @Attribute(key = AttributeNames.DOCUMENTS, inherit = true)
    public List documents;

    @Processing
    @Output
    @Internal
    @Attribute(key = AttributeNames.CLUSTERS, inherit = true)
    public List clusters = null;

    /**
     * The number of clusters to create. The algorithm will create at most the specified
     * number of clusters.
     */
    @Processing
    @Input
    @Attribute
    @IntRange(min = 2)
    @Group(DefaultGroups.CLUSTERS)
    @Level(AttributeLevel.BASIC)
    @Label("Cluster count")
    public int clusterCount = 25;

    /**
     * The maximum number of k-means iterations to perform.
     */
    @Processing
    @Input
    @Attribute
    @IntRange(min = 1)
    @Group(GROUP_KMEANS)
    @Level(AttributeLevel.BASIC)
    @Label("Maximum iterations")
    public int maxIterations = 15;

    /**
     * Use dimensionality reduction. If true, k-means will be applied on the
     * dimensionality-reduced term-document matrix with the number of dimensions being
     * equal to twice the number of requested clusters. If the number of dimensions is 
     * lower than the number of input documents, reduction will not be performed.
     * If false, the k-means will
     * be performed directly on the original term-document matrix.
     */
    @Processing
    @Input
    @Attribute
    @Group(GROUP_KMEANS)
    @Level(AttributeLevel.BASIC)
    @Label("Use dimensionality reduction")
    public boolean useDimensionalityReduction = true;

    /**
     * Partition count. The number of partitions to create at each k-means clustering
     * iteration.
     */
    @Processing
    @Input
    @Attribute
    @IntRange(min = 2, max = 10)
    @Group(GROUP_KMEANS)
    @Level(AttributeLevel.BASIC)
    @Label("Partition count")
    public int partitionCount = 2;

    /**
     * Label count. The minimum number of labels to return for each cluster.
     */
    @Processing
    @Input
    @Attribute
    @IntRange(min = 1, max = 10)
    @Group(DefaultGroups.CLUSTERS)
    @Level(AttributeLevel.BASIC)
    @Label("Label count")
    public int labelCount = 3;

    /**
     * Common preprocessing tasks handler.
     */
    @Init
    @Input
    @Attribute
    @Internal
    @ImplementingClasses(classes = {}, strict = false)
    @Level(AttributeLevel.ADVANCED)
    public IPreprocessingPipeline preprocessingPipeline = new BasicPreprocessingPipeline();

    /**
     * Term-document matrix builder for the algorithm, contains bindable attributes.
     */
    public final TermDocumentMatrixBuilder matrixBuilder = new TermDocumentMatrixBuilder();

    /**
     * Term-document matrix reducer for the algorithm, contains bindable attributes.
     */
    public final TermDocumentMatrixReducer matrixReducer = new TermDocumentMatrixReducer();

    /**
     * Cluster label formatter, contains bindable attributes.
     */
    public final LabelFormatter labelFormatter = new LabelFormatter();

    /**
     * A helper for performing multilingual clustering.
     */
    public final MultilingualClustering multilingualClustering = new MultilingualClustering();

    @Override
    public void process() throws ProcessingException
    {
        // There is a tiny trick here to support multilingual clustering without
        // refactoring the whole component: we remember the original list of documents
        // and invoke clustering for each language separately within the 
        // IMonolingualClusteringAlgorithm implementation below. This is safe because
        // processing components are not thread-safe by definition and 
        // IMonolingualClusteringAlgorithm forbids concurrent execution by contract.
        final List originalDocuments = documents;
        clusters = multilingualClustering.process(documents,
            new IMonolingualClusteringAlgorithm()
            {
                public List process(List documents, LanguageCode language)
                {
                    BisectingKMeansClusteringAlgorithm.this.documents = documents;
                    BisectingKMeansClusteringAlgorithm.this.cluster(language);
                    return BisectingKMeansClusteringAlgorithm.this.clusters;
                }
            });
        documents = originalDocuments;
    }

    /**
     * Perform clustering for a given language.
     */
    protected void cluster(LanguageCode language)
    {
        // Preprocessing of documents
        final PreprocessingContext preprocessingContext = 
            preprocessingPipeline.preprocess(documents, null, language);

        // Add trivial AllLabels so that we can reuse the common TD matrix builder
        final int [] stemsMfow = preprocessingContext.allStems.mostFrequentOriginalWordIndex;
        final short [] wordsType = preprocessingContext.allWords.type;
        final IntArrayList featureIndices = new IntArrayList(stemsMfow.length);
        for (int i = 0; i < stemsMfow.length; i++)
        {
            final short flag = wordsType[stemsMfow[i]];
            if ((flag & (ITokenizer.TF_COMMON_WORD | ITokenizer.TF_QUERY_WORD | ITokenizer.TT_NUMERIC)) == 0)
            {
                featureIndices.add(stemsMfow[i]);
            }
        }
        preprocessingContext.allLabels.featureIndex = featureIndices.toArray();
        preprocessingContext.allLabels.firstPhraseIndex = -1;

        // Further processing only if there are words to process
        clusters = Lists.newArrayList();
        if (preprocessingContext.hasLabels())
        {
            // Term-document matrix building and reduction
            final VectorSpaceModelContext vsmContext = new VectorSpaceModelContext(
                preprocessingContext);
            final ReducedVectorSpaceModelContext reducedVsmContext = new ReducedVectorSpaceModelContext(
                vsmContext);

            matrixBuilder.buildTermDocumentMatrix(vsmContext);
            matrixBuilder.buildTermPhraseMatrix(vsmContext);

            // Prepare rowIndex -> stemIndex mapping for labeling
            final IntIntHashMap rowToStemIndex = new IntIntHashMap();
            for (IntIntCursor c : vsmContext.stemToRowIndex)
            {
                rowToStemIndex.put(c.value, c.key);
            }

            final DoubleMatrix2D tdMatrix;
            if (useDimensionalityReduction && clusterCount * 2 < preprocessingContext.documents.size())
            {
                matrixReducer.reduce(reducedVsmContext, clusterCount * 2);
                tdMatrix = reducedVsmContext.coefficientMatrix.viewDice();
            }
            else
            {
                tdMatrix = vsmContext.termDocumentMatrix;
            }

            // Initial selection containing all columns, initial clustering
            final IntArrayList columns = new IntArrayList(tdMatrix.columns());
            for (int c = 0; c < tdMatrix.columns(); c++)
            {
                columns.add(c);
            }
            final List rawClusters = Lists.newArrayList();
            rawClusters.addAll(split(partitionCount, tdMatrix, columns, maxIterations));
            Collections.sort(rawClusters, BY_SIZE_DESCENDING);
            
            int largestIndex = 0;
            while (rawClusters.size() < clusterCount && largestIndex < rawClusters.size())
            {
                // Find largest cluster to split
                IntArrayList largest = rawClusters.get(largestIndex);
                if (largest.size() <= partitionCount * 2) 
                {
                    // No cluster is large enough to produce a meaningful
                    // split (i.e. a split into subclusters with more than
                    // 1 member).
                    break;
                }

                final List split = split(partitionCount, tdMatrix, largest,
                    maxIterations);
                if (split.size() > 1)
                {
                    rawClusters.remove(largestIndex);
                    rawClusters.addAll(split);
                    Collections.sort(rawClusters, BY_SIZE_DESCENDING);
                    largestIndex = 0;
                }
                else
                {
                    largestIndex++;
                }
            }

            for (int i = 0; i < rawClusters.size(); i++)
            {
                final Cluster cluster = new Cluster();

                final IntArrayList rawCluster = rawClusters.get(i);
                if (rawCluster.size() > 1)
                {
                    cluster.addPhrases(getLabels(rawCluster,
                        vsmContext.termDocumentMatrix, rowToStemIndex,
                        preprocessingContext.allStems.mostFrequentOriginalWordIndex,
                        preprocessingContext.allWords.image));
                    for (int j = 0; j < rawCluster.size(); j++)
                    {
                        cluster.addDocuments(documents.get(rawCluster.get(j)));
                    }
                    clusters.add(cluster);
                }
            }
        }

        Collections.sort(clusters, Cluster.BY_REVERSED_SIZE_AND_LABEL_COMPARATOR);
        Cluster.appendOtherTopics(documents, clusters);
    }

    private static final Comparator BY_SIZE_DESCENDING = new Comparator()
    {
        @Override
        public int compare(IntArrayList o1, IntArrayList o2)
        {
            // We don't expect very large sizes here.
            return o2.size() - o1.size();
        }
    };
    
    private List getLabels(IntArrayList documents,
        DoubleMatrix2D termDocumentMatrix, IntIntHashMap rowToStemIndex,
        int [] mostFrequentOriginalWordIndex, char [][] wordImage)
    {
        // Prepare a centroid. If dimensionality reduction was used,
        // the centroid from k-means will not be based on real terms,
        // so we need to calculate the centroid here once again based
        // on the cluster's documents.
        final DoubleMatrix1D centroid = new DenseDoubleMatrix1D(termDocumentMatrix.rows());
        for (IntCursor d : documents)
        {
            centroid.assign(termDocumentMatrix.viewColumn(d.value), Functions.PLUS);
        }

        final List labels = Lists.newArrayListWithCapacity(labelCount);

        final int [] order = IndirectSort.mergesort(0, centroid.size(),
            new IndirectComparator()
            {
                @Override
                public int compare(int a, int b)
                {
                    final double valueA = centroid.get(a);
                    final double valueB = centroid.get(b);
                    return valueA < valueB ? -1 : valueA > valueB ? 1 : 0;
                }
            });
        final double minValueForLabel = centroid.get(order[order.length
            - Math.min(labelCount, order.length)]);

        for (int i = 0; i < centroid.size(); i++)
        {
            if (centroid.getQuick(i) >= minValueForLabel)
            {
                labels.add(LabelFormatter.format(new char [] []
                {
                    wordImage[mostFrequentOriginalWordIndex[rowToStemIndex.get(i)]]
                }, new boolean []
                {
                    false
                }, false));
            }
        }
        return labels;
    }

    /**
     * Splits the input documents into the specified number of partitions using the
     * standard k-means routine.
     */
    private List split(int partitions, DoubleMatrix2D input,
        IntArrayList columns, int iterations)
    {
        // Prepare selected matrix
        final DoubleMatrix2D selected = input.viewSelection(null, columns.toArray())
            .copy();
        final IntIntMap selectedToInput = new IntIntHashMap(selected.columns());
        for (int i = 0; i < columns.size(); i++)
        {
            selectedToInput.put(i, columns.get(i));
        }

        // Prepare results holders
        List result = Lists.newArrayList();
        List previousResult = null;
        for (int i = 0; i < partitions; i++)
        {
            result.add(new IntArrayList(selected.columns()));
        }
        for (int i = 0; i < selected.columns(); i++) 
        {
            result.get(i % partitions).add(i);
        }

        // Matrices for centroids and document-centroid similarities
        final DoubleMatrix2D centroids = new DenseDoubleMatrix2D(selected.rows(),
            partitions).assign(selected.viewPart(0, 0, selected.rows(), partitions));
        final DoubleMatrix2D similarities = new DenseDoubleMatrix2D(partitions,
            selected.columns());

        // Run a fixed number of K-means iterations
        for (int it = 0; it < iterations; it++)
        {
            // Update centroids
            for (int i = 0; i < result.size(); i++)
            {
                final IntArrayList cluster = result.get(i);
                for (int k = 0; k < selected.rows(); k++)
                {
                    double sum = 0;
                    for (int j = 0; j < cluster.size(); j++)
                    {
                        sum += selected.get(k, cluster.get(j));
                    }
                    centroids.setQuick(k, i, sum / cluster.size());
                }
            }

            previousResult = result;
            result = Lists.newArrayList();
            for (int i = 0; i < partitions; i++)
            {
                result.add(new IntArrayList(selected.columns()));
            }

            // Calculate similarity to centroids
            centroids.zMult(selected, similarities, 1, 0, true, false);

            // Assign documents to the nearest centroid
            for (int c = 0; c < similarities.columns(); c++)
            {
                int maxRow = 0;
                double max = similarities.get(0, c);
                for (int r = 1; r < similarities.rows(); r++)
                {
                    if (max < similarities.get(r, c))
                    {
                        max = similarities.get(r, c);
                        maxRow = r;
                    }
                }

                result.get(maxRow).add(c);
            }

            if (ObjectUtils.equals(previousResult, result))
            {
                // Unchanged result
                break;
            }
        }

        // Map the results back to the global indices
        for (Iterator it = result.iterator(); it.hasNext();)
        {
            final IntArrayList cluster = it.next();
            if (cluster.isEmpty())
            {
                it.remove();
            }
            else
            {
                for (int j = 0; j < cluster.size(); j++)
                {
                    cluster.set(j, selectedToInput.get(cluster.get(j)));
                }
            }
        }

        return result;
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy