All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ciir.umass.edu.learning.tree.RegressionTree Maven / Gradle / Ivy

There is a newer version: 2.10.1
Show newest version
/*===============================================================================
 * Copyright (c) 2010-2012 University of Massachusetts.  All Rights Reserved.
 *
 * Use of the RankLib package is subject to the terms of the software license set
 * forth in the LICENSE file included with this software, and also available at
 * http://people.cs.umass.edu/~vdang/ranklib_license.html
 *===============================================================================
 */

package ciir.umass.edu.learning.tree;

import java.util.ArrayList;
import java.util.List;

import ciir.umass.edu.learning.DataPoint;

/**
 * @author vdang
 */
public class RegressionTree {

    //Parameters
    protected int nodes = 10;//-1 for unlimited number of nodes (the size of the tree will then be controlled *ONLY* by minLeafSupport)
    protected int minLeafSupport = 1;

    //Member variables and functions
    protected Split root = null;
    protected List leaves = null;

    protected DataPoint[] trainingSamples = null;
    protected double[] trainingLabels = null;
    protected int[] features = null;
    protected float[][] thresholds = null;
    protected int[] index = null;
    protected FeatureHistogram hist = null;

    public RegressionTree(final Split root) {
        this.root = root;
        leaves = root.leaves();
    }

    public RegressionTree(final int nLeaves, final DataPoint[] trainingSamples, final double[] labels, final FeatureHistogram hist,
            final int minLeafSupport) {
        this.nodes = nLeaves;
        this.trainingSamples = trainingSamples;
        this.trainingLabels = labels;
        this.hist = hist;
        this.minLeafSupport = minLeafSupport;
        index = new int[trainingSamples.length];
        for (int i = 0; i < trainingSamples.length; i++) {
            index[i] = i;
        }
    }

    /**
     * Fit the tree from the specified training data
     */
    public void fit() {
        final List queue = new ArrayList<>();
        root = new Split(index, hist, Float.MAX_VALUE, 0);
        root.setRoot(true);

        // Ensure inserts occur only after successful splits
        if (root.split(trainingLabels, minLeafSupport)) {
            insert(queue, root.getLeft());
            insert(queue, root.getRight());
        }

        int taken = 0;
        while ((nodes == -1 || taken + queue.size() < nodes) && queue.size() > 0) {
            final Split leaf = queue.get(0);
            queue.remove(0);

            if (leaf.getSamples().length < 2 * minLeafSupport) {
                taken++;
                continue;
            }

            if (!leaf.split(trainingLabels, minLeafSupport)) {
                taken++;
            } else {
                insert(queue, leaf.getLeft());
                insert(queue, leaf.getRight());
            }
        }
        leaves = root.leaves();
    }

    /**
     * Get the tree output for the input sample
     * @param dp
     * @return
     */
    public double eval(final DataPoint dp) {
        return root.eval(dp);
    }

    /**
     * Retrieve all leave nodes in the tree
     * @return
     */
    public List leaves() {
        return leaves;
    }

    /**
     * Clear samples associated with each leaves (when they are no longer necessary) in order to save memory
     */
    public void clearSamples() {
        trainingSamples = null;
        trainingLabels = null;
        features = null;
        thresholds = null;
        index = null;
        hist = null;
        for (int i = 0; i < leaves.size(); i++) {
            leaves.get(i).clearSamples();
        }
    }

    /**
     * Generate the string representation of the tree
     */
    @Override
    public String toString() {
        if (root != null) {
            return root.toString();
        }
        return "";
    }

    public String toString(final String indent) {
        if (root != null) {
            return root.toString(indent);
        }
        return "";
    }

    public double variance() {
        double var = 0;
        for (int i = 0; i < leaves.size(); i++) {
            var += leaves.get(i).getDeviance();
        }
        return var;
    }

    protected void insert(final List ls, final Split s) {
        int i = 0;
        while (i < ls.size()) {
            if (ls.get(i).getDeviance() > s.getDeviance()) {
                i++;
            } else {
                break;
            }
        }
        ls.add(i, s);
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy