All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.h2o.automl.Models Maven / Gradle / Ivy

The newest version!
package ai.h2o.automl;

import hex.Model;
import hex.ModelContainer;
import water.*;
import water.api.schemas3.KeyV3;
import water.automl.api.schemas3.SchemaExtensions;
import water.util.ArrayUtils;

import java.lang.reflect.Array;
import java.lang.reflect.Modifier;
import java.util.Arrays;

public class Models extends Lockable> implements ModelContainer {

    private final int _type_id;
    private final Job _job;
    private Key[] _modelKeys = new Key[0];

    public Models(Key> key, Class clz) {
        this(key, clz, null);
    }

    public Models(Key> key, Class clz, Job job) {
        super(key);
        _type_id = (clz != null && !Modifier.isAbstract(clz.getModifiers())) ? TypeMap.getIcedId(clz.getName()) : -1;
        _job = job;
    }

    @Override
    public Key[] getModelKeys() {
        return _modelKeys.clone();
    }

    @Override
    @SuppressWarnings("unchecked")
    public M[] getModels() {
        Arrays.stream(_modelKeys).forEach(DKV::prefetch);
        Class clz = (Class)(_type_id >= 0 ? TypeMap.theFreezable(_type_id).getClass(): Model.class);
        return Arrays.stream(_modelKeys)
                .map(k -> k == null ? null : k.get())
                .toArray(l -> (M[])Array.newInstance(clz, l));
    }

    @Override
    public int getModelCount() {
        return _modelKeys.length;
    }

    public void addModel(Key key) {
        addModels(new Key[]{key});
    }

    public void addModels(Key[] keys) {
       write_lock(_job);
       _modelKeys = ArrayUtils.append(_modelKeys, keys);
       update(_job);
       unlock(_job);
    }

    @Override
    protected Futures remove_impl(final Futures fs, boolean cascade) {
        if (cascade) {
            for (Key k : _modelKeys)
                Keyed.remove(k, fs, true);
        }
        _modelKeys = new Key[0];
        return super.remove_impl(fs, cascade);
    }

    @Override
    public Class makeSchema() {
        return SchemaExtensions.ModelsKeyV3.class;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy