All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.ignite.ml.clustering.kmeans.KMeansTrainer Maven / Gradle / Ivy

Go to download

Apache Ignite® is a Distributed Database For High-Performance Computing With In-Memory Speed.

There is a newer version: 2.15.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.ignite.ml.clustering.kmeans;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.ignite.lang.IgniteBiTuple;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.PartitionDataBuilder;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.math.Vector;
import org.apache.ignite.ml.math.VectorUtils;
import org.apache.ignite.ml.math.distances.DistanceMeasure;
import org.apache.ignite.ml.math.distances.EuclideanDistance;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
import org.apache.ignite.ml.math.util.MapUtil;
import org.apache.ignite.ml.structures.LabeledDataset;
import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.structures.partition.LabeledDatasetPartitionDataBuilderOnHeap;
import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;

/**
 * The trainer for KMeans algorithm.
 */
public class KMeansTrainer implements SingleLabelDatasetTrainer {
    /** Amount of clusters. */
    private int k = 2;

    /** Amount of iterations. */
    private int maxIterations = 10;

    /** Delta of convergence. */
    private double epsilon = 1e-4;

    /** Distance measure. */
    private DistanceMeasure distance = new EuclideanDistance();

    /** KMeans initializer. */
    private long seed;

    /**
     * Trains model based on the specified data.
     *
     * @param datasetBuilder Dataset builder.
     * @param featureExtractor Feature extractor.
     * @param lbExtractor Label extractor.
     * @return Model.
     */
    @Override public  KMeansModel fit(DatasetBuilder datasetBuilder,
        IgniteBiFunction featureExtractor, IgniteBiFunction lbExtractor) {
        assert datasetBuilder != null;

        PartitionDataBuilder> partDataBuilder = new LabeledDatasetPartitionDataBuilderOnHeap<>(
            featureExtractor,
            lbExtractor
        );

        Vector[] centers;

        try (Dataset> dataset = datasetBuilder.build(
            (upstream, upstreamSize) -> new EmptyContext(),
            partDataBuilder
        )) {
            final int cols = dataset.compute(org.apache.ignite.ml.structures.Dataset::colSize, (a, b) -> a == null ? b : a);
            centers = initClusterCentersRandomly(dataset, k);

            boolean converged = false;
            int iteration = 0;

            while (iteration < maxIterations && !converged) {
                Vector[] newCentroids = new DenseLocalOnHeapVector[k];

                TotalCostAndCounts totalRes = calcDataForNewCentroids(centers, dataset, cols);

                converged = true;

                for (Integer ind : totalRes.sums.keySet()) {
                    Vector massCenter = totalRes.sums.get(ind).times(1.0 / totalRes.counts.get(ind));

                    if (converged && distance.compute(massCenter, centers[ind]) > epsilon * epsilon)
                        converged = false;

                    newCentroids[ind] = massCenter;
                }

                iteration++;
                centers = newCentroids;
            }
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        return new KMeansModel(centers, distance);
    }

    /**
     * Prepares the data to define new centroids on current iteration.
     *
     * @param centers Current centers on the current iteration.
     * @param dataset Dataset.
     * @param cols Amount of columns.
     * @return Helper data to calculate the new centroids.
     */
    private TotalCostAndCounts calcDataForNewCentroids(Vector[] centers,
        Dataset> dataset, int cols) {
        final Vector[] finalCenters = centers;

        return dataset.compute(data -> {

            TotalCostAndCounts res = new TotalCostAndCounts();

            for (int i = 0; i < data.rowSize(); i++) {
                final IgniteBiTuple closestCentroid = findClosestCentroid(finalCenters, data.getRow(i));

                int centroidIdx = closestCentroid.get1();

                data.setLabel(i, centroidIdx);

                res.totalCost += closestCentroid.get2();
                res.sums.putIfAbsent(centroidIdx, VectorUtils.zeroes(cols));

                int finalI = i;
                res.sums.compute(centroidIdx,
                    (IgniteBiFunction)(ind, v) -> v.plus(data.getRow(finalI).features()));

                res.counts.merge(centroidIdx, 1,
                    (IgniteBiFunction)(i1, i2) -> i1 + i2);
            }
            return res;
        }, (a, b) -> a == null ? b : a.merge(b));
    }

    /**
     * Find the closest cluster center index and distance to it from a given point.
     *
     * @param centers Centers to look in.
     * @param pnt Point.
     */
    private IgniteBiTuple findClosestCentroid(Vector[] centers, LabeledVector pnt) {
        double bestDistance = Double.POSITIVE_INFINITY;
        int bestInd = 0;

        for (int i = 0; i < centers.length; i++) {
            double dist = distance.compute(centers[i], pnt.features());
            if (dist < bestDistance) {
                bestDistance = dist;
                bestInd = i;
            }
        }
        return new IgniteBiTuple<>(bestInd, bestDistance);
    }

    /**
     * K cluster centers are initialized randomly.
     *
     * @param dataset The dataset to pick up random centers.
     * @param k Amount of clusters.
     * @return K cluster centers.
     */
    private Vector[] initClusterCentersRandomly(Dataset> dataset,
        int k) {

        Vector[] initCenters = new DenseLocalOnHeapVector[k];

        List rndPnts = dataset.compute(data -> {
            List rndPnt = new ArrayList<>();
            rndPnt.add(data.getRow(new Random(seed).nextInt(data.rowSize())));
            return rndPnt;
        }, (a, b) -> a == null ? b : Stream.concat(a.stream(), b.stream()).collect(Collectors.toList()));

        for (int i = 0; i < k; i++) {
            final LabeledVector rndPnt = rndPnts.get(new Random(seed).nextInt(rndPnts.size()));
            rndPnts.remove(rndPnt);
            initCenters[i] = rndPnt.features();
        }

        return initCenters;
    }

    /** Service class used for statistics. */
    private static class TotalCostAndCounts {
        /** */
        double totalCost;

        /** */
        ConcurrentHashMap sums = new ConcurrentHashMap<>();

        /** Count of points closest to the center with a given index. */
        ConcurrentHashMap counts = new ConcurrentHashMap<>();

        /** Merge current */
        TotalCostAndCounts merge(TotalCostAndCounts other) {
            this.totalCost += totalCost;
            this.sums = MapUtil.mergeMaps(sums, other.sums, Vector::plus, ConcurrentHashMap::new);
            this.counts = MapUtil.mergeMaps(counts, other.counts, (i1, i2) -> i1 + i2, ConcurrentHashMap::new);
            return this;
        }
    }

    /**
     * Gets the amount of clusters.
     *
     * @return The parameter value.
     */
    public int getK() {
        return k;
    }

    /**
     * Set up the amount of clusters.
     *
     * @param k The parameter value.
     * @return Model with new amount of clusters parameter value.
     */
    public KMeansTrainer withK(int k) {
        this.k = k;
        return this;
    }

    /**
     * Gets the max number of iterations before convergence.
     *
     * @return The parameter value.
     */
    public int getMaxIterations() {
        return maxIterations;
    }

    /**
     * Set up the max number of iterations before convergence.
     *
     * @param maxIterations The parameter value.
     * @return Model with new max number of iterations before convergence parameter value.
     */
    public KMeansTrainer withMaxIterations(int maxIterations) {
        this.maxIterations = maxIterations;
        return this;
    }

    /**
     * Gets the epsilon.
     *
     * @return The parameter value.
     */
    public double getEpsilon() {
        return epsilon;
    }

    /**
     * Set up the epsilon.
     *
     * @param epsilon The parameter value.
     * @return Model with new epsilon parameter value.
     */
    public KMeansTrainer withEpsilon(double epsilon) {
        this.epsilon = epsilon;
        return this;
    }

    /**
     * Gets the distance.
     *
     * @return The parameter value.
     */
    public DistanceMeasure getDistance() {
        return distance;
    }

    /**
     * Set up the distance.
     *
     * @param distance The parameter value.
     * @return Model with new distance parameter value.
     */
    public KMeansTrainer withDistance(DistanceMeasure distance) {
        this.distance = distance;
        return this;
    }

    /**
     * Gets the seed number.
     *
     * @return The parameter value.
     */
    public long getSeed() {
        return seed;
    }

    /**
     * Set up the seed.
     *
     * @param seed The parameter value.
     * @return Model with new seed parameter value.
     */
    public KMeansTrainer withSeed(long seed) {
        this.seed = seed;
        return this;
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy