All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.ignite.ml.composition.boosting.GDBTrainer Maven / Gradle / Ivy

Go to download

Apache Ignite® is a Distributed Database For High-Performance Computing With In-Memory Speed.

There is a newer version: 2.15.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.ignite.ml.composition.boosting;

import java.util.Arrays;
import java.util.List;
import org.apache.ignite.lang.IgniteBiTuple;
import org.apache.ignite.ml.Model;
import org.apache.ignite.ml.composition.ModelsComposition;
import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceCheckerFactory;
import org.apache.ignite.ml.composition.boosting.convergence.mean.MeanAbsValueConvergenceCheckerFactory;
import org.apache.ignite.ml.composition.boosting.loss.Loss;
import org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.environment.logging.MLLogger;
import org.apache.ignite.ml.knn.regression.KNNRegressionTrainer;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainer;
import org.apache.ignite.ml.regressions.linear.LinearRegressionSGDTrainer;
import org.apache.ignite.ml.trainers.DatasetTrainer;
import org.apache.ignite.ml.tree.DecisionTreeRegressionTrainer;
import org.apache.ignite.ml.tree.data.DecisionTreeData;
import org.apache.ignite.ml.tree.data.DecisionTreeDataBuilder;
import org.apache.ignite.ml.tree.randomforest.RandomForestRegressionTrainer;
import org.jetbrains.annotations.NotNull;

/**
 * Abstract Gradient Boosting trainer. It implements gradient descent in functional space using user-selected regressor
 * in child class. Each learning iteration the trainer evaluate gradient of error-function and fit regression model to
 * it. After learning step the model is used in models composition of regressions with weight equal to gradient descent
 * step.
 *
 * These classes can be used as regressor trainers: {@link DecisionTreeRegressionTrainer}, {@link KNNRegressionTrainer},
 * {@link LinearRegressionLSQRTrainer}, {@link RandomForestRegressionTrainer}, {@link LinearRegressionSGDTrainer}.
 *
 * But in practice Decision Trees is most used regressors (see: {@link DecisionTreeRegressionTrainer}).
 */
public abstract class GDBTrainer extends DatasetTrainer {
    /** Gradient step. */
    private final double gradientStep;

    /** Count of iterations. */
    private final int cntOfIterations;

    /**
     * Loss function.
     */
    protected final Loss loss;

    /** Check convergence strategy factory. */
    protected ConvergenceCheckerFactory checkConvergenceStgyFactory = new MeanAbsValueConvergenceCheckerFactory(0.001);

    /**
     * Constructs GDBTrainer instance.
     *
     * @param gradStepSize Grad step size.
     * @param cntOfIterations Count of learning iterations.
     * @param loss Gradient of loss function. First argument is sample size, second argument is valid answer
     * third argument is current model prediction.
     */
    public GDBTrainer(double gradStepSize, Integer cntOfIterations, Loss loss) {
        gradientStep = gradStepSize;
        this.cntOfIterations = cntOfIterations;
        this.loss = loss;
    }

    /** {@inheritDoc} */
    @Override public  ModelsComposition fit(DatasetBuilder datasetBuilder,
        IgniteBiFunction featureExtractor,
        IgniteBiFunction lbExtractor) {

        return updateModel(null, datasetBuilder, featureExtractor, lbExtractor);
    }

    /** {@inheritDoc} */
    @Override protected  ModelsComposition updateModel(ModelsComposition mdl, DatasetBuilder datasetBuilder,
        IgniteBiFunction featureExtractor, IgniteBiFunction lbExtractor) {

        if (!learnLabels(datasetBuilder, featureExtractor, lbExtractor))
            return getLastTrainedModelOrThrowEmptyDatasetException(mdl);

        IgniteBiTuple initAndSampleSize = computeInitialValue(datasetBuilder, featureExtractor, lbExtractor);
        if(initAndSampleSize == null)
            return getLastTrainedModelOrThrowEmptyDatasetException(mdl);

        Double mean = initAndSampleSize.get1();
        Long sampleSize = initAndSampleSize.get2();

        long learningStartTs = System.currentTimeMillis();

        GDBLearningStrategy stgy = getLearningStrategy()
            .withBaseModelTrainerBuilder(this::buildBaseModelTrainer)
            .withExternalLabelToInternal(this::externalLabelToInternal)
            .withCntOfIterations(cntOfIterations)
            .withEnvironment(environment)
            .withLossGradient(loss)
            .withSampleSize(sampleSize)
            .withMeanLabelValue(mean)
            .withDefaultGradStepSize(gradientStep)
            .withCheckConvergenceStgyFactory(checkConvergenceStgyFactory);

        List> models;
        if (mdl != null)
            models = stgy.update((GDBModel)mdl, datasetBuilder, featureExtractor, lbExtractor);
        else
            models = stgy.learnModels(datasetBuilder, featureExtractor, lbExtractor);

        double learningTime = (double)(System.currentTimeMillis() - learningStartTs) / 1000.0;
        environment.logger(getClass()).log(MLLogger.VerboseLevel.LOW, "The training time was %.2fs", learningTime);

        WeightedPredictionsAggregator resAggregator = new WeightedPredictionsAggregator(
            stgy.getCompositionWeights(),
            stgy.getMeanValue()
        );
        return new GDBModel(models, resAggregator, this::internalLabelToExternal);
    }

    /** {@inheritDoc} */
    @Override protected boolean checkState(ModelsComposition mdl) {
        return mdl instanceof GDBModel;
    }

    /**
     * Defines unique labels in dataset if need (useful in case of classification).
     *
     * @param builder Dataset builder.
     * @param featureExtractor Feature extractor.
     * @param lExtractor Labels extractor.
     * @return true if labels learning was successful.
     */
    protected abstract  boolean learnLabels(DatasetBuilder builder,
        IgniteBiFunction featureExtractor, IgniteBiFunction lExtractor);

    /**
     * Returns regressor model trainer for one step of GDB.
     */
    @NotNull
    protected abstract DatasetTrainer, Double> buildBaseModelTrainer();

    /**
     * Maps external representation of label to internal.
     *
     * @param lbl Label value.
     */
    protected abstract double externalLabelToInternal(double lbl);

    /**
     * Maps internal representation of label to external.
     *
     * @param lbl Label value.
     */
    protected abstract double internalLabelToExternal(double lbl);

    /**
     * Compute mean value of label as first approximation.
     *
     * @param builder Dataset builder.
     * @param featureExtractor Feature extractor.
     * @param lbExtractor Label extractor.
     */
    protected  IgniteBiTuple computeInitialValue(DatasetBuilder builder,
        IgniteBiFunction featureExtractor,
        IgniteBiFunction lbExtractor) {

        try (Dataset dataset = builder.build(
            new EmptyContextBuilder<>(),
            new DecisionTreeDataBuilder<>(featureExtractor, lbExtractor, false)
        )) {
            IgniteBiTuple meanTuple = dataset.compute(
                data -> {
                    double sum = Arrays.stream(data.getLabels()).map(this::externalLabelToInternal).sum();
                    return new IgniteBiTuple<>(sum, (long)data.getLabels().length);
                },
                (a, b) -> {
                    if (a == null)
                        return b;
                    if (b == null)
                        return a;

                    a.set1(a.get1() + b.get1());
                    a.set2(a.get2() + b.get2());
                    return a;
                }
            );

            if (meanTuple != null)
                meanTuple.set1(meanTuple.get1() / meanTuple.get2());
            return meanTuple;
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * Sets CheckConvergenceStgyFactory.
     *
     * @param factory Factory.
     * @return trainer.
     */
    public GDBTrainer withCheckConvergenceStgyFactory(ConvergenceCheckerFactory factory) {
        this.checkConvergenceStgyFactory = factory;
        return this;
    }

    /**
     * Returns learning strategy.
     *
     * @return learning strategy.
     */
    protected GDBLearningStrategy getLearningStrategy() {
        return new GDBLearningStrategy();
    }

    /**
     * GDB model.
     */
    public static class GDBModel extends ModelsComposition {
        /** Serial version uid. */
        private static final long serialVersionUID = 3476661240155508004L;

        /** Internal to external lbl mapping. */
        private final IgniteFunction internalToExternalLblMapping;

        /**
         * Creates an instance of GDBModel.
         *
         * @param models Models.
         * @param predictionsAggregator Predictions aggregator.
         * @param internalToExternalLblMapping Internal to external lbl mapping.
         */
        public GDBModel(List> models,
            WeightedPredictionsAggregator predictionsAggregator,
            IgniteFunction internalToExternalLblMapping) {

            super(models, predictionsAggregator);
            this.internalToExternalLblMapping = internalToExternalLblMapping;
        }

        /** {@inheritDoc} */
        @Override public Double apply(Vector features) {
            return internalToExternalLblMapping.apply(super.apply(features));
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy