All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.generic.Generic Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package hex.generic;

import hex.ModelBuilder;
import hex.ModelCategory;
import hex.genmodel.*;
import hex.genmodel.descriptor.ModelDescriptor;
import hex.genmodel.descriptor.ModelDescriptorBuilder;
import water.H2O;
import water.Key;
import water.fvec.ByteVec;
import water.fvec.Frame;
import water.parser.ZipUtil;
import water.util.Log;

import java.io.IOException;
import java.net.URI;
import java.util.*;

/**
 * Generic model able to do scoring with any underlying model deserializable into a format known by the {@link GenericModel}.
 * Only H2O Mojos are currently supported.
 */
public class Generic extends ModelBuilder {

    /**
     * Unmodifiable {@link Set} of Algorithm MOJOs which are allowed to be imported as generic model
     */
    private static final Set ALLOWED_MOJO_ALGOS;
    static{ 
        final Set allowedAlgos = new HashSet<>(6);
        allowedAlgos.add("gbm");
        allowedAlgos.add("glm");
        allowedAlgos.add("xgboost");
        allowedAlgos.add("isolationforest");
        allowedAlgos.add("extendedisolationforest");
        allowedAlgos.add("drf");
        allowedAlgos.add("deeplearning");
        allowedAlgos.add("stackedensemble");
        allowedAlgos.add("coxph");
        allowedAlgos.add("rulefit");
        allowedAlgos.add("gam");
        
        ALLOWED_MOJO_ALGOS = Collections.unmodifiableSet(allowedAlgos);
    }


    public Generic(GenericModelParameters genericParameters){
        super(genericParameters);
        init(false);
    }

    public Generic(boolean startup_once) {
        super(new GenericModelParameters(), startup_once);
    }

    @Override
    public void init(boolean expensive) {
        super.init(expensive);
        if (_parms._path != null && _parms._model_key != null) {
            error("_path", 
                    "Path cannot be set for MOJO that is supposed to be loaded from distributed memory (key=" + _parms._model_key + ").");
        }
    }

    @Override
    protected Driver trainModelImpl() {
        return new MojoDelegatingModelDriver();
    }

    @Override
    public ModelCategory[] can_build() {
        return ModelCategory.values();
    }

    @Override
    public boolean haveMojo() {
        return true;
    }

    @Override
    public boolean isSupervised() {
        return false;
    }

    class MojoDelegatingModelDriver extends Driver {

        @Override
        public void compute2() {
            if (_parms._path != null) { // If there is a file to be imported, do the import before the scope is entered
                _parms._model_key = importFile();
            }
            super.compute2();
        }

        @Override
        public void computeImpl() {
            final Key dataKey;
            if (_parms._model_key != null) {
                dataKey = _parms._model_key;
            } else {
                throw new IllegalArgumentException("Either MOJO zip path or key to the uploaded MOJO frame must be specified");
            }
            final ByteVec modelBytes = readModelData(dataKey);
            try {
                final GenericModel genericModel;
                if (ZipUtil.isCompressed(modelBytes)) {
                    genericModel = importMojo(modelBytes, dataKey);
                } else {
                    warn("_path", "Trying to import a POJO model - this is currently an experimental feature.");
                    genericModel = importPojo(modelBytes, dataKey, _result.toString());
                }
                genericModel.write_lock(_job);
                genericModel.unlock(_job);
            } catch (IOException e) {
                throw new IllegalStateException("Unreachable model file: " + dataKey, e);
            }
        }

        private GenericModel importMojo(ByteVec mojoBytes, Key dataKey) throws IOException {
            final MojoReaderBackend readerBackend = MojoReaderBackendFactory.createReaderBackend(
                    mojoBytes.openStream(_job._key), MojoReaderBackendFactory.CachingStrategy.MEMORY);
            final MojoModel mojoModel = ModelMojoReader.readFrom(readerBackend, true);

            if(! ALLOWED_MOJO_ALGOS.contains(mojoModel._modelDescriptor.algoName().toLowerCase())) {
                if (_parms._disable_algo_check)
                    Log.warn(String.format("MOJO model '%s' is not supported but user disabled white-list check. Trying to load anyway.", mojoModel._modelDescriptor.algoName()));
                else
                    throw new IllegalArgumentException(String.format("Unsupported MOJO model '%s'. ", mojoModel._modelDescriptor.algoName()));
            }

            final GenericModelOutput genericModelOutput = new GenericModelOutput(mojoModel._modelDescriptor, mojoModel._modelAttributes, mojoModel._reproducibilityInformation);
            return new GenericModel(_result, _parms, genericModelOutput, mojoModel, dataKey);
        }

        private GenericModel importPojo(ByteVec pojoBytes, Key pojoKey, String modelId) throws IOException {
            GenModel genmodel = PojoLoader.loadPojoFromSourceCode(pojoBytes, pojoKey, modelId);
            ModelDescriptor pojoDescriptor = ModelDescriptorBuilder.makeDescriptor(genmodel);
            final GenericModelOutput genericModelOutput = new GenericModelOutput(pojoDescriptor);
            return new GenericModel(_result, _parms, genericModelOutput, genmodel, pojoKey);
        }
    }

    private Key importFile() {
        ArrayList files = new ArrayList<>();
        ArrayList keys = new ArrayList<>();
        ArrayList fails = new ArrayList<>();
        ArrayList dels = new ArrayList<>();
        H2O.getPM().importFiles(_parms._path, null, files, keys, fails, dels);
        if (!fails.isEmpty()) {
            throw new RuntimeException("Failed to import file: " + Arrays.toString(fails.toArray()));
        }
        assert keys.size() == 1;
        return Key.make(keys.get(0));
    }

    /**
     * Retrieves pre-uploaded MOJO archive and performs basic verifications, if present.
     *
     * @param key Key to MOJO bytes in DKV
     * @return An instance of {@link ByteVec} containing the bytes of an uploaded MOJO, if present. Or exception. Never returns null.
     * @throws IllegalArgumentException In case the supplied key is invalid (MOJO missing, empty key etc.)
     */
    private ByteVec readModelData(final Key key) throws IllegalArgumentException {
        Objects.requireNonNull(key); // Nicer null pointer exception in case null key is accidentally provided

        Frame mojoFrame = key.get();
        if (mojoFrame.numCols() > 1)
            throw new IllegalArgumentException(String.format("Given model frame with key '%s' should contain only 1 column with model bytes. More columns found. Incorrect key provided ?", key));
        ByteVec mojoData = (ByteVec) mojoFrame.anyVec();

        if (mojoData.length() < 1)
            throw new IllegalArgumentException(String.format("Given model frame with key '%s' is empty (0 bytes). Please provide a non-empty model file.", key));

        return mojoData;
    }

    @Override
    public BuilderVisibility builderVisibility() {
        return BuilderVisibility.Stable;
    }

    /**
     * Convenience method for importing MOJO into H2O.
     * 
     * @param location absolute path to MOJO file
     * @param disableAlgoCheck if true skip the check of white-listed MOJO models, use at your own risk - some features might not work.
     * @return instance of H2O Model wrapping a MOJO 
     */
    public static GenericModel importMojoModel(String location, boolean disableAlgoCheck) {
        GenericModelParameters p = new GenericModelParameters();
        p._path = location;
        p._disable_algo_check = disableAlgoCheck;
        return new Generic(p).trainModel().get();
    }

    public static GenericModel importMojoModel(URI location) {
        return importMojoModel(location.toString(), false);
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy