All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.ignite.ml.tree.DecisionTree Maven / Gradle / Ivy

Go to download

Apache Ignite® is a Distributed Database For High-Performance Computing With In-Memory Speed.

There is a newer version: 2.15.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.ignite.ml.tree;

import java.io.Serializable;
import java.util.Arrays;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.trainers.DatasetTrainer;
import org.apache.ignite.ml.tree.data.DecisionTreeData;
import org.apache.ignite.ml.tree.data.DecisionTreeDataBuilder;
import org.apache.ignite.ml.tree.impurity.ImpurityMeasure;
import org.apache.ignite.ml.tree.impurity.ImpurityMeasureCalculator;
import org.apache.ignite.ml.tree.impurity.util.StepFunction;
import org.apache.ignite.ml.tree.impurity.util.StepFunctionCompressor;
import org.apache.ignite.ml.tree.leaf.DecisionTreeLeafBuilder;

/**
 * Distributed decision tree trainer that allows to fit trees using row-partitioned dataset.
 *
 * @param  Type of impurity measure.
 */
abstract class DecisionTree> implements DatasetTrainer {
    /** Max tree deep. */
    private final int maxDeep;

    /** Min impurity decrease. */
    private final double minImpurityDecrease;

    /** Step function compressor. */
    private final StepFunctionCompressor compressor;

    /** Decision tree leaf builder. */
    private final DecisionTreeLeafBuilder decisionTreeLeafBuilder;

    /**
     * Constructs a new distributed decision tree trainer.
     *
     * @param maxDeep Max tree deep.
     * @param minImpurityDecrease Min impurity decrease.
     * @param compressor Impurity function compressor.
     * @param decisionTreeLeafBuilder Decision tree leaf builder.
     */
    DecisionTree(int maxDeep, double minImpurityDecrease, StepFunctionCompressor compressor, DecisionTreeLeafBuilder decisionTreeLeafBuilder) {
        this.maxDeep = maxDeep;
        this.minImpurityDecrease = minImpurityDecrease;
        this.compressor = compressor;
        this.decisionTreeLeafBuilder = decisionTreeLeafBuilder;
    }

    /** {@inheritDoc} */
    @Override public  DecisionTreeNode fit(DatasetBuilder datasetBuilder,
        IgniteBiFunction featureExtractor, IgniteBiFunction lbExtractor) {
        try (Dataset dataset = datasetBuilder.build(
            new EmptyContextBuilder<>(),
            new DecisionTreeDataBuilder<>(featureExtractor, lbExtractor)
        )) {
            return split(dataset, e -> true, 0, getImpurityMeasureCalculator(dataset));
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * Returns impurity measure calculator.
     *
     * @param dataset Dataset.
     * @return Impurity measure calculator.
     */
    abstract ImpurityMeasureCalculator getImpurityMeasureCalculator(Dataset dataset);

    /**
     * Splits the node specified by the given dataset and predicate and returns decision tree node.
     *
     * @param dataset Dataset.
     * @param filter Decision tree node predicate.
     * @param deep Current tree deep.
     * @param impurityCalc Impurity measure calculator.
     * @return Decision tree node.
     */
    private DecisionTreeNode split(Dataset dataset, TreeFilter filter, int deep,
        ImpurityMeasureCalculator impurityCalc) {
        if (deep >= maxDeep)
            return decisionTreeLeafBuilder.createLeafNode(dataset, filter);

        StepFunction[] criterionFunctions = calculateImpurityForAllColumns(dataset, filter, impurityCalc);

        if (criterionFunctions == null)
            return decisionTreeLeafBuilder.createLeafNode(dataset, filter);

        SplitPoint splitPnt = calculateBestSplitPoint(criterionFunctions);

        if (splitPnt == null)
            return decisionTreeLeafBuilder.createLeafNode(dataset, filter);

        return new DecisionTreeConditionalNode(
            splitPnt.col,
            splitPnt.threshold,
            split(dataset, updatePredicateForThenNode(filter, splitPnt), deep + 1, impurityCalc),
            split(dataset, updatePredicateForElseNode(filter, splitPnt), deep + 1, impurityCalc)
        );
    }

    /**
     * Calculates impurity measure functions for all columns for the node specified by the given dataset and predicate.
     *
     * @param dataset Dataset.
     * @param filter Decision tree node predicate.
     * @param impurityCalc Impurity measure calculator.
     * @return Array of impurity measure functions for all columns.
     */
    private StepFunction[] calculateImpurityForAllColumns(Dataset dataset,
        TreeFilter filter, ImpurityMeasureCalculator impurityCalc) {
        return dataset.compute(
            part -> {
                if (compressor != null)
                    return compressor.compress(impurityCalc.calculate(part.filter(filter)));
                else
                    return impurityCalc.calculate(part.filter(filter));
            }, this::reduce
        );
    }

    /**
     * Calculates best split point.
     *
     * @param criterionFunctions Array of impurity measure functions for all columns.
     * @return Best split point.
     */
    private SplitPoint calculateBestSplitPoint(StepFunction[] criterionFunctions) {
        SplitPoint res = null;

        for (int col = 0; col < criterionFunctions.length; col++) {
            StepFunction criterionFunctionForCol = criterionFunctions[col];

            double[] arguments = criterionFunctionForCol.getX();
            T[] values = criterionFunctionForCol.getY();

            for (int leftSize = 1; leftSize < values.length - 1; leftSize++) {
                if ((values[0].impurity() - values[leftSize].impurity()) > minImpurityDecrease
                    && (res == null || values[leftSize].compareTo(res.val) < 0))
                    res = new SplitPoint<>(values[leftSize], col, calculateThreshold(arguments, leftSize));
            }
        }

        return res;
    }

    /**
     * Merges two arrays gotten from two partitions.
     *
     * @param a First step function.
     * @param b Second step function.
     * @return Merged step function.
     */
    private StepFunction[] reduce(StepFunction[] a, StepFunction[] b) {
        if (a == null)
            return b;
        if (b == null)
            return a;
        else {
            StepFunction[] res = Arrays.copyOf(a, a.length);

            for (int i = 0; i < res.length; i++)
                res[i] = res[i].add(b[i]);

            return res;
        }
    }

    /**
     * Calculates threshold based on the given step function arguments and split point (specified left size).
     *
     * @param arguments Step function arguments.
     * @param leftSize Split point (left size).
     * @return Threshold.
     */
    private double calculateThreshold(double[] arguments, int leftSize) {
        return (arguments[leftSize] + arguments[leftSize + 1]) / 2.0;
    }

    /**
     * Constructs a new predicate for "then" node based on the parent node predicate and split point.
     *
     * @param filter Parent node predicate.
     * @param splitPnt Split point.
     * @return Predicate for "then" node.
     */
    private TreeFilter updatePredicateForThenNode(TreeFilter filter, SplitPoint splitPnt) {
        return filter.and(f -> f[splitPnt.col] > splitPnt.threshold);
    }

    /**
     * Constructs a new predicate for "else" node based on the parent node predicate and split point.
     *
     * @param filter Parent node predicate.
     * @param splitPnt Split point.
     * @return Predicate for "else" node.
     */
    private TreeFilter updatePredicateForElseNode(TreeFilter filter, SplitPoint splitPnt) {
        return filter.and(f -> f[splitPnt.col] <= splitPnt.threshold);
    }

    /**
     * Util class that represents split point.
     */
    private static class SplitPoint> implements Serializable {
        /** */
        private static final long serialVersionUID = -1758525953544425043L;

        /** Split point impurity measure value. */
        private final T val;

        /** Column. */
        private final int col;

        /** Threshold. */
        private final double threshold;

        /**
         * Constructs a new instance of split point.
         *
         * @param val Split point impurity measure value.
         * @param col Column.
         * @param threshold Threshold.
         */
        SplitPoint(T val, int col, double threshold) {
            this.val = val;
            this.col = col;
            this.threshold = threshold;
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy