All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.tree.gbm.GBMModel Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package hex.tree.gbm;

import hex.Model;
import hex.schemas.GBMModelV2;
import hex.tree.*;
import java.util.Arrays;
import water.H2O;
import water.Key;
import water.api.ModelSchema;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.util.ArrayUtils;

public class GBMModel extends SharedTreeModel {

  public static class GBMParameters extends SharedTreeModel.SharedTreeParameters {
    /** Distribution functions.  Note: AUTO will select gaussian for
     *  continuous, and multinomial for categorical response
     *
     *  

TODO: Replace with drop-down that displays different distributions * depending on cont/cat response */ public enum Family { AUTO, bernoulli } public Family _loss = Family.AUTO; public float _learn_rate=0.1f; // Learning rate from 0.0 to 1.0 } public static class GBMOutput extends SharedTreeModel.SharedTreeOutput { public GBMOutput( GBM b ) { super(b); } } public GBMModel(Key selfKey, GBMParameters parms, GBMOutput output ) { super(selfKey,parms,output); } // Default publically visible Schema is V2 @Override public ModelSchema schema() { return new GBMModelV2(); } /** Bulk scoring API for one row. Chunks are all compatible with the model, * and expect the last Chunks are for the final distribution and prediction. * Default method is to just load the data into the tmp array, then call * subclass scoring logic. */ @Override protected float[] score0( Chunk chks[], int row_in_chunk, double[] tmp, float[] preds ) { assert chks.length>=tmp.length; for( int i=0; i1 ) { // classification // Because we call Math.exp, we have to be numerically stable or else // we get Infinities, and then shortly NaN's. Rescale the data so the // largest value is +/-1 and the other values are smaller. // See notes here: http://www.hongliangjie.com/2011/01/07/logsum/ float maxval=Float.NEGATIVE_INFINITY; float dsum=0; if( _output.nclasses()==2 ) p[2] = - p[1]; // Find a max for( int k=1; k





© 2015 - 2025 Weber Informatics LLC | Privacy Policy