All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.genmodel.descriptor.ModelDescriptorBuilder Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package hex.genmodel.descriptor;

import hex.ModelCategory;
import hex.genmodel.GenModel;
import hex.genmodel.MojoModel;
import hex.genmodel.attributes.ModelAttributes;
import hex.genmodel.attributes.parameters.ColumnSpecifier;
import hex.genmodel.utils.ArrayUtils;

import java.io.Serializable;
import java.util.Arrays;

public class ModelDescriptorBuilder {

    /**
     * Builds an instance of {@link ModelDescriptor}, using information provided by the serialized model and
     * which corresponding implementations of {@link hex.genmodel.ModelMojoReader} are able to provide.
     *
     * @param mojoModel         A MojoModel to extract the model description from
     * @param fullAlgorithmName A full name of the algorithm
     * @param modelAttributes   Optional model attributes
     * @return A new instance of {@link ModelDescriptor}
     */
    public static ModelDescriptor makeDescriptor(final MojoModel mojoModel, final String fullAlgorithmName,
                                                 final ModelAttributes modelAttributes) {
        return new MojoModelDescriptor(mojoModel, fullAlgorithmName, modelAttributes);
    }

    public static ModelDescriptor makeDescriptor(final GenModel pojoModel) {
        return new PojoModelDescriptor(pojoModel);
    }

    public static class MojoModelDescriptor implements ModelDescriptor, Serializable {
        // Mandatory
        private final String _h2oVersion;
        private final hex.ModelCategory _category;
        private final String _uuid;
        private final boolean _supervised;
        private final int _nfeatures;
        private final int _nclasses;
        private final boolean _balanceClasses;
        private final double _defaultThreshold;
        private final double[] _priorClassDistrib;
        private final double[] _modelClassDistrib;
        private final String _offsetColumn;
        private final String _foldColumn;
        private final String _weightsColumn;
        private final String _treatmentColumn;
        private final String[][] _domains;
        private final String[][] _origDomains;
        private final String[] _names;
        private final String[] _origNames;
        private final String _algoName;
        private final String _fullAlgoName;

        private MojoModelDescriptor(final MojoModel mojoModel, final String fullAlgorithmName,
                                    final ModelAttributes modelAttributes) {
            _category = mojoModel._category;
            _uuid = mojoModel._uuid;
            _supervised = mojoModel.isSupervised();
            _nfeatures = mojoModel.nfeatures();
            _nclasses = mojoModel._nclasses;
            _balanceClasses = mojoModel._balanceClasses;
            _defaultThreshold = mojoModel._defaultThreshold;
            _priorClassDistrib = mojoModel._priorClassDistrib;
            _modelClassDistrib = mojoModel._modelClassDistrib;
            _h2oVersion = mojoModel._h2oVersion;
            _offsetColumn = mojoModel._offsetColumn;
            _foldColumn = mojoModel._foldColumn;
            _domains = mojoModel._domains;
            _origDomains = mojoModel.getOrigDomainValues();
            _names = mojoModel._names;
            _origNames = mojoModel.getOrigNames();
            _algoName = mojoModel._algoName;
            _fullAlgoName = fullAlgorithmName;
            if (modelAttributes != null) {
                ColumnSpecifier weightsColSpec = (ColumnSpecifier) modelAttributes.getParameterValueByName("weights_column");
                _weightsColumn = weightsColSpec != null ? weightsColSpec.getColumnName() : null; 
            } else {
                _weightsColumn = null;
            }
            if (modelAttributes != null) {
                _treatmentColumn = (String) modelAttributes.getParameterValueByName("treatment_column");;
            } else {
                _treatmentColumn = null;
            }
        }

        @Override
        public String[][] scoringDomains() {
            return _domains;
        }

        @Override
        public String projectVersion() {
            return _h2oVersion;
        }

        @Override
        public String algoName() {
            return _algoName;
        }

        @Override
        public String algoFullName() {
            return _fullAlgoName;
        }

        @Override
        public String offsetColumn() {
            return _offsetColumn;
        }

        @Override
        public String weightsColumn() {
            return _weightsColumn;
        }

        @Override
        public String foldColumn() {
            return _foldColumn;
        }

        @Override
        public String treatmentColumn() {
            return _treatmentColumn;
        }

        @Override
        public ModelCategory getModelCategory() {
            return _category;
        }

        @Override
        public boolean isSupervised() {
            return _supervised;
        }

        @Override
        public int nfeatures() {
            return _nfeatures;
        }

        @Override
        public String[] features() {
            return Arrays.copyOf(columnNames(), nfeatures());
        }

        @Override
        public int nclasses() {
            return _nclasses;
        }

        @Override
        public String[] columnNames() {
            return _names;
        }

        @Override
        public boolean balanceClasses() {
            return _balanceClasses;
        }

        @Override
        public double defaultThreshold() {
            return _defaultThreshold;
        }

        @Override
        public double[] priorClassDist() {
            return _priorClassDistrib;
        }

        @Override
        public double[] modelClassDist() {
            return _modelClassDistrib;
        }

        @Override
        public String uuid() {
            return _uuid;
        }

        @Override
        public String timestamp() {
            return null;
        }

        @Override
        public String[] getOrigNames() {
            return _origNames;
        }

        @Override
        public String[][] getOrigDomains() {
            return _origDomains;
        }
    }

    public static class PojoModelDescriptor implements ModelDescriptor, Serializable {
        // Mandatory
        private final hex.ModelCategory _category;
        private final boolean _supervised;
        private final int _nfeatures;
        private final int _nclasses;
        private final String _offsetColumn;
        private final String[][] _domains;
        private final String[][] _origDomains;
        private final String[] _names;
        private final String[] _origNames;

        private PojoModelDescriptor(final GenModel mojoModel) {
            _category = mojoModel.getModelCategory();
            _supervised = mojoModel.isSupervised();
            _nfeatures = mojoModel.nfeatures();
            _nclasses = mojoModel.getNumResponseClasses();
            _offsetColumn = mojoModel.getOffsetName();
            _domains = mojoModel.getDomainValues();
            _origDomains = mojoModel.getOrigDomainValues();
            String[] names = mojoModel.getNames();
            if (names.length == _domains.length - 1 && mojoModel.isSupervised() && 
                    !names[names.length - 1].equals(mojoModel._responseColumn)) {
                names = ArrayUtils.append(names, mojoModel._responseColumn);
            }
            _names = names;
            _origNames = mojoModel.getOrigNames();
        }

        @Override
        public String[][] scoringDomains() {
            return _domains;
        }

        @Override
        public String projectVersion() {
            return "unknown";
        }

        @Override
        public String algoName() {
            return "pojo";
        }

        @Override
        public String algoFullName() {
            return "POJO Scorer";
        }

        @Override
        public String offsetColumn() {
            return _offsetColumn;
        }

        @Override
        public String weightsColumn() {
            return null;
        }

        @Override
        public String foldColumn() {
            return null;
        }

        @Override
        public String treatmentColumn() { return null; }

        @Override
        public ModelCategory getModelCategory() {
            return _category;
        }

        @Override
        public boolean isSupervised() {
            return _supervised;
        }

        @Override
        public int nfeatures() {
            return _nfeatures;
        }

        @Override
        public String[] features() {
            return Arrays.copyOf(columnNames(), nfeatures());
        }

        @Override
        public int nclasses() {
            return _nclasses;
        }

        @Override
        public String[] columnNames() {
            return _names;
        }

        @Override
        public boolean balanceClasses() {
            return false;
        }

        @Override
        public double defaultThreshold() {
            return Double.NaN;
        }

        @Override
        public double[] priorClassDist() {
            return null;
        }

        @Override
        public double[] modelClassDist() {
            return null;
        }

        @Override
        public String uuid() {
            return null;
        }

        @Override
        public String timestamp() {
            return null;
        }

        @Override
        public String[] getOrigNames() {
            return _origNames;
        }

        @Override
        public String[][] getOrigDomains() {
            return _origDomains;
        }
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy