All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.h2o.mojos.runtime.h2o3.H2O3Transform Maven / Gradle / Ivy

There is a newer version: 2.8.7.1
Show newest version
package ai.h2o.mojos.runtime.h2o3;

import ai.h2o.mojos.runtime.frame.MojoColumnFloat64;
import ai.h2o.mojos.runtime.frame.MojoFrame;
import ai.h2o.mojos.runtime.frame.MojoFrameMeta;
import ai.h2o.mojos.runtime.transforms.MojoTransform;
import hex.ModelCategory;
import hex.genmodel.GenModel;
import hex.genmodel.easy.EasyPredictModelWrapper;
import hex.genmodel.easy.RowData;
import hex.genmodel.easy.exception.PredictException;
import hex.genmodel.easy.prediction.BinomialModelPrediction;
import hex.genmodel.easy.prediction.MultinomialModelPrediction;
import hex.genmodel.easy.prediction.RegressionModelPrediction;

/**
 * A MOJO2 pipeline implementation that uses an H2O-3 (or Sparkling Water)
 * MOJO as the predictor inside.  The intent is to provide as identical an
 * experience to the MOJO2 API as possible.
 * 

* A non-goal is to expose every possible low-level H2O-3 MOJO API capability. * If you want to do that, call the H2O-3 MOJO API directly, instead. */ public class H2O3Transform extends MojoTransform { private final GenModel genModel; private final EasyPredictModelWrapper easyPredictModelWrapper; /** * A MOJO2 transformer implementation that uses an H2O-3 (or Sparkling Water) MOJO as the predictor inside. *

* Must provide a valid Binomial, Multinomial or Regression model. * Other model types not currently supported. *

* Note: later, we might consider splitting this into one class per each supported model type, to more closely represent underlying H2O-3 algos. * * @param easyPredictModelWrapper H2O-3 MOJO model. */ H2O3Transform(MojoFrameMeta meta, int[] iindices, int[] oindices, EasyPredictModelWrapper easyPredictModelWrapper) { super(iindices, oindices); this.easyPredictModelWrapper = easyPredictModelWrapper; this.genModel = easyPredictModelWrapper.m; } @Override public void transform(final MojoFrame frame) { final ModelCategory modelCategory = genModel.getModelCategory(); final int colCount = iindices.length; final int rowCount = frame.getNrows(); final String[][] columns = new String[colCount][]; for (int j = 0; j < colCount; j += 1) { final int iidx = iindices[j]; columns[j] = frame.getColumn(iidx).getDataAsStrings(); } for (int rowIdx = 0; rowIdx < rowCount; rowIdx++) { final RowData rowData = new RowData(); for (int colIdx = 0; colIdx < colCount; colIdx++) { final int iidx = iindices[colIdx]; final String key = frame.getColumnName(iidx); final String value = columns[colIdx][rowIdx]; if (value != null) { rowData.put(key, value); } } try { switch (modelCategory) { case Binomial: { final BinomialModelPrediction p = easyPredictModelWrapper.predictBinomial(rowData); setPrediction(frame, rowIdx, p.classProbabilities); } break; case Multinomial: { final MultinomialModelPrediction p = easyPredictModelWrapper.predictMultinomial(rowData); setPrediction(frame, rowIdx, p.classProbabilities); } break; case Regression: { final RegressionModelPrediction p = easyPredictModelWrapper.predictRegression(rowData); final MojoColumnFloat64 col = (MojoColumnFloat64) frame.getColumn(oindices[0]); final double[] darr = (double[]) col.getData(); darr[rowIdx] = p.value; } break; default: throw new UnsupportedOperationException("Unsupported ModelCategory: " + modelCategory.toString()); } } catch (UnsupportedOperationException e) { throw e; } catch (PredictException e) { if (ai.h2o.mojos.runtime.utils.Debug.getPrintH2O3Exceptions()) e.printStackTrace(); throw new UnsupportedOperationException(String.format("%s failed: %s", modelCategory, e.getMessage())); } catch (Exception e) { if (ai.h2o.mojos.runtime.utils.Debug.getPrintH2O3Exceptions()) e.printStackTrace(); throw new UnsupportedOperationException(String.format("%s failed with %s: %s", modelCategory, e.getClass().getName(), e.getMessage())); } } } private void setPrediction(MojoFrame frame, int rowIdx, double[] classProbabilities) { for (int outputColIdx = 0; outputColIdx < oindices.length; outputColIdx++) { final int oidx = oindices[outputColIdx]; final MojoColumnFloat64 col = (MojoColumnFloat64) frame.getColumn(oidx); final double[] darr = (double[]) col.getData(); darr[rowIdx] = classProbabilities[outputColIdx]; } } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy