All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.tree.xgboost.matrix.RemoteMatrixLoader Maven / Gradle / Ivy

package hex.tree.xgboost.matrix;

import hex.tree.xgboost.task.XGBoostUploadMatrixTask;
import ai.h2o.xgboost4j.java.util.BigDenseMatrix;
import water.Key;

import java.util.HashMap;
import java.util.Map;

public class RemoteMatrixLoader extends MatrixLoader {

    static abstract class RemoteMatrix {
        XGBoostUploadMatrixTask.MatrixData data;
        
        abstract MatrixLoader.DMatrixProvider make();
    }
    
    static class RemoteSparseMatrix extends RemoteMatrix {
        
        final SparseMatrixDimensions dims;
        final SparseMatrix matrix;

        RemoteSparseMatrix(SparseMatrixDimensions dims) {
            this.dims = dims;
            this.matrix = SparseMatrixFactory.allocateCSRMatrix(dims);
        }

        @Override
        MatrixLoader.DMatrixProvider make() {
            return SparseMatrixFactory.toDMatrix(matrix, dims, data.actualRows, data.shape, data.resp, data.weights, data.offsets);
        }
    }
    
    static class RemoteDenseMatrix extends RemoteMatrix {

        final XGBoostUploadMatrixTask.DenseMatrixDimensions dims;
        final BigDenseMatrix matrix;

        RemoteDenseMatrix(XGBoostUploadMatrixTask.DenseMatrixDimensions dims) {
            this.dims = dims;
            this.matrix = new BigDenseMatrix(dims.rows, dims.cols);
        }

        @Override
        DMatrixProvider make() {
            return new DenseMatrixFactory.DenseDMatrixProvider(data.actualRows, data.resp, data.weights, data.offsets, matrix);
        }
    }

    private static final Map REGISTRY = new HashMap<>();

    public static void initSparse(String key, SparseMatrixDimensions dims) {
        RemoteSparseMatrix m = new RemoteSparseMatrix(dims);
        REGISTRY.put(key, m);
    }

    public static void sparseChunk(String key, XGBoostUploadMatrixTask.SparseMatrixChunk chunk) {
        RemoteSparseMatrix m = (RemoteSparseMatrix) REGISTRY.get(key);
        long nonZeroCount = m.dims._precedingNonZeroElementsCounts[chunk.id];
        int rwRow = m.dims._precedingRowCounts[chunk.id];
        SparseMatrixFactory.NestedArrayPointer rowHeaderPointer = new SparseMatrixFactory.NestedArrayPointer(rwRow);
        SparseMatrixFactory.NestedArrayPointer dataPointer = new SparseMatrixFactory.NestedArrayPointer(nonZeroCount);
        for (int i = 0; i < chunk.rowHeader.length; i++) {
            rowHeaderPointer.setAndIncrement(m.matrix._rowHeaders, chunk.rowHeader[i]);
        }
        for (int i = 0; i < chunk.data.length; i++) {
            dataPointer.set(m.matrix._sparseData, chunk.data[i]);
            dataPointer.set(m.matrix._colIndices, chunk.colIndices[i]);
            dataPointer.increment();
        }
    }

    public static void initDense(String key, XGBoostUploadMatrixTask.DenseMatrixDimensions dims) {
        RemoteDenseMatrix m = new RemoteDenseMatrix(dims);
        REGISTRY.put(key, m);
    }

    public static void denseChunk(String key, XGBoostUploadMatrixTask.DenseMatrixChunk chunk) {
        RemoteDenseMatrix m = (RemoteDenseMatrix) REGISTRY.get(key);
        for (long i = 0; i < chunk.data.length; i++) {
            m.matrix.set(i + (m.dims.rowOffsets[chunk.id] * m.dims.cols), chunk.data[(int) i]);
        }
    }

    public static void matrixData(String key, XGBoostUploadMatrixTask.MatrixData data) {
        REGISTRY.get(key).data = data;
    }
    
    public static void cleanup(String key) {
        REGISTRY.remove(key);
    }

    private final Key modelKey;
    
    public RemoteMatrixLoader(Key modelKey) {
        this.modelKey = modelKey;
    }

    @Override
    public DMatrixProvider makeLocalTrainMatrix() {
        return REGISTRY.remove(trainMatrixKey(modelKey)).make();
    }

    public static String trainMatrixKey(Key modelKey) {
        return modelKey.toString() + "_train";
    }

    @Override
    public boolean hasValidationFrame() {
        return REGISTRY.containsKey(validMatrixKey(modelKey));
    }

    @Override
    public DMatrixProvider makeLocalValidMatrix() {
        return REGISTRY.remove(validMatrixKey(modelKey)).make();
    }

    public static String validMatrixKey(Key modelKey) {
        return modelKey.toString() + "_valid";
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy