All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.genmodel.ModelMojoReader Maven / Gradle / Ivy

There is a newer version: 3.46.0.5
Show newest version
package hex.genmodel;

import com.google.gson.JsonObject;
import hex.genmodel.algos.isotonic.IsotonicCalibrator;
import hex.genmodel.attributes.ModelAttributes;
import hex.genmodel.attributes.ModelJsonReader;
import hex.genmodel.attributes.Table;
import hex.genmodel.descriptor.ModelDescriptorBuilder;
import hex.genmodel.utils.DistributionFamily;
import hex.genmodel.utils.LinkFunctionType;
import hex.genmodel.utils.ParseUtils;
import hex.genmodel.utils.StringEscapeUtils;

import java.io.BufferedReader;
import java.io.Closeable;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;

/**
 * Helper class to deserialize a model from MOJO format. This is a counterpart to `ModelMojoWriter`.
 */
public abstract class ModelMojoReader {

  protected M _model;

  protected MojoReaderBackend _reader;
  private Map _lkv;


  /**
   * De-serializes a {@link MojoModel}, creating an instance of {@link MojoModel} useful for scoring
   * and model evaluation.
   *
   * @param reader An instance of {@link MojoReaderBackend} to read from existing MOJO. After the model is de-serialized,
   *               the {@link MojoReaderBackend} instance is automatically closed if it implements {@link Closeable}.
   * @return De-serialized {@link MojoModel}
   * @throws IOException Whenever there is an error reading the {@link MojoModel}'s data.
   */
  public static MojoModel readFrom(MojoReaderBackend reader) throws IOException {
    return readFrom(reader, false);
  }

  /**
   * De-serializes a {@link MojoModel}, creating an instance of {@link MojoModel} useful for scoring
   * and model evaluation.
   *
   * @param reader      An instance of {@link MojoReaderBackend} to read from existing MOJO
   * @param readModelMetadata If true, parses also model metadata (model performance metrics... {@link ModelAttributes})
   *                          Model metadata are not required for scoring, it is advised to leave this option disabled
   *                          if you want to use MOJO for inference only.
   * @return De-serialized {@link MojoModel}
   * @throws IOException Whenever there is an error reading the {@link MojoModel}'s data.
   */
  public static MojoModel readFrom(MojoReaderBackend reader, final boolean readModelMetadata) throws IOException {
    try {
      Map info = parseModelInfo(reader);
      if (! info.containsKey("algorithm"))
        throw new IllegalStateException("Unable to find information about the model's algorithm.");
      String algo = String.valueOf(info.get("algorithm"));
      ModelMojoReader mmr = ModelMojoFactory.INSTANCE.getMojoReader(algo);
      mmr._lkv = info;
      mmr._reader = reader;
      mmr.readAll(readModelMetadata);
      return mmr._model;
    } finally {
      if (reader instanceof Closeable)
        ((Closeable) reader).close();
    }
  }

  public abstract String getModelName();

  //--------------------------------------------------------------------------------------------------------------------
  // Inheritance interface: ModelMojoWriter subclasses are expected to override these methods to provide custom behavior
  //--------------------------------------------------------------------------------------------------------------------

  protected abstract void readModelData() throws IOException;

  protected abstract M makeModel(String[] columns, String[][] domains, String responseColumn);

  protected M makeModel(String[] columns, String[][] domains, String responseColumn, String treatmentColumn){
    return makeModel(columns, domains, responseColumn);
  }

  /**
   * Maximal version of the mojo file current model reader supports. Follows the major.minor
   * format, where minor is a 2-digit number. For example "1.00",
   * "2.05", "2.13". See README in mojoland repository for more details.
   */
  public abstract String mojoVersion();


  //--------------------------------------------------------------------------------------------------------------------
  // Interface for subclasses
  //--------------------------------------------------------------------------------------------------------------------

  /**
   * Retrieve value from the model's kv store which was previously put there using `writekv(key, value)`. We will
   * attempt to cast to your expected type, but this is obviously unsafe. Note that the value is deserialized from
   * the underlying string representation using {@link ParseUtils#tryParse(String, Object)}, which occasionally may get
   * the answer wrong.
   * If the `key` is missing in the local kv store, null will be returned. However when assigning to a primitive type
   * this would result in an NPE, so beware.
   */
  @SuppressWarnings("unchecked")
  protected  T readkv(String key) {
    return (T) readkv(key, null);
  }

  /**
   * Retrieves the value associated with a given key. If value is not set of the key, a given default value is returned
   * instead. Uses same parsing logic as {@link ModelMojoReader#readkv(String)}. If default value is not null it's type
   * is used to assist the parser to determine the return type.
   * @param key name of the key
   * @param defVal default value
   * @param  return type
   * @return parsed value
   */
  @SuppressWarnings("unchecked")
  protected  T readkv(String key, T defVal) {
    Object val = _lkv.get(key);
    if (! (val instanceof RawValue))
      return val != null ? (T) val : defVal;
    return ((RawValue) val).parse(defVal);
  }
  
  /**
   * Retrieve binary data previously saved to the mojo file using `writeblob(key, blob)`.
   */
  protected byte[] readblob(String name) throws IOException {
    return getMojoReaderBackend().getBinaryFile(name);
  }

  protected boolean exists(String name) {
    return getMojoReaderBackend().exists(name);
  }

  /**
   * Retrieve text previously saved using `startWritingTextFile` + `writeln` as an array of lines. Each line is
   * trimmed to remove the leading and trailing whitespace.
   */
  protected Iterable readtext(String name) throws IOException {
    return readtext(name, false);
  }

  /**
   * Retrieve text previously saved using `startWritingTextFile` + `writeln` as an array of lines. Each line is
   * trimmed to remove the leading and trailing whitespace. Removes escaping of the new line characters in enabled.
   */
  protected Iterable readtext(String name, boolean unescapeNewlines) throws IOException {
    ArrayList res = new ArrayList<>(50);
    BufferedReader br = getMojoReaderBackend().getTextFile(name);
    try {
      String line;
      while (true) {
        line = br.readLine();
        if (line == null) break;
        if (unescapeNewlines)
          line = StringEscapeUtils.unescapeNewlines(line);
        res.add(line.trim());
      }
      br.close();
    } finally {
      try { br.close(); } catch (IOException e) { /* ignored */ }
    }
    return res;
  }
  
  protected String[] readStringArray(String name, int size) throws IOException {
    String[] array = new String[size];
    int i = 0;
    for (String line : readtext(name, true)) {
      array[i++] = line;
    }
    return array;
  }

  protected IsotonicCalibrator readIsotonicCalibrator() throws IOException {
    return new IsotonicCalibrator(
      readkv("calib_min_x", Double.NaN),
      readkv("calib_max_x", Double.NaN),
      readblobDoubles("calib/thresholds_x"),
      readblobDoubles("calib/thresholds_y")
    );
  }

  private double[] readblobDoubles(String filename) throws IOException {
    byte[] bytes = readblob(filename);
    final ByteBuffer bb = ByteBuffer.wrap(bytes);
    int size = bb.getInt();
    double[] doubles = new double[size];
    for (int i = 0; i < doubles.length; i++) {
      doubles[i] = bb.getDouble();
    }
    return doubles;
  }

  /**
   * Reads an two dimensional array written by {@link hex.ModelMojoWriter#writeRectangularDoubleArray} method.
   * 
   * Dimensions of the result are explicitly given as parameters.
   * 
   * @param title can't be null
   * @param firstSize 
   * @param secondSize
   * @return and double[][]  array with dimensions firstSize and secondSize
   * @throws IOException
   */
  protected double[][] readRectangularDoubleArray(String title, int firstSize, int secondSize) throws IOException {
    assert null != title;
    
    final double [][] row = new double[firstSize][secondSize];
    final ByteBuffer bb = ByteBuffer.wrap(readblob(title));
    for (int i = 0; i < firstSize; i++) {
      for (int j = 0; j < secondSize; j++)
        row[i][j] = bb.getDouble();
    }
    return row;
  }
  
  /**
   * Reads an two dimensional array written by {@link hex.ModelMojoWriter#writeRectangularDoubleArray} method
   * 
   * Dimensions of the array are read from the mojo.
   *
   * @param title can't be null
   * @throws IOException
   */
  protected double[][] readRectangularDoubleArray(String title) throws IOException {
    assert null != title;

    final int size1 = readkv(title + "_size1");
    final int size2 = readkv(title + "_size2");

    return readRectangularDoubleArray(title, size1, size2);
  }

  //--------------------------------------------------------------------------------------------------------------------
  // Private
  //--------------------------------------------------------------------------------------------------------------------

  private void readAll(final boolean readModelMetadata) throws IOException {
    String[] columns = (String[]) _lkv.get("[columns]");
    String[][] domains = parseModelDomains(columns.length);
    boolean isSupervised = readkv("supervised");
    _model = makeModel(columns, domains, isSupervised ? columns[columns.length - 1] : null, (String) readkv("treatment_column"));
    _model._uuid = readkv("uuid");
    _model._algoName = readkv("algo");
    _model._h2oVersion = readkv("h2o_version", "unknown");
    _model._category = hex.ModelCategory.valueOf((String) readkv("category"));
    _model._supervised = isSupervised;
    _model._nfeatures = readkv("n_features");
    _model._nclasses = readkv("n_classes");
    _model._balanceClasses = readkv("balance_classes");
    _model._defaultThreshold = readkv("default_threshold");
    _model._priorClassDistrib = readkv("prior_class_distrib");
    _model._modelClassDistrib = readkv("model_class_distrib");
    _model._offsetColumn = readkv("offset_column");
    _model._mojo_version = ((Number) readkv("mojo_version")).doubleValue();
    checkMaxSupportedMojoVersion();
    readModelData();
    if (readModelMetadata) {
      final String algoFullName = readkv("algorithm"); // The key'algo' contains the shortcut, 'algorithm' is the long version
      _model._modelAttributes = readModelSpecificAttributes();
      _model._modelDescriptor = ModelDescriptorBuilder.makeDescriptor(_model, algoFullName, _model._modelAttributes);
    }
    _model._reproducibilityInformation = readReproducibilityInformation() ;
  }

  protected Table[] readReproducibilityInformation() {
    final JsonObject modelJson = ModelJsonReader.parseModelJson(_reader);
    if (modelJson != null && modelJson.get("output") != null) {
      return ModelJsonReader.readTableArray(modelJson, "output.reproducibility_information_table");
    }
    return null;
  }

  protected ModelAttributes readModelSpecificAttributes() {
    final JsonObject modelJson = ModelJsonReader.parseModelJson(_reader);
    if(modelJson != null) {
      return new ModelAttributes(_model, modelJson);
    } else {
      return null;
    }
  }

  private static Map parseModelInfo(MojoReaderBackend reader) throws IOException {
    Map info = new HashMap<>();
    BufferedReader br = reader.getTextFile("model.ini");
    try {
      String line;
      int section = 0;
      int ic = 0;  // Index for `columns` array
      String[] columns = new String[0];  // array of column names, will be initialized later
      Map domains = new HashMap<>();  // map of (categorical column index => name of the domain file)
      while (true) {
        line = br.readLine();
        if (line == null) break;
        line = line.trim();
        if (line.startsWith("#") || line.isEmpty()) continue;
        if (line.equals("[info]"))
          section = 1;
        else if (line.equals("[columns]")) {
          section = 2;  // Enter the [columns] section
          if (! info.containsKey("n_columns"))
            throw new IOException("`n_columns` variable is missing in the model info.");
          int n_columns = Integer.parseInt(((RawValue) info.get("n_columns"))._val);
          columns = new String[n_columns];
          info.put("[columns]", columns);
        } else if (line.equals("[domains]")) {
          section = 3; // Enter the [domains] section
          info.put("[domains]", domains);
        } else if (section == 1) {
          // [info] section: just parse key-value pairs and store them into the `info` map.
          String[] res = line.split("\\s*=\\s*", 2);
          info.put(res[0], res[0].equals("uuid")? res[1] : new RawValue(res[1]));
        } else if (section == 2) {
          // [columns] section
          if (ic >= columns.length)
            throw new IOException("`n_columns` variable is too small.");
          columns[ic++] = line;
        } else if (section == 3) {
          // [domains] section
          String[] res = line.split(":\\s*", 2);
          int col_index = Integer.parseInt(res[0]);
          domains.put(col_index, res[1]);
        }
      }
      br.close();
    } finally {
      try { br.close(); } catch (IOException e) { /* ignored */ }
    }
    return info;
  }

  private String[][] parseModelDomains(int n_columns) throws IOException {
    final boolean escapeDomainValues = Boolean.TRUE.equals(readkv("escape_domain_values")); // The key might not exist in older MOJOs
    String[][] domains = new String[n_columns][];
    // noinspection unchecked
    Map domass = (Map) _lkv.get("[domains]");
    for (Map.Entry e : domass.entrySet()) {
      int col_index = e.getKey();
      // There is a file with categories of the response column, but we ignore it.
      if (col_index >= n_columns) continue;
      String[] info = e.getValue().split(" ", 2);
      int n_elements = Integer.parseInt(info[0]);
      String domfile = info[1];
      String[] domain = new String[n_elements];

      try (BufferedReader br = getMojoReaderBackend().getTextFile("domains/" + domfile)) {
        String line;
        int id = 0;  // domain elements counter
        while ((line = br.readLine()) != null) {
          if (escapeDomainValues) {
            line = StringEscapeUtils.unescapeNewlines(line);
          }
          domain[id++] = line;
        }
        if (id != n_elements)
          throw new IOException("Not enough elements in the domain file");
      }
      domains[col_index] = domain;
    }
    return domains;
  }

  private static class RawValue {
    private final String _val;
    RawValue(String val) { _val = val; }
    @SuppressWarnings("unchecked")
     T parse(T defVal) { return (T) ParseUtils.tryParse(_val, defVal); }
    @Override
    public String toString() { return _val; }
  }

  private void checkMaxSupportedMojoVersion() throws IOException {
    if(_model._mojo_version > Double.parseDouble(mojoVersion())){
      throw new IOException(String.format("MOJO version incompatibility - the model MOJO version (%.2f) is higher than the current H2O version (%s) supports. Please, use a newer version of H2O to load MOJO model.", _model._mojo_version, mojoVersion()));
    }
  }

  public static LinkFunctionType readLinkFunction(String linkFunctionTypeName, DistributionFamily family) {
    if (linkFunctionTypeName != null)
      return LinkFunctionType.valueOf(linkFunctionTypeName);
    return defaultLinkFunction(family);
  }

  public static LinkFunctionType defaultLinkFunction(DistributionFamily family){
    switch (family) {
      case bernoulli:
      case fractionalbinomial:
      case quasibinomial:
      case modified_huber:
      case ordinal:
        return LinkFunctionType.logit;
      case multinomial:
      case poisson:
      case gamma:
      case tweedie:
        return LinkFunctionType.log;
      default:
        return LinkFunctionType.identity;
    }
  }

  protected MojoReaderBackend getMojoReaderBackend() {
    return _reader;
  }

  public String[] readStringArrays(int aSize, String title) throws IOException {
    String[] stringArrays = new String[aSize];
    int counter = 0;
    for (String line : readtext(title)) {
      stringArrays[counter++] = line;
    }
    return stringArrays;
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy