All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.genmodel.algos.rulefit.RuleFitMojoModel Maven / Gradle / Ivy

There is a newer version: 3.46.0.5
Show newest version
package hex.genmodel.algos.rulefit;

import hex.genmodel.MojoModel;
import java.util.Arrays;
import java.util.List;

public class RuleFitMojoModel extends MojoModel {

  public enum ModelType {LINEAR, RULES_AND_LINEAR, RULES}
  
  MojoModel _linearModel;
  MojoRuleEnsemble _ruleEnsemble;
  int _depth;
  int _ntrees;
  ModelType _modelType;
  String[] _dataFromRulesCodes;
  String _weightsColumn;
  String[] _linearNames;
  
  RuleFitMojoModel(String[] columns, String[][] domains, String responseColumn) {
    super(columns, domains, responseColumn);
  }

  @Override
  public double[] score0(double[] row, double[] preds) {
    double[] linearFromRules = null;
    int testsize = 0;
    if (!_modelType.equals(ModelType.LINEAR)) {
      linearFromRules = _ruleEnsemble.transformRow(row, _depth, _ntrees, _linearModel._names, _linearModel._domains, this._domains[  Arrays.asList(this._names).indexOf(this._responseColumn)]);
      testsize += linearFromRules.length;
      if (_modelType.equals(ModelType.RULES_AND_LINEAR)) {
        testsize += row.length;
      }
    }
    double[] test = new double[testsize];
    if (_modelType.equals(ModelType.RULES_AND_LINEAR) || _modelType.equals(ModelType.RULES)) {
      System.arraycopy(linearFromRules, 0, test, 0, linearFromRules.length);
    }
    if (_modelType.equals(ModelType.RULES_AND_LINEAR)) {
      System.arraycopy(row, 0, test, linearFromRules.length, row.length);
    }
    if (_modelType.equals(ModelType.LINEAR)) {
      test = row;
    }
    double[] linearModelInput = map(test);
    
    _linearModel.score0(linearModelInput, preds);
    
    
    return preds;
  }
  
  double[] map(double[] test) {
    double[] newtest = new double[_linearModel.nfeatures()];
    List list = Arrays.asList(_linearModel._names);
    for (int i = 0; i < _linearModel.nfeatures(); i++) {
      int id = list.indexOf(_linearNames[i]);
      newtest[id] = test[i];
    }
    return newtest;
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy