
ai.h2o.mojos.runtime.h2o3.KlimeTransform Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mojo2-runtime-h2o3-impl Show documentation
Show all versions of mojo2-runtime-h2o3-impl Show documentation
MOJO2 H2O-3 Runtime Implementation
package ai.h2o.mojos.runtime.h2o3;
import ai.h2o.mojos.runtime.frame.MojoColumnFloat64;
import ai.h2o.mojos.runtime.frame.MojoFrame;
import ai.h2o.mojos.runtime.frame.MojoFrameMeta;
import ai.h2o.mojos.runtime.transforms.MojoTransform;
import hex.ModelCategory;
import hex.genmodel.GenModel;
import hex.genmodel.easy.EasyPredictModelWrapper;
import hex.genmodel.easy.RowData;
import hex.genmodel.easy.exception.PredictException;
import hex.genmodel.easy.prediction.KLimeModelPrediction;
/**
* A MOJO2 pipeline implementation that uses a k-LIME MOJO (built on H2O-3 MOJO framework)
* as the predictor inside. The intent is to provide as identical an
* experience to the MOJO2 API as possible.
*/
public class KlimeTransform extends MojoTransform {
private final EasyPredictModelWrapper easyPredictModelWrapper;
private final GenModel genModel;
/**
* A MOJO2 transformer implementation that uses a k-LIME MOJO (built on H2O-3 MOJO framework) as the predictor inside.
*
* @param easyPredictModelWrapper H2O-3 MOJO model.
*/
KlimeTransform(MojoFrameMeta meta, int[] iindices, int[] oindices, EasyPredictModelWrapper easyPredictModelWrapper) {
super(iindices, oindices);
this.easyPredictModelWrapper = easyPredictModelWrapper;
this.genModel = easyPredictModelWrapper.m;
}
@Override
public void transform(final MojoFrame frame) {
final ModelCategory modelCategory = genModel.getModelCategory();
final int colCount = iindices.length;
final int rowCount = frame.getNrows();
final String[][] columns = new String[colCount][];
for (int j = 0; j < colCount; j += 1) {
final int iidx = iindices[j];
columns[j] = frame.getColumn(iidx).getDataAsStrings();
}
for (int rowIdx = 0; rowIdx < rowCount; rowIdx++) {
final RowData rowData = new RowData();
for (int colIdx = 0; colIdx < colCount; colIdx++) {
final int iidx = iindices[colIdx];
final String key = frame.getColumnName(iidx);
final String value = columns[colIdx][rowIdx];
if (value != null) {
rowData.put(key, value);
}
}
try {
final KLimeModelPrediction p = easyPredictModelWrapper.predictKLime(rowData);
for (int outputColIdx = 0; outputColIdx < genModel.getPredsSize(); outputColIdx++) {
final MojoColumnFloat64 col = (MojoColumnFloat64) frame.getColumn(oindices[outputColIdx]);
final double[] darr = (double[]) col.getData();
switch (outputColIdx) {
case 0:
darr[rowIdx] = p.value;
break;
case 1:
darr[rowIdx] = p.cluster;
break;
default:
darr[rowIdx] = p.reasonCodes[outputColIdx - 2];
break;
}
}
} catch (PredictException e) {
if (ai.h2o.mojos.runtime.utils.Debug.getPrintH2O3Exceptions()) e.printStackTrace();
throw new UnsupportedOperationException(String.format("%s failed: %s", modelCategory, e.getMessage()));
} catch (Exception e) {
if (ai.h2o.mojos.runtime.utils.Debug.getPrintH2O3Exceptions()) e.printStackTrace();
throw new UnsupportedOperationException(String.format("%s failed with %s: %s", modelCategory, e.getClass().getName(), e.getMessage()));
}
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy