All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.genmodel.MojoPipelineBuilder Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package hex.genmodel;

import hex.genmodel.utils.IOUtils;

import java.io.*;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.zip.ZipEntry;
import java.util.zip.ZipFile;
import java.util.zip.ZipOutputStream;

public class MojoPipelineBuilder {

  private final Map _files = new HashMap<>();
  private final Map _models = new HashMap<>();
  private final Map _mapping = new HashMap<>();
  private String _mainModelAlias;

  public MojoPipelineBuilder addModel(String modelAlias, File mojoFile) throws IOException {
    MojoModel model = MojoModel.load(mojoFile.getAbsolutePath());
    _files.put(modelAlias, mojoFile);
    _models.put(modelAlias, model);
    return this;
  }

  public MojoPipelineBuilder addMainModel(String modelAlias, File mojoFile) throws IOException {
    addModel(modelAlias, mojoFile);
    _mainModelAlias = modelAlias;
    return this;
  }

  public MojoPipelineBuilder addMappings(List specs) {
    for (MappingSpec spec : specs) {
      addMapping(spec);
    }
    return this;
  }

  public MojoPipelineBuilder addMapping(MappingSpec spec) {
    return addMapping(spec._columnName, spec._modelAlias, spec._predsIndex);
  }

  public MojoPipelineBuilder addMapping(String columnName, String sourceModelAlias, int sourceModelPredictionIndex) {
    _mapping.put(columnName, sourceModelAlias + ":" + sourceModelPredictionIndex);
    return this;
  }

  public void buildPipeline(File pipelineFile) throws IOException {
    MojoPipelineWriter w = new MojoPipelineWriter(_models, _mapping, _mainModelAlias);

    try (FileOutputStream fos = new FileOutputStream(pipelineFile);
         ZipOutputStream zos = new ZipOutputStream(fos)) {
      w.writeTo(zos);
      for (Map.Entry mojoFile : _files.entrySet()) {
        try (ZipFile zf = new ZipFile(mojoFile.getValue())) {
          Enumeration entries = zf.entries();
          while (entries.hasMoreElements()) {
            ZipEntry ze = entries.nextElement();

            ZipEntry copy = new ZipEntry("models/" + mojoFile.getKey() + "/" + ze.getName());
            if (copy.getSize() >= 0) {
              copy.setSize(copy.getSize());
            }
            copy.setTime(copy.getTime());
            zos.putNextEntry(copy);
            try (InputStream input = zf.getInputStream(zf.getEntry(ze.getName()))) {
              IOUtils.copyStream(input, zos);
            }
            zos.closeEntry();
          }
        }
      }
    }
  }

  public static class MappingSpec {
    public String _columnName;
    public String _modelAlias;
    public int _predsIndex;

    public static MappingSpec parse(String spec) throws NumberFormatException, IndexOutOfBoundsException {
      MappingSpec ms = new MappingSpec();
      String[] parts = spec.split("=", 2);
      ms._columnName = parts[0];
      parts = parts[1].split(":", 2);
      ms._modelAlias = parts[0];
      ms._predsIndex = Integer.valueOf(parts[1]);
      return ms;
    }
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy