All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.tree.xgboost.matrix.MatrixLoader Maven / Gradle / Ivy

package hex.tree.xgboost.matrix;

import ai.h2o.xgboost4j.java.DMatrix;
import ai.h2o.xgboost4j.java.XGBoostError;
import water.Iced;

import java.util.Arrays;
import java.util.Objects;

public abstract class MatrixLoader extends Iced {
    
    public static abstract class DMatrixProvider {

        protected long actualRows;
        protected float[] response;
        protected float[] weights;
        protected float[] offsets;

        protected DMatrixProvider(long actualRows, float[] response, float[] weights, float[] offsets) {
            this.actualRows = actualRows;
            this.response = response;
            this.weights = weights;
            this.offsets = offsets;
        }
        
        protected abstract DMatrix makeDMatrix() throws XGBoostError;

        @SuppressWarnings("unused") // used for debugging
        public abstract void print(int nrow);
        
        protected void dispose() {}
        
        public final DMatrix get() throws XGBoostError {
            DMatrix mat = makeDMatrix();
            dispose();
            assert mat.rowNum() == actualRows;
            mat.setLabel(response);
            if (weights != null) {
                mat.setWeight(weights);
            }
            if (offsets != null) {
                mat.setBaseMargin(offsets);
            }
            return mat;
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) return true;
            if (!(o instanceof DMatrixProvider)) return false;
            DMatrixProvider that = (DMatrixProvider) o;
            return actualRows == that.actualRows &&
                Arrays.equals(response, that.response) &&
                Arrays.equals(weights, that.weights) &&
                Arrays.equals(offsets, that.offsets);
        }

        @Override
        public int hashCode() {
            int result = Objects.hash(actualRows);
            result = 31 * result + Arrays.hashCode(response);
            result = 31 * result + Arrays.hashCode(weights);
            result = 31 * result + Arrays.hashCode(offsets);
            return result;
        }
    }
    
    public abstract DMatrixProvider makeLocalTrainMatrix();

    public abstract DMatrixProvider makeLocalValidMatrix();

    public abstract boolean hasValidationFrame();

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy