All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.lucene.sandbox.codecs.quantization.KMeans Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.lucene.sandbox.codecs.quantization;

import static org.apache.lucene.sandbox.codecs.quantization.SampleReader.createSampleReader;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Random;
import java.util.Set;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.NeighborQueue;

/** KMeans clustering algorithm for vectors */
public class KMeans {
  public static final int MAX_NUM_CENTROIDS = Short.MAX_VALUE; // 32767
  public static final int DEFAULT_RESTARTS = 5;
  public static final int DEFAULT_ITRS = 10;
  public static final int DEFAULT_SAMPLE_SIZE = 100_000;

  private final FloatVectorValues vectors;
  private final int numVectors;
  private final int numCentroids;
  private final Random random;
  private final KmeansInitializationMethod initializationMethod;
  private final int restarts;
  private final int iters;

  /**
   * Cluster vectors into a given number of clusters
   *
   * @param vectors float vectors
   * @param similarityFunction vector similarity function. For COSINE similarity, vectors must be
   *     normalized.
   * @param numClusters number of cluster to cluster vector into
   * @return results of clustering: produced centroids and for each vector its centroid
   * @throws IOException when if there is an error accessing vectors
   */
  public static Results cluster(
      FloatVectorValues vectors, VectorSimilarityFunction similarityFunction, int numClusters)
      throws IOException {
    return cluster(
        vectors,
        numClusters,
        true,
        42L,
        KmeansInitializationMethod.PLUS_PLUS,
        similarityFunction == VectorSimilarityFunction.COSINE,
        DEFAULT_RESTARTS,
        DEFAULT_ITRS,
        DEFAULT_SAMPLE_SIZE);
  }

  /**
   * Expert: Cluster vectors into a given number of clusters
   *
   * @param vectors float vectors
   * @param numClusters number of cluster to cluster vector into
   * @param assignCentroidsToVectors if {@code true} assign centroids for all vectors. Centroids are
   *     computed on a sample of vectors. If this parameter is {@code true}, in results also return
   *     for all vectors what centroids they belong to.
   * @param seed random seed
   * @param initializationMethod Kmeans initialization method
   * @param normalizeCenters for cosine distance, set to true, to use spherical k-means where
   *     centers are normalized
   * @param restarts how many times to run Kmeans algorithm
   * @param iters how many iterations to do within a single run
   * @param sampleSize sample size to select from all vectors on which to run Kmeans algorithm
   * @return results of clustering: produced centroids and if {@code assignCentroidsToVectors ==
   *     true} also for each vector its centroid
   * @throws IOException if there is error accessing vectors
   */
  public static Results cluster(
      FloatVectorValues vectors,
      int numClusters,
      boolean assignCentroidsToVectors,
      long seed,
      KmeansInitializationMethod initializationMethod,
      boolean normalizeCenters,
      int restarts,
      int iters,
      int sampleSize)
      throws IOException {
    if (vectors.size() == 0) {
      return null;
    }
    if (numClusters < 1 || numClusters > MAX_NUM_CENTROIDS) {
      throw new IllegalArgumentException(
          "[numClusters] must be between [1] and [" + MAX_NUM_CENTROIDS + "]");
    }
    // adjust sampleSize and numClusters
    sampleSize = Math.max(sampleSize, 100 * numClusters);
    if (sampleSize > vectors.size()) {
      sampleSize = vectors.size();
      // Decrease the number of clusters if needed
      int maxNumClusters = Math.max(1, sampleSize / 100);
      numClusters = Math.min(numClusters, maxNumClusters);
    }

    Random random = new Random(seed);
    float[][] centroids;
    if (numClusters == 1) {
      centroids = new float[1][vectors.dimension()];
    } else {
      FloatVectorValues sampleVectors =
          vectors.size() <= sampleSize ? vectors : createSampleReader(vectors, sampleSize, seed);
      KMeans kmeans =
          new KMeans(sampleVectors, numClusters, random, initializationMethod, restarts, iters);
      centroids = kmeans.computeCentroids(normalizeCenters);
    }

    short[] vectorCentroids = null;
    // Assign each vector to the nearest centroid and update the centres
    if (assignCentroidsToVectors) {
      vectorCentroids = new short[vectors.size()];
      // Use kahan summation to get more precise results
      KMeans.runKMeansStep(vectors, centroids, vectorCentroids, true, normalizeCenters);
    }
    return new Results(centroids, vectorCentroids);
  }

  private KMeans(
      FloatVectorValues vectors,
      int numCentroids,
      Random random,
      KmeansInitializationMethod initializationMethod,
      int restarts,
      int iters) {
    this.vectors = vectors;
    this.numVectors = vectors.size();
    this.numCentroids = numCentroids;
    this.random = random;
    this.initializationMethod = initializationMethod;
    this.restarts = restarts;
    this.iters = iters;
  }

  private float[][] computeCentroids(boolean normalizeCenters) throws IOException {
    short[] vectorCentroids = new short[numVectors];
    double minSquaredDist = Double.MAX_VALUE;
    double squaredDist = 0;
    float[][] bestCentroids = null;

    for (int restart = 0; restart < restarts; restart++) {
      float[][] centroids =
          switch (initializationMethod) {
            case FORGY -> initializeForgy();
            case RESERVOIR_SAMPLING -> initializeReservoirSampling();
            case PLUS_PLUS -> initializePlusPlus();
          };
      double prevSquaredDist = Double.MAX_VALUE;
      for (int iter = 0; iter < iters; iter++) {
        squaredDist = runKMeansStep(vectors, centroids, vectorCentroids, false, normalizeCenters);
        // Check for convergence
        if (prevSquaredDist <= (squaredDist + 1e-6)) {
          break;
        }
        prevSquaredDist = squaredDist;
      }
      if (squaredDist < minSquaredDist) {
        minSquaredDist = squaredDist;
        bestCentroids = centroids;
      }
    }
    return bestCentroids;
  }

  /**
   * Initialize centroids using Forgy method: randomly select numCentroids vectors for initial
   * centroids
   */
  private float[][] initializeForgy() throws IOException {
    Set selection = new HashSet<>();
    while (selection.size() < numCentroids) {
      selection.add(random.nextInt(numVectors));
    }
    float[][] initialCentroids = new float[numCentroids][];
    int i = 0;
    for (Integer selectedIdx : selection) {
      float[] vector = vectors.vectorValue(selectedIdx);
      initialCentroids[i++] = ArrayUtil.copyOfSubArray(vector, 0, vector.length);
    }
    return initialCentroids;
  }

  /** Initialize centroids using a reservoir sampling method */
  private float[][] initializeReservoirSampling() throws IOException {
    float[][] initialCentroids = new float[numCentroids][];
    for (int index = 0; index < numVectors; index++) {
      float[] vector = vectors.vectorValue(index);
      if (index < numCentroids) {
        initialCentroids[index] = ArrayUtil.copyOfSubArray(vector, 0, vector.length);
      } else if (random.nextDouble() < numCentroids * (1.0 / index)) {
        int c = random.nextInt(numCentroids);
        initialCentroids[c] = ArrayUtil.copyOfSubArray(vector, 0, vector.length);
      }
    }
    return initialCentroids;
  }

  /** Initialize centroids using Kmeans++ method */
  private float[][] initializePlusPlus() throws IOException {
    float[][] initialCentroids = new float[numCentroids][];
    // Choose the first centroid uniformly at random
    int firstIndex = random.nextInt(numVectors);
    float[] value = vectors.vectorValue(firstIndex);
    initialCentroids[0] = ArrayUtil.copyOfSubArray(value, 0, value.length);

    // Store distances of each point to the nearest centroid
    float[] minDistances = new float[numVectors];
    Arrays.fill(minDistances, Float.MAX_VALUE);

    // Step 2 and 3: Select remaining centroids
    for (int i = 1; i < numCentroids; i++) {
      // Update distances with the new centroid
      double totalSum = 0;
      for (int j = 0; j < numVectors; j++) {
        // TODO: replace with RandomVectorScorer::score possible on quantized vectors
        float dist = VectorUtil.squareDistance(vectors.vectorValue(j), initialCentroids[i - 1]);
        if (dist < minDistances[j]) {
          minDistances[j] = dist;
        }
        totalSum += minDistances[j];
      }

      // Randomly select next centroid
      double r = totalSum * random.nextDouble();
      double cumulativeSum = 0;
      int nextCentroidIndex = -1;
      for (int j = 0; j < numVectors; j++) {
        cumulativeSum += minDistances[j];
        if (cumulativeSum >= r && minDistances[j] > 0) {
          nextCentroidIndex = j;
          break;
        }
      }
      // Update centroid
      value = vectors.vectorValue(nextCentroidIndex);
      initialCentroids[i] = ArrayUtil.copyOfSubArray(value, 0, value.length);
    }
    return initialCentroids;
  }

  /**
   * Run kmeans step
   *
   * @param vectors float vectors
   * @param centroids centroids, new calculated centroids are written here
   * @param docCentroids for each document which centroid it belongs to, results will be written
   *     here
   * @param useKahanSummation for large datasets use Kahan summation to calculate centroids, since
   *     we can easily reach the limits of float precision
   * @param normalizeCentroids if centroids should be normalized; used for cosine similarity only
   * @throws IOException if there is an error accessing vector values
   */
  private static double runKMeansStep(
      FloatVectorValues vectors,
      float[][] centroids,
      short[] docCentroids,
      boolean useKahanSummation,
      boolean normalizeCentroids)
      throws IOException {
    short numCentroids = (short) centroids.length;

    float[][] newCentroids = new float[numCentroids][centroids[0].length];
    int[] newCentroidSize = new int[numCentroids];
    float[][] compensations = null;
    if (useKahanSummation) {
      compensations = new float[numCentroids][centroids[0].length];
    }

    double sumSquaredDist = 0;
    for (int docID = 0; docID < vectors.size(); docID++) {
      float[] vector = vectors.vectorValue(docID);
      short bestCentroid = 0;
      if (numCentroids > 1) {
        float minSquaredDist = Float.MAX_VALUE;
        for (short c = 0; c < numCentroids; c++) {
          // TODO: replace with RandomVectorScorer::score possible on quantized vectors
          float squareDist = VectorUtil.squareDistance(centroids[c], vector);
          if (squareDist < minSquaredDist) {
            bestCentroid = c;
            minSquaredDist = squareDist;
          }
        }
        sumSquaredDist += minSquaredDist;
      }

      newCentroidSize[bestCentroid] += 1;
      for (int dim = 0; dim < vector.length; dim++) {
        if (useKahanSummation) {
          float y = vector[dim] - compensations[bestCentroid][dim];
          float t = newCentroids[bestCentroid][dim] + y;
          compensations[bestCentroid][dim] = (t - newCentroids[bestCentroid][dim]) - y;
          newCentroids[bestCentroid][dim] = t;
        } else {
          newCentroids[bestCentroid][dim] += vector[dim];
        }
      }
      docCentroids[docID] = bestCentroid;
    }

    List unassignedCentroids = new ArrayList<>();
    for (int c = 0; c < numCentroids; c++) {
      if (newCentroidSize[c] > 0) {
        for (int dim = 0; dim < newCentroids[c].length; dim++) {
          centroids[c][dim] = newCentroids[c][dim] / newCentroidSize[c];
        }
      } else {
        unassignedCentroids.add(c);
      }
    }
    if (unassignedCentroids.size() > 0) {
      assignCentroids(vectors, centroids, unassignedCentroids);
    }
    if (normalizeCentroids) {
      for (int c = 0; c < centroids.length; c++) {
        VectorUtil.l2normalize(centroids[c], false);
      }
    }
    return sumSquaredDist;
  }

  /**
   * For centroids that did not get any points, assign outlying points to them chose points by
   * descending distance to the current centroid set
   */
  static void assignCentroids(
      FloatVectorValues vectors, float[][] centroids, List unassignedCentroidsIdxs)
      throws IOException {
    int[] assignedCentroidsIdxs = new int[centroids.length - unassignedCentroidsIdxs.size()];
    int assignedIndex = 0;
    for (int i = 0; i < centroids.length; i++) {
      if (unassignedCentroidsIdxs.contains(i) == false) {
        assignedCentroidsIdxs[assignedIndex++] = i;
      }
    }
    NeighborQueue queue = new NeighborQueue(unassignedCentroidsIdxs.size(), false);
    for (int i = 0; i < vectors.size(); i++) {
      float[] vector = vectors.vectorValue(i);
      for (short j = 0; j < assignedCentroidsIdxs.length; j++) {
        float squareDist = VectorUtil.squareDistance(centroids[assignedCentroidsIdxs[j]], vector);
        queue.insertWithOverflow(i, squareDist);
      }
    }
    for (int i = 0; i < unassignedCentroidsIdxs.size(); i++) {
      float[] vector = vectors.vectorValue(queue.topNode());
      int unassignedCentroidIdx = unassignedCentroidsIdxs.get(i);
      centroids[unassignedCentroidIdx] = ArrayUtil.copyArray(vector);
      queue.pop();
    }
  }

  /** Kmeans initialization methods */
  public enum KmeansInitializationMethod {
    FORGY,
    RESERVOIR_SAMPLING,
    PLUS_PLUS
  }

  /**
   * Results of KMeans clustering
   *
   * @param centroids the produced centroids
   * @param vectorCentroids for each vector which centroid it belongs to, we use short type, as we
   *     expect less than {@code MAX_NUM_CENTROIDS} which is equal to 32767 centroids. Can be {@code
   *     null} if they were not computed.
   */
  public record Results(float[][] centroids, short[] vectorCentroids) {}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy