All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.rulefit.RuleFitUtils Maven / Gradle / Ivy

package hex.rulefit;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.stream.Collectors;

public class RuleFitUtils {

    public static String[] getPathNames(int modelId, int numCols, String[] names) {
        String[] pathNames = new String[numCols];
        for (int i = 0; i < numCols; i++) {
            pathNames[i] = "tree_" + modelId + "." + names[i];
        }
        return pathNames;
    }

    public static String[] getLinearNames(int numCols, String[] names) {
        String[] pathNames = new String[numCols];
        for (int i = 0; i < numCols; i++) {
            pathNames[i] = "linear." + names[i];
        }
        return pathNames;
    }

    static Rule[] consolidateRules(Rule[] rules, boolean remove_duplicates) {
        for (int i = 0; i < rules.length; i++) {
            if (rules[i].conditions != null) { // linear rules doesn't need to consolidate
                rules[i] = consolidateRule(rules[i], remove_duplicates);
            }
        }
        
        if (remove_duplicates)
            return deduplicateRules(rules);
        else
            return rules;
    }

    static Rule consolidateRule(Rule rule, boolean remove_duplicates) {
        List consolidatedConditions = new ArrayList<>();

        Condition[] conditions = rule.conditions;
        List varNames = new ArrayList<>();
        for (int i = 0; i < conditions.length; i++) {
            if (!varNames.contains(conditions[i].featureName)) {
                varNames.add(conditions[i].featureName);
            }
        }
        for (int i = 0; i < varNames.size(); i++) {
            consolidatedConditions.addAll(consolidateConditionsByVar(conditions, varNames.get(i)));
        }
        
        if (remove_duplicates) {
            // sort by feature name as a preparation for rules deduplication
            rule.conditions = consolidatedConditions.stream()
                    .sorted(Comparator.comparing(condition -> condition.featureName))
                    .collect(Collectors.toList()).toArray(new Condition[0]);
        } else {
            rule.conditions = consolidatedConditions.toArray(new Condition[0]);
        }
        rule.languageRule = rule.generateLanguageRule();
        return rule;
    }

    static List  consolidateConditionsByVar(Condition[] conditions, String varname) {
        List currVarConditions = new ArrayList<>();
        for (int i = 0; i < conditions.length; i++) {
            if (varname.equals(conditions[i].featureName))
                currVarConditions.add(conditions[i]);
        }
        if (currVarConditions.size() == 1) {
            return currVarConditions;
        } else {
            Condition potentialLessThan = null;
            Condition potentialGreaterThanOrEqual = null;
            Condition potentialIn = null;
            
            for (int i = 0; i < currVarConditions.size(); i++) {
                Condition currCondition = currVarConditions.get(i);
                if (Condition.Operator.LessThan.equals(currCondition.operator)) {
                    if (potentialLessThan == null) {
                        potentialLessThan = currCondition;
                    } else {
                        potentialLessThan = potentialLessThan.expandBy(currCondition);
                    }
                } else if (Condition.Operator.GreaterThanOrEqual.equals(currCondition.operator)) {
                    if (potentialGreaterThanOrEqual == null) {
                        potentialGreaterThanOrEqual = currCondition;
                    } else {
                        potentialGreaterThanOrEqual = potentialGreaterThanOrEqual.expandBy(currCondition);
                    }
                } else {
                    assert Condition.Operator.In.equals(currCondition.operator);
                    if (potentialIn == null) {
                        potentialIn = currCondition;
                    } else {
                        potentialIn = potentialIn.expandBy(currCondition);
                    }
                }
            }

            List currVarConsolidatedConditions = new ArrayList<>();

            if (potentialLessThan != null)
                currVarConsolidatedConditions.add(potentialLessThan);
            if (potentialGreaterThanOrEqual != null)
                currVarConsolidatedConditions.add(potentialGreaterThanOrEqual);
            if (potentialIn != null)
                currVarConsolidatedConditions.add(potentialIn);

            return currVarConsolidatedConditions;
        }
    }
    
    static Rule[] deduplicateRules(Rule[] rules) {
        List list = Arrays.asList(rules);

        // group by non linear rules
        List transform = list.stream()
                .filter(rule -> rule.conditions != null)
                .collect(Collectors.groupingBy(rule -> rule.languageRule))
                .entrySet().stream()
                .map(e -> e.getValue().stream()
                        .reduce((r1,r2) -> new Rule(r1.conditions, r1.predictionValue, r1.varName + ", " + r2.varName, r1.coefficient + r2.coefficient)))
                .map(f -> f.get())
                .collect(Collectors.toList());

        // add linear rules
        transform.addAll(list.stream().filter(rule -> rule.conditions == null).collect(Collectors.toList()));

        return transform.toArray(new Rule[0]);
    }

    /** 
     * Returns a ruleId. 
     * If the ruleId is in form after deduplication:  "M0T0N1, M0T9N56, M9T34N56", meaning contains ", "
     * finds only first rule (other are equivalents)
     */
    static String readRuleId(String ruleId) {
        if (ruleId.contains(",")) {
            return ruleId.split(",")[0];
        } else {
            return ruleId;
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy