All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.tree.GlobalInteractionConstraints Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package hex.tree;

import water.Iced;
import water.util.ArrayUtils;
import water.util.IcedHashMap;
import water.util.IcedHashSet;
import water.util.IcedInt;

import java.util.*;

/**
 * Class to process global interaction constraints information and use this information for 
 * make a split decision in a tree.
 */
public class GlobalInteractionConstraints extends Iced {

    // Map where key is column index and value is a set of columns indices which can interact with the key column
    private IcedHashMap> allowedInteractionMap; 
    
    public GlobalInteractionConstraints(String[][] userFeatureInteractions, String[] treeFeatureNames){
        this.allowedInteractionMap = new IcedHashMap<>();
        parseInteractionsIndices(userFeatureInteractions, treeFeatureNames);
        // There should be always at least one column index in the map as a key
        assert this.allowedInteractionMap != null;
        assert this.allowedInteractionMap.size() != 0;
    }

    /**
     * Parse input interaction constraints String array into Map to easy use for split decision.
     * @param userInteractionConstraints input interaction constraints String array
     * @param columnNames column names from used dataset for training to match indices correctly
     */
    private void parseInteractionsIndices(String[][] userInteractionConstraints, String[] columnNames){
        IcedHashSet interactions;
        for (String[] list : userInteractionConstraints) {
            interactions = new IcedHashSet<>();
            for (int i = 0; i < list.length; i++) {
                String item = list[i];
                // first find only name
                int start = ArrayUtils.findWithPrefix(columnNames, item);
                // find start index and add indices until end index
                assert start != -1 : "Column name should be in defined column names.";
                if (start > -1) {               // find exact position - no encoding  
                    interactions.add(new IcedInt(start));
                } else {                       // find first occur of the name with prefix - encoding
                    start = - start - 2;
                    assert columnNames[start].startsWith(item): "The column name should be find correctly.";
                    // iterate until find all encoding indices
                    int end = start;
                    while (end < columnNames.length && columnNames[end].startsWith(item)) {
                        interactions.add(new IcedInt(end));
                        end++;
                    }
                }
            }
            addInteractionsSetToMap(interactions);
        }
    }

    private void addInteractionsSetToMap(IcedHashSet interactions){
        for (IcedInt index : interactions) {
            if (!allowedInteractionMap.containsKey(index)) {
                allowedInteractionMap.put(index, interactions);
            } else {
                IcedHashSet set = new IcedHashSet<>();
                set.addAll(allowedInteractionMap.get(index));
                set.addAll(interactions);
                allowedInteractionMap.put(index, set);
            }
        }
    }
    
    public IcedHashSet getAllowedInteractionForIndex(int columnIndex){
        return allowedInteractionMap.get(new IcedInt(columnIndex));
    }
    
    public boolean allowedInteractionContainsColumn(int columnIndex){
        return allowedInteractionMap.containsKey(new IcedInt(columnIndex));
    }
    
    public IcedHashSet getAllAllowedColumnIndices(){
        IcedHashSet indices = new IcedHashSet<>();
        indices.addAll(allowedInteractionMap.keySet());
        return indices;
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy