All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.genmodel.algos.xgboost.OneHotEncoderFactory Maven / Gradle / Ivy

package hex.genmodel.algos.xgboost;

import biz.k11i.xgboost.util.FVec;
import hex.genmodel.GenModel;

import java.io.Serializable;

class OneHotEncoderFactory implements Serializable {

    private boolean _compatible10;
    private final boolean _sparse;
    private final int[] _catOffsets;
    private final int _cats;
    private final int _nums;
    private final boolean _useAllFactorLevels;
    private final int[] _catMap;
    private final float _notHot;

    OneHotEncoderFactory(boolean compatible10, boolean sparse, int[] catOffsets, int cats, int nums, boolean useAllFactorLevels) {
        _compatible10 = compatible10;
        _sparse = sparse;
        _catOffsets = catOffsets;
        _cats = cats;
        _nums = nums;
        _useAllFactorLevels = useAllFactorLevels;
        _notHot = _sparse ? Float.NaN : 0;
        if (_catOffsets == null) {
            _catMap = new int[0];
        } else {
            _catMap = new int[_catOffsets[_cats]];
            for (int c = 0; c < _cats; c++) {
                for (int j = _catOffsets[c]; j < _catOffsets[c+1]; j++)
                    _catMap[j] = c;
            }
        }
    }

    FVec fromArray(double[] input) {
        float[] numValues = new float[_nums];
        int[] catValues = new int[_cats];
        GenModel.setCats(input, catValues, _cats, _catOffsets, _useAllFactorLevels);
        for (int i = 0; i < numValues.length; i++) {
            float val = (float) input[_cats + i];
            numValues[i] = _sparse && (val == 0) ? Float.NaN : val;
        }

        if (_compatible10) {
            return new OneHotEncoderFVecCompatible10(catValues, numValues);
        } else {
            return new DefaultOneHotEncoderFVec(catValues, numValues);
        }
    }

    private abstract class AbstractOneHotEncoderFVec implements FVec {
        protected final int[] _catValues;
        protected final float[] _numValues;

        private  AbstractOneHotEncoderFVec(int[] catValues, float[] numValues) {
            _catValues = catValues;
            _numValues = numValues;
        }

        @Override
        public final float fvalue(int index) {
            if (index >= _catMap.length)
                return _numValues[index - _catMap.length];

            final boolean isHot = getCategoricalValue(index);
            return isHot ? 1 : _notHot;
        }

        protected abstract boolean getCategoricalValue(int index);
    }
    
    private class DefaultOneHotEncoderFVec extends AbstractOneHotEncoderFVec {

        public DefaultOneHotEncoderFVec(int[] catValues, float[] numValues) {
            super(catValues, numValues);
        }

        @Override
        protected boolean getCategoricalValue(int index) {
            return _catValues[_catMap[index]] == index;
        }
    }

    private class OneHotEncoderFVecCompatible10 extends AbstractOneHotEncoderFVec {

        public OneHotEncoderFVecCompatible10(int[] catValues, float[] numValues) {
            super(catValues, numValues);
        }

        @Override
        protected boolean getCategoricalValue(int index) {
            boolean hot = _catValues[_catMap[index]] == index;
            if (hot) return true;
            // check other columns for match
            for (int catValue : _catValues) {
                if (catValue == index) {
                    return true;
                }
            }
            return false;
        }
    }


}





© 2015 - 2025 Weber Informatics LLC | Privacy Policy