All Downloads are FREE. Search and download functionalities are using the official Maven repository.

biz.k11i.xgboost.gbm.Dart Maven / Gradle / Ivy

The newest version!
package biz.k11i.xgboost.gbm;

import biz.k11i.xgboost.tree.RegTree;
import biz.k11i.xgboost.util.FVec;
import biz.k11i.xgboost.util.ModelReader;

import java.io.IOException;

/**
 * Gradient boosted DART tree implementation.
 */
public class Dart extends GBTree {
    private float[] weightDrop;

    Dart() {
        // do nothing
    }

    @Override
    public void loadModel(ModelReader reader, boolean with_pbuffer) throws IOException {
        super.loadModel(reader, with_pbuffer);
        if (mparam.num_trees != 0) {
            long size = reader.readLong();
            weightDrop = reader.readFloatArray((int)size);
        }
    }

    double pred(FVec feat, int bst_group, int root_index, int ntree_limit) {
        RegTree[] trees = _groupTrees[bst_group];
        int treeleft = ntree_limit == 0 ? trees.length : Math.min(ntree_limit, trees.length);

        double psum = 0;
        for (int i = 0; i < treeleft; i++) {
            psum += weightDrop[i] * trees[i].getLeafValue(feat, root_index);
        }

        return psum;
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy