All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.rulefit.RuleFitUtils Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package hex.rulefit;

import water.util.TwoDimTable;

import java.util.*;
import java.util.stream.Collectors;

public class RuleFitUtils {

    public static String[] getPathNames(int modelId, int numCols, String[] names) {
        String[] pathNames = new String[numCols];
        for (int i = 0; i < numCols; i++) {
            pathNames[i] = "tree_" + modelId + "." + names[i];
        }
        return pathNames;
    }

    public static String[] getLinearNames(int numCols, String[] names) {
        String[] pathNames = new String[numCols];
        for (int i = 0; i < numCols; i++) {
            pathNames[i] = "linear." + names[i];
        }
        return pathNames;
    }
    
    static Rule[] deduplicateRules(Rule[] rules, boolean remove_duplicates) {
        if (remove_duplicates) {
            List transform = new ArrayList<>();
            for (int i = 0; i < rules.length; i++) {
                Rule currRule = rules[i];
                if (currRule.conditions != null) {
                    // non linear rules:
                    if (!transform.contains(currRule)) {
                        transform.add(currRule);
                    } else {
                        for (int j = 0; j < transform.size(); j++) {
                            if (i != j) {
                                Rule ruleToExtend = transform.get(j);
                                if (currRule.equals(ruleToExtend)) {
                                    transform.remove(j);
                                    Rule newRule = new Rule(ruleToExtend.conditions,  ruleToExtend.predictionValue, ruleToExtend.varName + ", " + currRule.varName,  ruleToExtend.coefficient + currRule.coefficient, ruleToExtend.support);
                                    transform.add(newRule);
                                    break;
                                }
                            }
                        }
                    }
                } else {
                    // linear rules:
                    transform.add(currRule);
                }
            }
            return transform.toArray(new Rule[0]);
        } else {
            return rules;
        }
    }

    static Rule[] sortRules(Rule[] rules) {
        Comparator ruleAbsCoefficientComparator = Comparator.comparingDouble(Rule::getAbsCoefficient).reversed();
        Arrays.sort(rules, ruleAbsCoefficientComparator);
        return rules;
    }

    /** 
     * Returns a ruleId. 
     * If the ruleId is in form after deduplication:  "M0T0N1, M0T9N56, M9T34N56", meaning contains ", "
     * finds only first rule (other are equivalents)
     */
    static String readRuleId(String ruleId) {
        if (ruleId.contains(",")) {
            return ruleId.split(",")[0];
        } else {
            return ruleId;
        }
    }

    static Rule[] getRules(HashMap glmCoefficients, RuleEnsemble ruleEnsemble, String[] classNames, int nclasses) {
        // extract variable-coefficient map (filter out intercept and zero betas)
        Map filteredRules = glmCoefficients.entrySet()
                .stream()
                .filter(e -> !("Intercept".equals(e.getKey()) || e.getKey().contains("Intercept_")) && 0 != e.getValue())
                .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));

        List rules = new ArrayList<>();
        Rule rule;
        for (Map.Entry entry : filteredRules.entrySet()) {
            if (!entry.getKey().startsWith("linear.")) {
                rule = ruleEnsemble.getRuleByVarName(getVarName(entry.getKey(), classNames, nclasses));
            } else {
                rule = new Rule(null, entry.getValue(), entry.getKey());
                // linear rule applies to all the rows
                rule.support = 1.0;
            }
            rule.setCoefficient(entry.getValue());
            rules.add(rule);
        }

        return rules.toArray(new Rule[] {});
    }

    static private String getVarName(String ruleKey, String[] classNames, int nclasses) {
        if (nclasses > 2) {
            ruleKey = removeClassNameSuffix(ruleKey, classNames);
        }
        return ruleKey.substring(ruleKey.lastIndexOf(".") + 1);
    }

    private static String removeClassNameSuffix(String ruleKey, String[] classNames) {
        for (int i = 0; i < classNames.length; i++) {
            if (ruleKey.endsWith(classNames[i]))
                return ruleKey.substring(0, ruleKey.length() - classNames[i].length() - 1);
        }
        return ruleKey;
    }



    static  TwoDimTable convertRulesToTable(Rule[] rules, boolean isMultinomial, boolean generateLanguageRule) {
        List colHeaders = new ArrayList<>();
        List colTypes = new ArrayList<>();
        List colFormat = new ArrayList<>();

        colHeaders.add("variable");
        colTypes.add("string");
        colFormat.add("%s");
        if (isMultinomial) {
            colHeaders.add("class");
            colTypes.add("string");
            colFormat.add("%s");
        }
        colHeaders.add("coefficient");
        colTypes.add("double");
        colFormat.add("%.5f");
        colHeaders.add("support");
        colTypes.add("double");
        colFormat.add("%.5f");
        colHeaders.add("rule");
        colTypes.add("string");
        colFormat.add("%s");

        final int rows = rules.length;
        TwoDimTable table = new TwoDimTable("Rule Importance", null, new String[rows],
                colHeaders.toArray(new String[0]), colTypes.toArray(new String[0]), colFormat.toArray(new String[0]), "");

        for (int row = 0; row < rows; row++) {
            int col = 0;
            String varname = (rules[row]).varName;
            table.set(row, col++, varname);
            if (isMultinomial) {
                String segments[] = varname.split("_");
                table.set(row, col++, segments[segments.length - 1]);
            }
            table.set(row, col++, (rules[row]).coefficient);
            table.set(row, col++, (rules[row]).support);
            table.set(row, col, generateLanguageRule ? rules[row].generateLanguageRule() : rules[row].languageRule);
        }

        return table;
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy