All Downloads are FREE. Search and download functionalities are using the official Maven repository.

smile.base.cart.RegressionNode Maven / Gradle / Ivy

There is a newer version: 4.0.0
Show newest version
/*
 * Copyright (c) 2010-2021 Haifeng Li. All rights reserved.
 *
 * Smile is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * Smile is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with Smile.  If not, see .
 */

package smile.base.cart;

import java.math.BigInteger;
import java.util.List;
import smile.data.type.StructField;
import smile.data.type.StructType;
import smile.math.MathEx;

/**
 * A leaf node in regression tree.
 *
 * @author Haifeng Li
 */
public class RegressionNode extends LeafNode {
    private static final long serialVersionUID = 2L;

    /** The mean of response variable. */
    private final double mean;

    /**
     * The predicted output. In standard regression tree,
     * this is same as the mean. However, in gradient tree
     * boosting, this may be different.
     */
    private final double output;

    /** The residual sum of squares. */
    private final double rss;

    /**
     * Constructor.
     *
     * @param size the number of samples in the node
     * @param output the predicted value for this node.
     * @param mean the mean of response variable.
     * @param rss the residual sum of squares.
     */
    public RegressionNode(int size, double output, double mean, double rss) {
        super(size);
        this.output = output;
        this.mean = mean;
        this.rss = rss;
    }

    /**
     * Returns the predicted value.
     * @return the predicted value.
     */
    public double output() {
        return output;
    }

    /**
     * Returns the mean of response variable.
     * @return the mean of response variable.
     */
    public double mean() {
        return mean;
    }

    /**
     * Returns the residual sum of squares.
     * @return the residual sum of squares.
     */
    public double impurity() {
        return rss;
    }

    @Override
    public double deviance() {
        return rss;
    }

    @Override
    public String dot(StructType schema, StructField response, int id) {
        return String.format(" %d [label=<%s = %.4f
size = %d
deviance = %.4f>, fillcolor=\"#00000000\", shape=ellipse];\n", id, response.name, output, size, rss); } @Override public int[] toString(StructType schema, StructField response, InternalNode parent, int depth, BigInteger id, List lines) { StringBuilder line = new StringBuilder(); // indent for (int i = 0; i < depth; i++) line.append(" "); line.append(id).append(") "); // split line.append(parent == null ? "root" : parent.toString(schema, this == parent.trueChild)).append(" "); // size line.append(size).append(" "); // deviance line.append(String.format("%.5g", rss)).append(" "); // fitted value line.append(String.format("%g", output)); // terminal node line.append(" *"); lines.add(line.toString()); return new int[]{size}; } @Override public boolean equals(Object o) { if (o instanceof RegressionNode) { RegressionNode a = (RegressionNode) o; return MathEx.equals(output, a.output); } return false; } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy