All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.genmodel.algos.rulefit.RuleFitMojoReader Maven / Gradle / Ivy

There is a newer version: 3.46.0.5
Show newest version
package hex.genmodel.algos.rulefit;

import hex.genmodel.MultiModelMojoReader;
import java.io.IOException;

public class RuleFitMojoReader extends MultiModelMojoReader {

  @Override
  protected void readParentModelData() throws IOException {
    _model._linearModel = getModel((String) readkv("linear_model"));
    int modelType = readkv("model_type");
    if (modelType == 0) {
      _model._modelType = RuleFitMojoModel.ModelType.LINEAR;
    } else if (modelType == 1) {
      _model._modelType = RuleFitMojoModel.ModelType.RULES_AND_LINEAR;
    } else {
      _model._modelType = RuleFitMojoModel.ModelType.RULES;
    }
    
    _model._depth = readkv("depth");
    _model._ntrees = readkv("ntrees");

    if (!_model._modelType.equals(RuleFitMojoModel.ModelType.LINEAR)) {
      _model._ruleEnsemble = readRuleEnseble();
    }
    
    int len = readkv("data_from_rules_codes_len");
    _model._dataFromRulesCodes = new String[len];
    for (int i = 0; i < len; i++) {
      _model._dataFromRulesCodes[i] = readkv("data_from_rules_codes_" + i);
    }
    _model._weightsColumn =  readkv("weights_column");
    len = readkv("linear_names_len");
    _model._linearNames = new String[len];
    for (int i = 0; i < len; i++) {
      _model._linearNames[i] = readkv("linear_names_" + i);
    }

  }

  @Override
  public String getModelName() {
    return "rulefit";
  }

  @Override
  protected RuleFitMojoModel makeModel(String[] columns, String[][] domains, String responseColumn) {
    return new RuleFitMojoModel(columns, domains, responseColumn);
  }

  @Override
  public String mojoVersion() {
    return "1.00";
  }

  MojoRuleEnsemble readRuleEnseble() throws IOException {
    MojoRuleEnsemble ruleEnsemble = new MojoRuleEnsemble(readOrderedRuleEnseble());
    return ruleEnsemble;
  }
  
  MojoRule[][][] readOrderedRuleEnseble() throws IOException {
    MojoRule[][][] orderedRules = new MojoRule[_model._depth][_model._ntrees][];

    for (int i = 0; i < _model._depth; i++) {
      for (int j = 0; j < _model._ntrees; j++) {
        int currNumRules = readkv("num_rules_M".concat(String.valueOf(i)).concat("T").concat(String.valueOf(j)));
        MojoRule[] currRules = new MojoRule[currNumRules];
        String currIdPrefix = i + "_" + j + "_";
        for (int k = 0; k < currNumRules; k++) {
          currRules[k] = readRule(currIdPrefix + k);
        }
        orderedRules[i][j] = currRules;
      }
    }
    return orderedRules;
  }
  
  MojoRule readRule(String ruleId) throws IOException {
    MojoRule rule = new MojoRule();
    int numConditions = readkv("num_conditions_rule_id_" + ruleId);
    MojoCondition[] conditions = new MojoCondition[numConditions];
    for (int i = 0; i < numConditions; i++) {
      conditions[i] = readCondition(i, ruleId);
    }
    rule._conditions = conditions;
    rule._predictionValue = readkv("prediction_value_rule_id_" + ruleId);
    rule._languageRule = readkv("language_rule_rule_id_" + ruleId);
    rule._coefficient = readkv("coefficient_rule_id_" + ruleId);
    rule._varName = readkv("var_name_rule_id_" + ruleId);
    if (readkv("support_rule_id_" + ruleId) != null) 
      rule._support = readkv("support_rule_id_" + ruleId);
    else
      rule._support = Double.NaN;
    return rule;
  }
  
  MojoCondition readCondition(int conditionId, String ruleId) {
    MojoCondition condition = new MojoCondition();
    String conditionIdentifier = conditionId + "_" + ruleId;
    condition._featureIndex = readkv("feature_index_" + conditionIdentifier);
    int type = readkv("type_" + conditionIdentifier);
    if (type == 0) {
      condition._type = MojoCondition.Type.Categorical;
      int languageCatTresholdLength = readkv("language_cat_treshold_length_" + conditionIdentifier);
      String[] languageCatTreshold = new String[languageCatTresholdLength];
      for (int i = 0; i < languageCatTresholdLength; i++) {
        languageCatTreshold[i] = readkv("language_cat_treshold_" + i + "_" + conditionIdentifier).toString();
      }
      condition._languageCatThreshold = languageCatTreshold;
      int catTresholdLength = readkv("cat_treshold_length_" + conditionIdentifier);
      int[] catTreshold = new int[catTresholdLength];
      for (int i = 0; i < catTresholdLength; i++) {
        catTreshold[i] = readkv("cat_treshold_length_" + i + "_" + conditionIdentifier);
      }
      condition._catThreshold = catTreshold;
    } else {
      condition._type = MojoCondition.Type.Numerical;
      condition._numThreshold = readkv("num_treshold" + conditionIdentifier);
    }
    int operator = readkv("operator_" + conditionIdentifier);
    if (operator == 0) {
      condition._operator = MojoCondition.Operator.LessThan;
    } else if (operator == 1) {
      condition._operator = MojoCondition.Operator.GreaterThanOrEqual;
    } else {
      condition._operator = MojoCondition.Operator.In;
    }
    condition._featureName = readkv("feature_name_" + conditionIdentifier);
    condition._NAsIncluded = readkv("nas_included_" + conditionIdentifier);
    condition._languageCondition = readkv("language_condition" + conditionIdentifier);
    
    return condition;
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy