
ai.h2o.mojos.runtime.h2o3.H2O3Transform Maven / Gradle / Ivy
package ai.h2o.mojos.runtime.h2o3;
import ai.h2o.mojos.runtime.frame.MojoColumnFloat64;
import ai.h2o.mojos.runtime.frame.MojoFrame;
import ai.h2o.mojos.runtime.frame.MojoFrameMeta;
import ai.h2o.mojos.runtime.transforms.MojoTransform;
import hex.ModelCategory;
import hex.genmodel.GenModel;
import hex.genmodel.easy.EasyPredictModelWrapper;
import hex.genmodel.easy.RowData;
import hex.genmodel.easy.exception.PredictException;
import hex.genmodel.easy.prediction.BinomialModelPrediction;
import hex.genmodel.easy.prediction.MultinomialModelPrediction;
import hex.genmodel.easy.prediction.RegressionModelPrediction;
/**
* A MOJO2 pipeline implementation that uses an H2O-3 (or Sparkling Water)
* MOJO as the predictor inside. The intent is to provide as identical an
* experience to the MOJO2 API as possible.
*
* A non-goal is to expose every possible low-level H2O-3 MOJO API capability.
* If you want to do that, call the H2O-3 MOJO API directly, instead.
*/
public class H2O3Transform extends MojoTransform {
private final GenModel genModel;
private final EasyPredictModelWrapper easyPredictModelWrapper;
/**
* A MOJO2 transformer implementation that uses an H2O-3 (or Sparkling Water) MOJO as the predictor inside.
*
* Must provide a valid Binomial, Multinomial or Regression model.
* Other model types not currently supported.
*
* Note: later, we might consider splitting this into one class per each supported model type, to more closely represent underlying H2O-3 algos.
*
* @param easyPredictModelWrapper H2O-3 MOJO model.
*/
H2O3Transform(MojoFrameMeta meta, int[] iindices, int[] oindices, EasyPredictModelWrapper easyPredictModelWrapper) {
super(iindices, oindices);
this.easyPredictModelWrapper = easyPredictModelWrapper;
this.genModel = easyPredictModelWrapper.m;
}
@Override
public void transform(final MojoFrame frame) {
final ModelCategory modelCategory = genModel.getModelCategory();
final int colCount = iindices.length;
final int rowCount = frame.getNrows();
final String[][] columns = new String[colCount][];
for (int j = 0; j < colCount; j += 1) {
final int iidx = iindices[j];
columns[j] = frame.getColumn(iidx).getDataAsStrings();
}
for (int rowIdx = 0; rowIdx < rowCount; rowIdx++) {
final RowData rowData = new RowData();
for (int colIdx = 0; colIdx < colCount; colIdx++) {
final int iidx = iindices[colIdx];
final String key = frame.getColumnName(iidx);
final String value = columns[colIdx][rowIdx];
if (value != null) {
rowData.put(key, value);
}
}
try {
switch (modelCategory) {
case Binomial: {
final BinomialModelPrediction p = easyPredictModelWrapper.predictBinomial(rowData);
setPrediction(frame, rowIdx, p.classProbabilities);
}
break;
case Multinomial: {
final MultinomialModelPrediction p = easyPredictModelWrapper.predictMultinomial(rowData);
setPrediction(frame, rowIdx, p.classProbabilities);
}
break;
case Regression: {
final RegressionModelPrediction p = easyPredictModelWrapper.predictRegression(rowData);
final MojoColumnFloat64 col = (MojoColumnFloat64) frame.getColumn(oindices[0]);
final double[] darr = (double[]) col.getData();
darr[rowIdx] = p.value;
}
break;
default:
throw new UnsupportedOperationException("Unsupported ModelCategory: " + modelCategory.toString());
}
} catch (UnsupportedOperationException e) {
throw e;
} catch (PredictException e) {
if (ai.h2o.mojos.runtime.utils.Debug.getPrintH2O3Exceptions()) e.printStackTrace();
throw new UnsupportedOperationException(String.format("%s failed: %s", modelCategory, e.getMessage()));
} catch (Exception e) {
if (ai.h2o.mojos.runtime.utils.Debug.getPrintH2O3Exceptions()) e.printStackTrace();
throw new UnsupportedOperationException(String.format("%s failed with %s: %s", modelCategory, e.getClass().getName(), e.getMessage()));
}
}
}
private void setPrediction(MojoFrame frame, int rowIdx, double[] classProbabilities) {
for (int outputColIdx = 0; outputColIdx < oindices.length; outputColIdx++) {
final int oidx = oindices[outputColIdx];
final MojoColumnFloat64 col = (MojoColumnFloat64) frame.getColumn(oidx);
final double[] darr = (double[]) col.getData();
darr[rowIdx] = classProbabilities[outputColIdx];
}
}
}