All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.genmodel.algos.pipeline.MojoPipelineReader Maven / Gradle / Ivy

There is a newer version: 3.46.0.5
Show newest version
package hex.genmodel.algos.pipeline;

import hex.genmodel.MojoModel;
import hex.genmodel.MultiModelMojoReader;

import java.util.LinkedList;
import java.util.Map;
import java.util.List;
import java.util.HashMap;

public class MojoPipelineReader extends MultiModelMojoReader {

  @Override
  public String getModelName() {
    return "MOJO Pipeline";
  }

  @Override
  protected void readParentModelData() {
    String mainModelAlias = readkv("main_model");
    String[] generatedColumns = readGeneratedColumns();

    _model._mainModel = getModel(mainModelAlias);
    _model._generatedColumnCount = generatedColumns.length;
    _model._targetMainModelRowIndices = new int[_model._mainModel._nfeatures - generatedColumns.length];
    _model._sourceRowIndices = findIndices(_model._names, _model._mainModel._names, _model._mainModel._nfeatures,
            _model._targetMainModelRowIndices, generatedColumns);

    Map> m2idxs = readModel2GeneratedColumnIndex();

    _model._models = new MojoPipeline.PipelineSubModel[getSubModels().size() - 1];
    int modelsCnt = 0;
    int genColsCnt = 0;
    for (Map.Entry subModel : getSubModels().entrySet()) {
      if (mainModelAlias.equals(subModel.getKey())) {
        continue;
      }
      final MojoModel m = subModel.getValue();
      final List generatedColsIdxs = m2idxs.get(subModel.getKey());

      MojoPipeline.PipelineSubModel psm = _model._models[modelsCnt++] = new MojoPipeline.PipelineSubModel();
      psm._mojoModel = m;
      psm._inputMapping = mapModelColumns(m);
      psm._predsSize = m.getPredsSize(m.getModelCategory());
      psm._sourcePredsIndices = new int[generatedColsIdxs.size()];
      String[] targetColNames = new String[generatedColsIdxs.size()];
      int t = 0;
      for (int i : generatedColsIdxs) {
        psm._sourcePredsIndices[t] = readkv("generated_column_index_" + i, 0);
        targetColNames[t] = readkv("generated_column_name_" + i, "");
        t++;
      }
      psm._targetRowIndices = findIndices(_model._mainModel._names, targetColNames);
      genColsCnt += t;
    }
    assert modelsCnt == _model._models.length;
    assert genColsCnt == _model._generatedColumnCount;
  }

  private Map> readModel2GeneratedColumnIndex() {
    final int cnt = readkv("generated_column_count", 0);
    Map> map = new HashMap<>(cnt);
    for (int i = 0; i < cnt; i++) {
      String alias = readkv("generated_column_model_" + i);
      if (! map.containsKey(alias)) {
        map.put(alias, new LinkedList());
      }
      List indices = map.get(alias);
      indices.add(i);
    }
    return map;
  }

  private String[] readGeneratedColumns() {
    final int cnt = readkv("generated_column_count", 0);
    final String[] names = new String[cnt];
    for (int i = 0; i < names.length; i++) {
      names[i] = readkv("generated_column_name_" + i, "");
    }
    return names;
  }

  @Override
  protected MojoPipeline makeModel(String[] columns, String[][] domains, String responseColumn) {
    return new MojoPipeline(columns, domains, responseColumn);
  }

  private int[] mapModelColumns(MojoModel subModel) {
    return findIndices(_model._names, subModel._names, subModel._nfeatures, null, new String[0]);
  }

  private static int[] findIndices(String[] strings, String[] subset) {
    return findIndices(strings, subset, subset.length, null, new String[0]);
  }

  private static int[] findIndices(String[] strings, String[] subset, int firstN, int[] outSubsetIdxs, String[] ignored) {
    final int[] idx = new int[firstN - ignored.length];
    assert outSubsetIdxs == null || outSubsetIdxs.length == idx.length;
    int cnt = 0;
    outer: for (int i = 0; i < firstN; i++) {
      final String s = subset[i];
      assert s != null;
      for (String si : ignored) {
        if (s.equals(si)) {
          continue outer;
        }
      }
      for (int j = 0; j < strings.length; j++) {
        if (s.equals(strings[j])) {
          if (outSubsetIdxs != null) {
            outSubsetIdxs[cnt] = i;
          }
          idx[cnt++] = j;
          continue outer;
        }
      }
      throw new IllegalStateException("Pipeline doesn't have input column '" + subset[i] + "'.");
    }
    assert cnt == idx.length;
    return idx;
  }

  @Override public String mojoVersion() {
    return "1.00";
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy