All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.genmodel.algos.deepwater.DeepwaterMojoModel Maven / Gradle / Ivy

package hex.genmodel.algos.deepwater;

import deepwater.backends.BackendModel;
import deepwater.backends.BackendParams;
import deepwater.backends.BackendTrain;
import deepwater.backends.RuntimeOptions;
import deepwater.datasets.ImageDataSet;
import hex.genmodel.ConverterFactoryProvidingModel;
import hex.genmodel.GenModel;
import hex.genmodel.MojoModel;
import hex.genmodel.algos.deepwater.caffe.DeepwaterCaffeBackend;
import hex.genmodel.easy.CategoricalEncoder;
import hex.genmodel.easy.EasyPredictModelWrapper;
import hex.genmodel.easy.RowToRawDataConverter;

import java.io.File;
import java.util.Map;

public class DeepwaterMojoModel extends MojoModel implements ConverterFactoryProvidingModel {
  public String _problem_type;
  public int _mini_batch_size;
  public int _height;
  public int _width;
  public int _channels;

  public int _nums;
  public int _cats;
  public int[] _catOffsets;
  public double[] _normMul;
  public double[] _normSub;
  public double[] _normRespMul;
  public double[] _normRespSub;
  public boolean _useAllFactorLevels;

  transient byte[] _network;
  transient byte[] _parameters;
  public transient float[] _meanImageData;

  BackendTrain _backend; //interface provider
  BackendModel _model;  //pointer to C++ process
  ImageDataSet _imageDataSet; //interface provider
  RuntimeOptions _opts;
  BackendParams _backendParams;

  DeepwaterMojoModel(String[] columns, String[][] domains, String responseColumn) {
    super(columns, domains, responseColumn);
  }


  /**
   * Corresponds to `hex.DeepWater.score0()`
   */
  @Override
  public final double[] score0(double[] doubles, double offset, double[] preds) {
    assert(doubles != null) : "doubles are null";
    float[] floats;
    int cats = _catOffsets == null ? 0 : _catOffsets[_cats];
    if (_nums > 0) {
      floats = new float[_nums + cats]; //TODO: use thread-local storage
      GenModel.setInput(doubles, floats, _nums, _cats, _catOffsets, _normMul, _normSub, _useAllFactorLevels, true);
    } else {
      floats = new float[doubles.length];
      for (int i=0; i 1) {
      for (int i = 0; i < predFloats.length; ++i)
        preds[1 + i] = predFloats[i];
      if (_balanceClasses)
        GenModel.correctProbabilities(preds, _priorClassDistrib, _modelClassDistrib);
      preds[0] = GenModel.getPrediction(preds, _priorClassDistrib, doubles, _defaultThreshold);
    } else {
      if (_normRespMul!=null && _normRespSub!=null)
        preds[0] = predFloats[0] * _normRespMul[0] + _normRespSub[0];
      else
        preds[0] = predFloats[0];
    }
    return preds;
  }

  @Override
  public double[] score0(double[] row, double[] preds) {
    return score0(row, 0.0, preds);
  }

  static public BackendTrain createDeepWaterBackend(String backend) {
    try {
      // For Caffe, only instantiate if installed at the right place
      File f = new File(DeepwaterCaffeBackend.CAFFE_H2O_DIR);
      if (backend.equals("caffe") && f.exists() && f.isDirectory())
        return new DeepwaterCaffeBackend();
      if (backend.equals("mxnet"))
        backend="deepwater.backends.mxnet.MXNetBackend";
      else if (backend.equals("tensorflow"))
        backend = "deepwater.backends.tensorflow.TensorflowBackend";
//      else if (backend.equals("xgrpc"))
//        backend="deepwater.backends.grpc.XGRPCBackendTrain";
      return (BackendTrain) (Class.forName(backend).newInstance());
    } catch (Exception ignored) {
      //ignored.printStackTrace();
    }
    return null;
  }

  @Override
  public RowToRawDataConverter makeConverterFactory(Map modelColumnNameToIndexMap,
                                                    Map domainMap,
                                                    EasyPredictModelWrapper.ErrorConsumer errorConsumer,
                                                    EasyPredictModelWrapper.Config config) {
    if (_problem_type.equals("image"))
      return new DWImageConverter(this, modelColumnNameToIndexMap, domainMap, errorConsumer, config);
    else if (_problem_type.equals("text")) {
      return new DWTextConverter(this, modelColumnNameToIndexMap, domainMap, errorConsumer, config);
    }
    return new RowToRawDataConverter(this, modelColumnNameToIndexMap, domainMap,
            errorConsumer, config);

  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy