All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.modelselection.ModelSelectionUtils Maven / Gradle / Ivy

package hex.modelselection;

import hex.DataInfo;
import hex.Model;
import hex.glm.GLM;
import hex.glm.GLMModel;
import water.DKV;
import water.Key;
import water.fvec.Frame;
import java.lang.reflect.Field;
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

public class ModelSelectionUtils {
    public static Frame[] generateTrainingFrames(ModelSelectionModel.ModelSelectionParameters parms, int predNum, String[] predNames,
                                                 int numModels, String foldColumn) {
        int maxPredNum = predNames.length;
        Frame[] trainFrames = new Frame[numModels];
        int[] predIndices = IntStream.range(0, predNum).toArray();   // contains indices to predictor names
        int zeroBound = maxPredNum-predNum;
        int[] bounds = IntStream.range(zeroBound, maxPredNum).toArray();   // highest combo value
        for (int frameCount = 0; frameCount < numModels; frameCount++) {    // generate one combo
            trainFrames[frameCount] = generateOneFrame(predIndices, parms, predNames, foldColumn);
            DKV.put(trainFrames[frameCount]);
            updatePredIndices(predIndices, bounds);
        }
        return trainFrames;
    }

    /***
     * Given predictor indices stored in currentPredIndices, we need to find the next combination of predictor indices
     * to use to generate the next combination.  For example, if we have 4 predictors and we are looking to take two 
     * predictors, predictor indices can change in the following sequence [0,1]->[0,2]->[0,3]->[1,2]->[1,2]->[2,3]. 
     *
     * @param currentPredIndices
     * @param indicesBounds
     */
    public static void updatePredIndices(int[] currentPredIndices, int[] indicesBounds) {
        int lastPredInd = currentPredIndices.length-1;
        for (int index = lastPredInd; index >= 0; index--) {
            if (currentPredIndices[index] < indicesBounds[index]) { // increase LSB first
                currentPredIndices[index]++;
                updateLaterIndices(currentPredIndices, index, lastPredInd);
                break;
            } 
        }
    }

    /***
     * Give 5 predictors and say we want the combo of 3 predictors, this function will properly reset the prediction
     * combination indices say from [0, 1, 4] -> [0, 2, 3] or [0, 3, 4] -> [1, 2, 3].  Given an index that was just
     * updated, it will update the indices that come later in the list correctly.
     * 
     * @param currentPredIndices
     * @param indexUpdated
     * @param lastPredInd
     */
    public static void updateLaterIndices(int[] currentPredIndices, int indexUpdated, int lastPredInd) {
        for (int index = indexUpdated; index < lastPredInd; index++) {
            currentPredIndices[index+1] = currentPredIndices[index]+1;
        }
    }
    
    /***
     *     Given a predictor indices set, this function will generate a training frame containing the predictors with
     *     indices in predIndices.
     *     
     * @param predIndices
     * @param parms
     * @param predNames
     * @return
     */
    public static Frame generateOneFrame(int[] predIndices, ModelSelectionModel.ModelSelectionParameters parms, String[] predNames,
                                         String foldColumn) {
        final Frame predVecs = new Frame(Key.make());
        final Frame train = parms.train();
        int numPreds = predIndices.length;
        for (int index = 0; index < numPreds; index++) {
            int predVecNum = predIndices[index];
            predVecs.add(predNames[predVecNum], train.vec(predNames[predVecNum]));
        }
        if (parms._weights_column != null)
            predVecs.add(parms._weights_column, train.vec(parms._weights_column));
        if (parms._offset_column != null)
            predVecs.add(parms._offset_column, train.vec(parms._offset_column));
        if (foldColumn != null)
            predVecs.add(foldColumn, train.vec(foldColumn));
        predVecs.add(parms._response_column, train.vec(parms._response_column));
        return predVecs;
    }

    public static BitSet setBitSet(int[] currIndices, int totalPredSize) {
        BitSet predSet = new BitSet(totalPredSize);
        setBitSet(predSet, currIndices);
        return predSet;
    }
    
    public static void setBitSet(BitSet predBitSet, int[] currIndices) {
        for (int predIndex : currIndices)
            predBitSet.set(predIndex);
    }
    
    /**
     * Give a predictor subset with indices stored in currSubsetIndices, an array of training frames are generated by 
     * adding one predictor from predictorNames with predictors not already included in currSubsetIndices.  
     * 
     * @param parms
     * @param predictorNames
     * @param foldColumn
     * @param currSubsetIndices
     * @param validSubsets Lists containing only valid predictor indices to choose from
     * @return
     */
    public static Frame[] generateMaxRTrainingFrames(ModelSelectionModel.ModelSelectionParameters parms,
                                                     String[] predictorNames, String foldColumn,
                                                     List currSubsetIndices, int newPredPos,
                                                     List validSubsets, Set usedCombo) {
        List trainFramesList = new ArrayList<>();
        List changedSubset = new ArrayList<>(currSubsetIndices);
        changedSubset.add(newPredPos, -1);  // value irrelevant
        int[] predIndices = changedSubset.stream().mapToInt(Integer::intValue).toArray();
        int predNum = predictorNames.length;
        BitSet tempIndices =  new BitSet(predNum);
        int predSizes = changedSubset.size();
        boolean emptyUsedCombo = (usedCombo != null) && (usedCombo.size() == 0);
        for (int predIndex : validSubsets) {  // consider valid predictor indices only
            predIndices[newPredPos] = predIndex;
            if (emptyUsedCombo && predSizes > 1) {   // add all indice set into usedCombo
                tempIndices.clear();
                setBitSet(tempIndices, predIndices);
                usedCombo.add((BitSet) tempIndices.clone());
                Frame trainFrame = generateOneFrame(predIndices, parms, predictorNames, foldColumn);
                DKV.put(trainFrame);
                trainFramesList.add(trainFrame);
                
            } else if (usedCombo != null && predSizes > 1) {   // only need to check for forward and replacement step for maxR
                tempIndices.clear();
                setBitSet(tempIndices, predIndices);
                if (usedCombo.add((BitSet) tempIndices.clone())) {  // returns true if not in keyset
                    Frame trainFrame = generateOneFrame(predIndices, parms, predictorNames, foldColumn);
                    DKV.put(trainFrame);
                    trainFramesList.add(trainFrame);
                }
            } else {     // just build without checking duplicates for other modes
                Frame trainFrame = generateOneFrame(predIndices, parms, predictorNames, foldColumn);
                DKV.put(trainFrame);
                trainFramesList.add(trainFrame);
            }
        }
        return trainFramesList.stream().toArray(Frame[]::new);
    }
    
    public static String[][] shrinkStringArray(String[][] array, int numModels) {
        int arrLen = array.length-1;
        int offset = numModels-1;
        String[][] newArray =new String[numModels][];
        for (int index=0; index < numModels; index++)
            newArray[offset-index] = array[arrLen-index].clone();
        return newArray;
    }
    
    public static double[][] shrinkDoubleArray(double[][] array, int numModels) {
        int arrLen = array.length-1;
        int offset = numModels-1;
        double[][] newArray =new double[numModels][];
        for (int index=0; index < numModels; index++)
            newArray[offset-index] = array[arrLen-index].clone();
        return newArray;
    }

    public static Key[] shrinkKeyArray(Key[] array, int numModels) {
        int arrLen = array.length;
        Key[] newArray = new Key[numModels];
        System.arraycopy(array, (arrLen-numModels), newArray, 0, numModels);
        return newArray;
    }
    
    public static String joinDouble(double[] val) {
        int arrLen = val.length; // skip the intercept terms
        String[] strVal = new String[arrLen];
        for (int index=0; index < arrLen; index++)
            strVal[index] = Double.toString(val[index]);
        return String.join(", ", strVal);
    }
    /**
     * Given an array GLMModel built, find the one with the highest R2 value that exceeds lastBestR2.  If found, return
     * the index where the best model is.  Else return -1
     * 
     * @param lastBestR2
     * @param bestR2Models
     * @return
     */
    public static int findBestR2Model(double lastBestR2, GLMModel[] bestR2Models) {
        int numModel = bestR2Models.length;
        int bestIndex = 0;
        double currBestR2 = lastBestR2;
        for (int index=0; index < numModel; index++) {
            if (bestR2Models[index] != null) {
                double bestR2 = bestR2Models[index].r2();
                if (bestR2 > currBestR2) {
                    bestR2Models[bestIndex].delete();
                    bestIndex = index;
                    currBestR2 = bestR2;
                } else {
                    bestR2Models[index].delete();
                }
            }
        }
        return currBestR2 > lastBestR2 ? bestIndex : -1;
    }
    
    public static GLMModel.GLMParameters[] generateGLMParameters(Frame[] trainingFrames,
                                                                 ModelSelectionModel.ModelSelectionParameters parms, 
                                                                 int nfolds, String foldColumn,
                                                                 Model.Parameters.FoldAssignmentScheme foldAssignment) {
        final int numModels = trainingFrames.length;
        GLMModel.GLMParameters[] params = new GLMModel.GLMParameters[numModels];
        final Field[] field1 = ModelSelectionModel.ModelSelectionParameters.class.getDeclaredFields();
        final Field[] field2 = Model.Parameters.class.getDeclaredFields();
        for (int index = 0; index < numModels; index++) {
            params[index] = new GLMModel.GLMParameters();
            setParamField(parms, params[index], false, field1, Collections.emptyList());
            setParamField(parms, params[index], true, field2, Collections.emptyList());
            params[index]._train = trainingFrames[index]._key;
            params[index]._nfolds = nfolds;
            params[index]._fold_column = foldColumn;
            params[index]._fold_assignment = foldAssignment;
        }
        return params;
    }
    
    public static void setParamField(Model.Parameters params, GLMModel.GLMParameters glmParam, boolean superClassParams,
                                     Field[] paramFields, List excludeList) {
        // assign relevant GAMParameter fields to GLMParameter fields
        Field glmField;
        boolean emptyExcludeList = excludeList.size() == 0;
        for (Field oneField : paramFields) {
            try {
                if (emptyExcludeList || !excludeList.contains(oneField.getName())) {
                    if (superClassParams)
                        glmField = glmParam.getClass().getSuperclass().getDeclaredField(oneField.getName());
                    else
                        glmField = glmParam.getClass().getDeclaredField(oneField.getName());
                    glmField.set(glmParam, oneField.get(params));
                }
            } catch (IllegalAccessException|NoSuchFieldException e) { // suppress error printing, only cares about fields that are accessible
                ;
            }
        }    
    }
    
    public static GLM[] buildGLMBuilders(GLMModel.GLMParameters[] trainingParams) {
        int numModels = trainingParams.length;
        GLM[] builders = new GLM[numModels];
        for (int index=0; index 0) {
                int r2Index = Arrays.asList(oneModel._output._cross_validation_metrics_summary.getRowHeaders()).indexOf("r2");
                Float tempR2 = (Float) oneModel._output._cross_validation_metrics_summary.get(r2Index, 0);
                currR2 = tempR2.doubleValue();
            }
            if (currR2 > bestR2Val) {
                bestR2Val = currR2;
                if (bestModel != null)
                    bestModel.delete();
                bestModel = oneModel;
            } else {
                oneModel.delete();
            }
        }
        return bestModel;
    }

    public static String[] extractPredictorNames(ModelSelectionModel.ModelSelectionParameters parms, DataInfo dinfo, 
                                            String foldColumn) {
        List frameNames = Arrays.stream(dinfo._adaptedFrame.names()).collect(Collectors.toList());
        String[] nonResponseCols = parms.getNonPredictors();
        for (String col : nonResponseCols)
            frameNames.remove(col);
        if (foldColumn != null && frameNames.contains(foldColumn))
            frameNames.remove(foldColumn);
        return frameNames.stream().toArray(String[]::new);
        
    }
    
    public static int findMinZValue(GLMModel model, List numPredNames, List catPredNames, 
                                    List predNames) {
        List zValList = Arrays.stream(model._output.zValues()).boxed().map(Math::abs).collect(Collectors.toList());
        List coeffNames = Arrays.stream(model._output.coefficientNames()).collect(Collectors.toList());
        if (coeffNames.contains("Intercept")) { // remove intercept terms
            int interceptIndex = coeffNames.indexOf("Intercept");
            zValList.remove(interceptIndex);
            coeffNames.remove(interceptIndex);
        }
        // grab min z-values for numerical and categorical columns
        PredNameMinZVal numericalPred = findNumMinZVal(numPredNames, zValList, coeffNames);
        PredNameMinZVal categoricalPred = findCatMinZVal(model, zValList);
        
        // choose the min z-value from numerical and categorical predictors and return its index in predNames
        if (categoricalPred._minZVal >= 0 && categoricalPred._minZVal < numericalPred._minZVal) { // categorical pred has minimum z-value
            catPredNames.remove(catPredNames.indexOf(categoricalPred._predName));
            return predNames.indexOf(categoricalPred._predName);
        } else {    // numerical pred has minimum z-value
            numPredNames.remove(numPredNames.indexOf(numericalPred._predName));
            return predNames.indexOf(numericalPred._predName);
        }
    }
    
    public static PredNameMinZVal findNumMinZVal(List numPredNames, List zValList, List coeffNames) {
        double minNumVal = -1;
        String numPredMinZ = null;
        if (numPredNames != null && numPredNames.size() > 0) {
            List numZValues = new ArrayList<>();
            for (String predName : numPredNames) {
                int eleInd = coeffNames.indexOf(predName);
                double oneZValue = zValList.get(eleInd);
                if (Double.isNaN(oneZValue)) {
                    zValList.set(eleInd, 0.0);
                    numZValues.add(0.0);    // NaN corresponds to coefficient of 0.0
                } else {
                    numZValues.add(oneZValue);
                }
            }
            minNumVal = numZValues.stream().min(Double::compare).get(); // minimum z-value of numerical predictors
            numPredMinZ = numPredNames.get(numZValues.indexOf(minNumVal));
        }
        return new PredNameMinZVal(numPredMinZ, minNumVal);
    }

    /***
     * This method extracts the categorical coefficient z-value by using the following method:
     * 1. From GLMModel model, it extracts the column names of the dinfo._adaptedFrame that is used to build the glm 
     * model and generate the glm coefficients.  The column names will be in exactly the same order as the coefficient
     * names with the exception that each enum levels will not be given a name in the column names.
     * 2. To figure out which coefficient name corresponds to which column name, we use the catOffsets which will tell
     * us how many enum levels are used in the glm model coefficients.  If the catOffset for the first coefficient
     * says 3, that means that column will have three enum levels represented in the glm model coefficients.
     */
    public static PredNameMinZVal findCatMinZVal(GLMModel model, List zValList) {
        String[] columnNames = model.names(); // column names of dinfo._adaptedFrame
        int[] catOffsets = model._output.getDinfo()._catOffsets;
        double minCatVal = -1;
        String catPredMinZ = null;
        if (catOffsets != null) {
            minCatVal = Double.MAX_VALUE;
            int numCatCol = catOffsets.length-1;

            for (int catInd = 0; catInd < numCatCol; catInd++) {    // go through each categorical column
                List catZValues = new ArrayList<>();
                int nextCatOffset = catOffsets[catInd+1];
                for (int eleInd = catOffsets[catInd]; eleInd < nextCatOffset; eleInd++) {   // check z-value for each level
                    double oneZVal = zValList.get(eleInd);
                    if (Double.isNaN(oneZVal)) {
                        zValList.set(eleInd, 0.0);
                        catZValues.add(0.0);
                    } else {
                        catZValues.add(oneZVal);
                    }
                }
                if (catZValues.size() > 0) {
                    double oneCatMinZ = catZValues.stream().max(Double::compare).get(); // choose the best z-value here
                    if (oneCatMinZ < minCatVal) {
                        minCatVal = oneCatMinZ;
                        catPredMinZ = columnNames[catInd];
                    }
                }
            }
        }
        return new PredNameMinZVal(catPredMinZ, minCatVal);
    }
    
    static class PredNameMinZVal {
        String _predName;
        double _minZVal;
        
        public PredNameMinZVal(String predName, double minZVal) {
            _predName= predName;
            _minZVal = minZVal;
        }
    }
    
    public static List extraModelColumnNames(List coefNames, GLMModel bestModel) {
        List coefUsed = new ArrayList();
        List modelColumns = new ArrayList<>(Arrays.asList(bestModel._output._names));
        for (String coefName : modelColumns) {
            if (coefNames.contains(coefName)) 
                coefUsed.add(coefName);
        }
        return coefUsed;
    }
    
    public static void updateValidSubset(List validSubset, List originalSubset, 
                                         List currSubsetIndices) {
        List onlyInOriginal = new ArrayList<>(originalSubset);
        onlyInOriginal.removeAll(currSubsetIndices);
        List onlyInCurr = new ArrayList<>(currSubsetIndices);
        onlyInCurr.removeAll(originalSubset);
        validSubset.addAll(onlyInOriginal);
        validSubset.removeAll(onlyInCurr);
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy