
ai.h2o.mojos.runtime.h2o3.H2O3PipelineLoader Maven / Gradle / Ivy
package ai.h2o.mojos.runtime.h2o3;
import ai.h2o.mojos.runtime.AbstractPipelineLoader;
import ai.h2o.mojos.runtime.MojoPipeline;
import ai.h2o.mojos.runtime.MojoPipelineMeta;
import ai.h2o.mojos.runtime.MojoPipelineProtoImpl;
import ai.h2o.mojos.runtime.api.MojoColumnMeta;
import ai.h2o.mojos.runtime.api.MojoTransformMeta;
import ai.h2o.mojos.runtime.api.PipelineConfig;
import ai.h2o.mojos.runtime.api.backend.ReaderBackend;
import ai.h2o.mojos.runtime.frame.MojoColumn;
import ai.h2o.mojos.runtime.frame.MojoFrameMeta;
import ai.h2o.mojos.runtime.transforms.MojoTransform;
import ai.h2o.mojos.runtime.transforms.MojoTransformExecPipeBuilder;
import hex.genmodel.GenModel;
import hex.genmodel.MojoModel;
import hex.genmodel.MojoReaderBackend;
import hex.genmodel.easy.EasyPredictModelWrapper;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.joda.time.DateTime;
class H2O3PipelineLoader extends AbstractPipelineLoader {
private final List globalColumns;
private final MojoTransformExecPipeBuilder root;
public H2O3PipelineLoader(ReaderBackend backend, PipelineConfig config) throws IOException {
super(backend, config);
final MojoReaderBackend mojoReader = new H2O3BackendAdapter(backend);
final MojoModel model = MojoModel.load(mojoReader);
final EasyPredictModelWrapper easyPredictModelWrapper = wrapModelForPrediction(model);
final String name = "h2o3:" + model.getModelCategory().toString();
this.globalColumns = new ArrayList<>();
final int[] inputIndices = readInputIndices(globalColumns, model);
final int[] outputIndices = readOutputIndices(globalColumns, model);
final MojoFrameMeta globalMeta = new MojoFrameMeta(globalColumns);
final MojoTransform transform = new H2O3Transform(globalMeta, inputIndices, outputIndices, easyPredictModelWrapper);
transform.setId("h2o3-main");
transform.setName(name);
final DateTime creationTime = new DateTime(1970, 1, 1, 0, 0); //TODO
final MojoPipelineMeta pipelineMeta = new MojoPipelineMeta(
model.getUUID(), creationTime);
pipelineMeta.license = "H2O-3 Opensource";
this.root = new MojoTransformExecPipeBuilder(inputIndices, outputIndices, transform, pipelineMeta);
this.root.transforms.add(transform);
}
static int[] readInputIndices(final List columns, final GenModel genModel) {
final int[] inputIndices = new int[genModel.getNumCols()];
for (int i = 0; i < genModel.getNumCols(); i += 1) {
final String columnName = genModel.getNames()[i];
final MojoColumn.Type columnType = (genModel.getDomainValues(i) == null) ? MojoColumn.Type.Float64 : MojoColumn.Type.Str;
inputIndices[i] = columns.size();
columns.add(MojoColumnMeta.create(columnName, columnType));
}
return inputIndices;
}
private static int[] readOutputIndices(final List columns, final GenModel genModel) {
final int[] outputIndices;
switch (genModel.getModelCategory()) {
case Binomial:
case Multinomial: {
outputIndices = new int[genModel.getNumResponseClasses()];
for (int i = 0; i < genModel.getNumResponseClasses(); i += 1) {
final String columnName = genModel.getResponseName() + "." + genModel.getDomainValues(genModel.getResponseIdx())[i];
outputIndices[i] = columns.size();
columns.add(MojoColumnMeta.create(columnName, MojoColumn.Type.Float64));
}
return outputIndices;
}
case Regression: {
final MojoColumnMeta column = MojoColumnMeta.create(genModel.getResponseName(), MojoColumn.Type.Float64);
outputIndices = new int[]{columns.size()};
columns.add(column);
return outputIndices;
}
default:
throw new UnsupportedOperationException("Unsupported ModelCategory: " + genModel.getModelCategory().toString());
}
}
@Override
public List getColumns() {
return globalColumns;
}
@Override
public List getTransformations() {
return root.metaTransforms;
}
@Override
protected final MojoPipeline internalLoad() {
return new MojoPipelineProtoImpl(globalColumns, root, config);
}
/**
* Wraps the specified {@link MojoModel} as an {@link EasyPredictModelWrapper} with
* configuration to behave similar to Mojo2 behavior.
*
* This includes configuring the wrapper to tolerate and ignore (by forcing to NA) bad input
* without throwing an exception.
*/
static EasyPredictModelWrapper wrapModelForPrediction(MojoModel model) {
EasyPredictModelWrapper.Config config = new EasyPredictModelWrapper.Config()
.setModel(model)
.setConvertUnknownCategoricalLevelsToNa(true)
.setConvertInvalidNumbersToNa(true);
return new EasyPredictModelWrapper(config);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy