hex.modelselection.ModelSelectionUtils Maven / Gradle / Ivy
package hex.modelselection;
import hex.DataInfo;
import hex.Model;
import hex.glm.GLM;
import hex.glm.GLMModel;
import water.DKV;
import water.Key;
import water.fvec.Frame;
import java.lang.reflect.Field;
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
public class ModelSelectionUtils {
public static Frame[] generateTrainingFrames(ModelSelectionModel.ModelSelectionParameters parms, int predNum, String[] predNames,
int numModels, String foldColumn) {
int maxPredNum = predNames.length;
Frame[] trainFrames = new Frame[numModels];
int[] predIndices = IntStream.range(0, predNum).toArray(); // contains indices to predictor names
int zeroBound = maxPredNum-predNum;
int[] bounds = IntStream.range(zeroBound, maxPredNum).toArray(); // highest combo value
for (int frameCount = 0; frameCount < numModels; frameCount++) { // generate one combo
trainFrames[frameCount] = generateOneFrame(predIndices, parms, predNames, foldColumn);
DKV.put(trainFrames[frameCount]);
updatePredIndices(predIndices, bounds);
}
return trainFrames;
}
/***
* Given predictor indices stored in currentPredIndices, we need to find the next combination of predictor indices
* to use to generate the next combination. For example, if we have 4 predictors and we are looking to take two
* predictors, predictor indices can change in the following sequence [0,1]->[0,2]->[0,3]->[1,2]->[1,2]->[2,3].
*
* @param currentPredIndices
* @param indicesBounds
*/
public static void updatePredIndices(int[] currentPredIndices, int[] indicesBounds) {
int lastPredInd = currentPredIndices.length-1;
for (int index = lastPredInd; index >= 0; index--) {
if (currentPredIndices[index] < indicesBounds[index]) { // increase LSB first
currentPredIndices[index]++;
updateLaterIndices(currentPredIndices, index, lastPredInd);
break;
}
}
}
/***
* Give 5 predictors and say we want the combo of 3 predictors, this function will properly reset the prediction
* combination indices say from [0, 1, 4] -> [0, 2, 3] or [0, 3, 4] -> [1, 2, 3]. Given an index that was just
* updated, it will update the indices that come later in the list correctly.
*
* @param currentPredIndices
* @param indexUpdated
* @param lastPredInd
*/
public static void updateLaterIndices(int[] currentPredIndices, int indexUpdated, int lastPredInd) {
for (int index = indexUpdated; index < lastPredInd; index++) {
currentPredIndices[index+1] = currentPredIndices[index]+1;
}
}
/***
* Given a predictor indices set, this function will generate a training frame containing the predictors with
* indices in predIndices.
*
* @param predIndices
* @param parms
* @param predNames
* @return
*/
public static Frame generateOneFrame(int[] predIndices, ModelSelectionModel.ModelSelectionParameters parms, String[] predNames,
String foldColumn) {
final Frame predVecs = new Frame(Key.make());
final Frame train = parms.train();
int numPreds = predIndices.length;
for (int index = 0; index < numPreds; index++) {
int predVecNum = predIndices[index];
predVecs.add(predNames[predVecNum], train.vec(predNames[predVecNum]));
}
if (parms._weights_column != null)
predVecs.add(parms._weights_column, train.vec(parms._weights_column));
if (parms._offset_column != null)
predVecs.add(parms._offset_column, train.vec(parms._offset_column));
if (foldColumn != null)
predVecs.add(foldColumn, train.vec(foldColumn));
predVecs.add(parms._response_column, train.vec(parms._response_column));
return predVecs;
}
public static BitSet setBitSet(int[] currIndices, int totalPredSize) {
BitSet predSet = new BitSet(totalPredSize);
setBitSet(predSet, currIndices);
return predSet;
}
public static void setBitSet(BitSet predBitSet, int[] currIndices) {
for (int predIndex : currIndices)
predBitSet.set(predIndex);
}
/**
* Give a predictor subset with indices stored in currSubsetIndices, an array of training frames are generated by
* adding one predictor from predictorNames with predictors not already included in currSubsetIndices.
*
* @param parms
* @param predictorNames
* @param foldColumn
* @param currSubsetIndices
* @param validSubsets Lists containing only valid predictor indices to choose from
* @return
*/
public static Frame[] generateMaxRTrainingFrames(ModelSelectionModel.ModelSelectionParameters parms,
String[] predictorNames, String foldColumn,
List currSubsetIndices, int newPredPos,
List validSubsets, Set usedCombo) {
List trainFramesList = new ArrayList<>();
List changedSubset = new ArrayList<>(currSubsetIndices);
changedSubset.add(newPredPos, -1); // value irrelevant
int[] predIndices = changedSubset.stream().mapToInt(Integer::intValue).toArray();
int predNum = predictorNames.length;
BitSet tempIndices = new BitSet(predNum);
int predSizes = changedSubset.size();
boolean emptyUsedCombo = (usedCombo != null) && (usedCombo.size() == 0);
for (int predIndex : validSubsets) { // consider valid predictor indices only
predIndices[newPredPos] = predIndex;
if (emptyUsedCombo && predSizes > 1) { // add all indice set into usedCombo
tempIndices.clear();
setBitSet(tempIndices, predIndices);
usedCombo.add((BitSet) tempIndices.clone());
Frame trainFrame = generateOneFrame(predIndices, parms, predictorNames, foldColumn);
DKV.put(trainFrame);
trainFramesList.add(trainFrame);
} else if (usedCombo != null && predSizes > 1) { // only need to check for forward and replacement step for maxR
tempIndices.clear();
setBitSet(tempIndices, predIndices);
if (usedCombo.add((BitSet) tempIndices.clone())) { // returns true if not in keyset
Frame trainFrame = generateOneFrame(predIndices, parms, predictorNames, foldColumn);
DKV.put(trainFrame);
trainFramesList.add(trainFrame);
}
} else { // just build without checking duplicates for other modes
Frame trainFrame = generateOneFrame(predIndices, parms, predictorNames, foldColumn);
DKV.put(trainFrame);
trainFramesList.add(trainFrame);
}
}
return trainFramesList.stream().toArray(Frame[]::new);
}
public static String[][] shrinkStringArray(String[][] array, int numModels) {
int arrLen = array.length-1;
int offset = numModels-1;
String[][] newArray =new String[numModels][];
for (int index=0; index < numModels; index++)
newArray[offset-index] = array[arrLen-index].clone();
return newArray;
}
public static double[][] shrinkDoubleArray(double[][] array, int numModels) {
int arrLen = array.length-1;
int offset = numModels-1;
double[][] newArray =new double[numModels][];
for (int index=0; index < numModels; index++)
newArray[offset-index] = array[arrLen-index].clone();
return newArray;
}
public static Key[] shrinkKeyArray(Key[] array, int numModels) {
int arrLen = array.length;
Key[] newArray = new Key[numModels];
System.arraycopy(array, (arrLen-numModels), newArray, 0, numModels);
return newArray;
}
public static String joinDouble(double[] val) {
int arrLen = val.length; // skip the intercept terms
String[] strVal = new String[arrLen];
for (int index=0; index < arrLen; index++)
strVal[index] = Double.toString(val[index]);
return String.join(", ", strVal);
}
/**
* Given an array GLMModel built, find the one with the highest R2 value that exceeds lastBestR2. If found, return
* the index where the best model is. Else return -1
*
* @param lastBestR2
* @param bestR2Models
* @return
*/
public static int findBestR2Model(double lastBestR2, GLMModel[] bestR2Models) {
int numModel = bestR2Models.length;
int bestIndex = 0;
double currBestR2 = lastBestR2;
for (int index=0; index < numModel; index++) {
if (bestR2Models[index] != null) {
double bestR2 = bestR2Models[index].r2();
if (bestR2 > currBestR2) {
bestR2Models[bestIndex].delete();
bestIndex = index;
currBestR2 = bestR2;
} else {
bestR2Models[index].delete();
}
}
}
return currBestR2 > lastBestR2 ? bestIndex : -1;
}
public static GLMModel.GLMParameters[] generateGLMParameters(Frame[] trainingFrames,
ModelSelectionModel.ModelSelectionParameters parms,
int nfolds, String foldColumn,
Model.Parameters.FoldAssignmentScheme foldAssignment) {
final int numModels = trainingFrames.length;
GLMModel.GLMParameters[] params = new GLMModel.GLMParameters[numModels];
final Field[] field1 = ModelSelectionModel.ModelSelectionParameters.class.getDeclaredFields();
final Field[] field2 = Model.Parameters.class.getDeclaredFields();
for (int index = 0; index < numModels; index++) {
params[index] = new GLMModel.GLMParameters();
setParamField(parms, params[index], false, field1, Collections.emptyList());
setParamField(parms, params[index], true, field2, Collections.emptyList());
params[index]._train = trainingFrames[index]._key;
params[index]._nfolds = nfolds;
params[index]._fold_column = foldColumn;
params[index]._fold_assignment = foldAssignment;
}
return params;
}
public static void setParamField(Model.Parameters params, GLMModel.GLMParameters glmParam, boolean superClassParams,
Field[] paramFields, List excludeList) {
// assign relevant GAMParameter fields to GLMParameter fields
Field glmField;
boolean emptyExcludeList = excludeList.size() == 0;
for (Field oneField : paramFields) {
try {
if (emptyExcludeList || !excludeList.contains(oneField.getName())) {
if (superClassParams)
glmField = glmParam.getClass().getSuperclass().getDeclaredField(oneField.getName());
else
glmField = glmParam.getClass().getDeclaredField(oneField.getName());
glmField.set(glmParam, oneField.get(params));
}
} catch (IllegalAccessException|NoSuchFieldException e) { // suppress error printing, only cares about fields that are accessible
;
}
}
}
public static GLM[] buildGLMBuilders(GLMModel.GLMParameters[] trainingParams) {
int numModels = trainingParams.length;
GLM[] builders = new GLM[numModels];
for (int index=0; index 0) {
int r2Index = Arrays.asList(oneModel._output._cross_validation_metrics_summary.getRowHeaders()).indexOf("r2");
Float tempR2 = (Float) oneModel._output._cross_validation_metrics_summary.get(r2Index, 0);
currR2 = tempR2.doubleValue();
}
if (currR2 > bestR2Val) {
bestR2Val = currR2;
if (bestModel != null)
bestModel.delete();
bestModel = oneModel;
} else {
oneModel.delete();
}
}
return bestModel;
}
public static String[] extractPredictorNames(ModelSelectionModel.ModelSelectionParameters parms, DataInfo dinfo,
String foldColumn) {
List frameNames = Arrays.stream(dinfo._adaptedFrame.names()).collect(Collectors.toList());
String[] nonResponseCols = parms.getNonPredictors();
for (String col : nonResponseCols)
frameNames.remove(col);
if (foldColumn != null && frameNames.contains(foldColumn))
frameNames.remove(foldColumn);
return frameNames.stream().toArray(String[]::new);
}
public static int findMinZValue(GLMModel model, List numPredNames, List catPredNames,
List predNames) {
List zValList = Arrays.stream(model._output.zValues()).boxed().map(Math::abs).collect(Collectors.toList());
List coeffNames = Arrays.stream(model._output.coefficientNames()).collect(Collectors.toList());
if (coeffNames.contains("Intercept")) { // remove intercept terms
int interceptIndex = coeffNames.indexOf("Intercept");
zValList.remove(interceptIndex);
coeffNames.remove(interceptIndex);
}
// grab min z-values for numerical and categorical columns
PredNameMinZVal numericalPred = findNumMinZVal(numPredNames, zValList, coeffNames);
PredNameMinZVal categoricalPred = findCatMinZVal(model, zValList);
// choose the min z-value from numerical and categorical predictors and return its index in predNames
if (categoricalPred._minZVal >= 0 && categoricalPred._minZVal < numericalPred._minZVal) { // categorical pred has minimum z-value
catPredNames.remove(catPredNames.indexOf(categoricalPred._predName));
return predNames.indexOf(categoricalPred._predName);
} else { // numerical pred has minimum z-value
numPredNames.remove(numPredNames.indexOf(numericalPred._predName));
return predNames.indexOf(numericalPred._predName);
}
}
public static PredNameMinZVal findNumMinZVal(List numPredNames, List zValList, List coeffNames) {
double minNumVal = -1;
String numPredMinZ = null;
if (numPredNames != null && numPredNames.size() > 0) {
List numZValues = new ArrayList<>();
for (String predName : numPredNames) {
int eleInd = coeffNames.indexOf(predName);
double oneZValue = zValList.get(eleInd);
if (Double.isNaN(oneZValue)) {
zValList.set(eleInd, 0.0);
numZValues.add(0.0); // NaN corresponds to coefficient of 0.0
} else {
numZValues.add(oneZValue);
}
}
minNumVal = numZValues.stream().min(Double::compare).get(); // minimum z-value of numerical predictors
numPredMinZ = numPredNames.get(numZValues.indexOf(minNumVal));
}
return new PredNameMinZVal(numPredMinZ, minNumVal);
}
/***
* This method extracts the categorical coefficient z-value by using the following method:
* 1. From GLMModel model, it extracts the column names of the dinfo._adaptedFrame that is used to build the glm
* model and generate the glm coefficients. The column names will be in exactly the same order as the coefficient
* names with the exception that each enum levels will not be given a name in the column names.
* 2. To figure out which coefficient name corresponds to which column name, we use the catOffsets which will tell
* us how many enum levels are used in the glm model coefficients. If the catOffset for the first coefficient
* says 3, that means that column will have three enum levels represented in the glm model coefficients.
*/
public static PredNameMinZVal findCatMinZVal(GLMModel model, List zValList) {
String[] columnNames = model.names(); // column names of dinfo._adaptedFrame
int[] catOffsets = model._output.getDinfo()._catOffsets;
double minCatVal = -1;
String catPredMinZ = null;
if (catOffsets != null) {
minCatVal = Double.MAX_VALUE;
int numCatCol = catOffsets.length-1;
for (int catInd = 0; catInd < numCatCol; catInd++) { // go through each categorical column
List catZValues = new ArrayList<>();
int nextCatOffset = catOffsets[catInd+1];
for (int eleInd = catOffsets[catInd]; eleInd < nextCatOffset; eleInd++) { // check z-value for each level
double oneZVal = zValList.get(eleInd);
if (Double.isNaN(oneZVal)) {
zValList.set(eleInd, 0.0);
catZValues.add(0.0);
} else {
catZValues.add(oneZVal);
}
}
if (catZValues.size() > 0) {
double oneCatMinZ = catZValues.stream().max(Double::compare).get(); // choose the best z-value here
if (oneCatMinZ < minCatVal) {
minCatVal = oneCatMinZ;
catPredMinZ = columnNames[catInd];
}
}
}
}
return new PredNameMinZVal(catPredMinZ, minCatVal);
}
static class PredNameMinZVal {
String _predName;
double _minZVal;
public PredNameMinZVal(String predName, double minZVal) {
_predName= predName;
_minZVal = minZVal;
}
}
public static List extraModelColumnNames(List coefNames, GLMModel bestModel) {
List coefUsed = new ArrayList();
List modelColumns = new ArrayList<>(Arrays.asList(bestModel._output._names));
for (String coefName : modelColumns) {
if (coefNames.contains(coefName))
coefUsed.add(coefName);
}
return coefUsed;
}
public static void updateValidSubset(List validSubset, List originalSubset,
List currSubsetIndices) {
List onlyInOriginal = new ArrayList<>(originalSubset);
onlyInOriginal.removeAll(currSubsetIndices);
List onlyInCurr = new ArrayList<>(currSubsetIndices);
onlyInCurr.removeAll(originalSubset);
validSubset.addAll(onlyInOriginal);
validSubset.removeAll(onlyInCurr);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy