All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.genmodel.MojoPipelineWriter Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package hex.genmodel;

import hex.ModelCategory;
import hex.genmodel.descriptor.ModelDescriptor;
import hex.genmodel.attributes.Table;
import hex.genmodel.attributes.VariableImportances;

import java.io.IOException;
import java.util.Arrays;
import java.util.Date;
import java.util.LinkedHashMap;
import java.util.Map;

public class MojoPipelineWriter extends AbstractMojoWriter {

  private Map _models;
  private Map _inputMapping;
  private String _mainModelAlias;

  MojoPipelineWriter(Map models, Map inputMapping, String mainModelAlias) {
    super(makePipelineDescriptor(models, inputMapping, mainModelAlias));
    _models = models;
    _inputMapping = inputMapping;
    _mainModelAlias = mainModelAlias;
  }

  @Override
  public String mojoVersion() {
    return "1.00";
  }

  @Override
  protected void writeModelData() throws IOException {
    writekv("submodel_count", _models.size());
    int modelNum = 0;
    for (Map.Entry model : _models.entrySet()) {
      writekv("submodel_key_" + modelNum, model.getKey());
      writekv("submodel_dir_" + modelNum, "models/" + model.getKey() + "/");
      modelNum++;
    }
    writekv("generated_column_count", _inputMapping.size());
    int generatedColumnNum = 0;
    for (Map.Entry mapping : _inputMapping.entrySet()) {
      String inputSpec = mapping.getValue();
      String[] inputSpecArr = inputSpec.split(":", 2);
      writekv("generated_column_name_" + generatedColumnNum, mapping.getKey());
      writekv("generated_column_model_" + generatedColumnNum, inputSpecArr[0]);
      writekv("generated_column_index_" + generatedColumnNum, Integer.valueOf(inputSpecArr[1]));
      generatedColumnNum++;
    }
    writekv("main_model", _mainModelAlias);
  }

  private static MojoPipelineDescriptor makePipelineDescriptor(
          Map models, Map inputMapping, String mainModelAlias) {
    MojoModel finalModel = models.get(mainModelAlias);
    if (finalModel == null) {
      throw new IllegalArgumentException("Main model is missing. There is no model with alias '" + mainModelAlias + "'.");
    }
    LinkedHashMap schema = deriveInputSchema(models, inputMapping, finalModel);
    return new MojoPipelineDescriptor(schema, finalModel);
  }

  private static LinkedHashMap deriveInputSchema(
          Map allModels, Map inputMapping, MojoModel finalModel) {
    LinkedHashMap schema = new LinkedHashMap<>();

    for (MojoModel model : allModels.values()) {
      if (model == finalModel) {
        continue;
      }
      for (int i = 0; i < model.nfeatures(); i++) {
        String fName = model._names[i];
        if (schema.containsKey(fName)) { // make sure the domain matches
          String[] domain = schema.get(fName);
          if (! Arrays.equals(domain, model._domains[i])) {
            throw new IllegalStateException("Domains of column '" + fName + "' differ.");
          }
        } else {
          schema.put(fName, model._domains[i]);
        }
      }
    }

    for (int i = 0; i < finalModel._names.length; i++) { // we include the response of the final model as well
      String fName = finalModel._names[i];
      if (! inputMapping.containsKey(fName)) {
        schema.put(fName, finalModel._domains[i]);
      }
    }

    return schema;
  }

  private static class MojoPipelineDescriptor implements ModelDescriptor {

    private final MojoModel _finalModel;
    private final String[] _names;
    private final String[][] _domains;

    private MojoPipelineDescriptor(LinkedHashMap schema, MojoModel finalModel) {
      _finalModel = finalModel;
      _names = new String[schema.size()];
      _domains = new String[schema.size()][];
      int i = 0;
      for (Map.Entry field : schema.entrySet()) {
        _names[i] = field.getKey();
        _domains[i] = field.getValue();
        i++;
      }
    }

    @Override
    public String[][] scoringDomains() {
      return _domains;
    }

    @Override
    public String projectVersion() {
      return _finalModel._h2oVersion;
    }

    @Override
    public String algoName() {
      return "pipeline";
    }

    @Override
    public String algoFullName() {
      return "MOJO Pipeline";
    }

    @Override
    public String offsetColumn() {
      return _finalModel._offsetColumn;
    }

    @Override
    public String weightsColumn() {
      return null;
    }

    @Override
    public String foldColumn() {
      return null;
    }

    @Override
    public String treatmentColumn() { return null; }

    @Override
    public ModelCategory getModelCategory() {
      return _finalModel._category;
    }

    @Override
    public boolean isSupervised() {
      return _finalModel.isSupervised();
    }

    @Override
    public int nfeatures() {
      return isSupervised() ? columnNames().length - 1 : columnNames().length;
    }

    @Override
    public String[] features() {
      return Arrays.copyOf(columnNames(), nfeatures());
    }

    @Override
    public int nclasses() {
      return _finalModel.nclasses();
    }

    @Override
    public String[] columnNames() {
      return _names;
    }

    @Override
    public boolean balanceClasses() {
      return _finalModel._balanceClasses;
    }

    @Override
    public double defaultThreshold() {
      return _finalModel._defaultThreshold;
    }

    @Override
    public double[] priorClassDist() {
      return _finalModel._priorClassDistrib;
    }

    @Override
    public double[] modelClassDist() {
      return _finalModel._modelClassDistrib;
    }

    @Override
    public String uuid() {
      return _finalModel._uuid;
    }

    @Override
    public String timestamp() {
      return String.valueOf(new Date().getTime());
    }

    @Override
    public String[] getOrigNames() {
      return null;
    }

    @Override
    public String[][] getOrigDomains() {
      return null;
    }
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy