All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.genmodel.algos.tree.ContributionComposer Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package hex.genmodel.algos.tree;

import hex.genmodel.utils.ArrayUtils;

import java.util.Arrays;

public class ContributionComposer {
    
    /**
     * Sort #contribNameIds according to #contribs values and compose desired output with correct #topN, #bottomN fields
     *
     * @param contribNameIds Contribution corresponding feature ids
     * @param contribs Raw contribution values
     * @param topN Return only #topN highest #contribNameIds + bias.
     *             If topN<0 then sort all SHAP values in descending order
     *             If topN<0 && bottomN<0 then sort all SHAP values in descending order
     * @param bottomN Return only #bottomN lowest #contribNameIds + bias
     *                If topN and bottomN are defined together then return array of #topN + #bottomN + bias
     *                If bottomN<0 then sort all SHAP values in ascending order
     *                If topN<0 && bottomN<0 then sort all SHAP values in descending order
     * @param compareAbs True to compare absolute values of #contribs
     * @return Sorted contribNameIds array of corresponding contributions features.
     *         The size of returned array is #topN + #bottomN + bias
     */
    public final int[] composeContributions(final int[] contribNameIds, final float[] contribs, int topN, int bottomN, boolean compareAbs) {
        assert contribNameIds.length == contribs.length : "contribNameIds must have the same length as contribs";
        if (returnOnlyTopN(topN, bottomN)) {
            return composeSortedContributions(contribNameIds, contribs, topN, compareAbs, -1);
        } else if (returnOnlyBottomN(topN, bottomN)) {
            return composeSortedContributions(contribNameIds, contribs, bottomN, compareAbs,1);
        } else if (returnAllTopN(topN, bottomN, contribs.length)) {
            return composeSortedContributions(contribNameIds, contribs, contribs.length, compareAbs, -1);
        }

        composeSortedContributions(contribNameIds, contribs, contribNameIds.length, compareAbs,-1);
        int[] bottomSorted = Arrays.copyOfRange(contribNameIds, contribNameIds.length - 1 - bottomN, contribNameIds.length);
        reverse(bottomSorted, contribs, bottomSorted.length - 1);
        int[] contribNameIdsTmp = Arrays.copyOf(contribNameIds, topN);

        return ArrayUtils.append(contribNameIdsTmp, bottomSorted);
    }

    private boolean returnOnlyTopN(int topN, int bottomN) {
        return topN != 0 && bottomN == 0;
    }

    private boolean returnOnlyBottomN(int topN, int bottomN) {
        return topN == 0 && bottomN != 0;
    }

    private boolean returnAllTopN(int topN, int bottomN, int len) {
        return (topN + bottomN) >= len || topN < 0 || bottomN < 0;
    }

    public int checkAndAdjustInput(int n, int len) {
        if (n < 0 || n > len) {
            return len;
        }
        return n;
    }
    
    private int[] composeSortedContributions(final int[] contribNameIds, final float[] contribs, int n, boolean compareAbs, int increasing) {
        int nAdjusted = checkAndAdjustInput(n, contribs.length);
        sortContributions(contribNameIds, contribs, compareAbs, increasing);
        if (nAdjusted < contribs.length) {
            int bias = contribNameIds[contribs.length-1];
            int[] contribNameIdsSorted = Arrays.copyOfRange(contribNameIds, 0, nAdjusted + 1);
            contribNameIdsSorted[nAdjusted] = bias;
            return contribNameIdsSorted;
        }
        return contribNameIds;
    }
    
    private void sortContributions(final int[] contribNameIds, final float[] contribs, final boolean compareAbs, final int increasing) {
        ArrayUtils.sort(contribNameIds, contribs, 0, contribs.length -1, compareAbs, increasing);
    }

    private void reverse(int[] contribNameIds, float[] contribs, int len) {
        for (int i = 0; i < len/2; i++) {
            if (contribs[contribNameIds[i]] != contribs[contribNameIds[len - i - 1]]) {
                int tmp = contribNameIds[i];
                contribNameIds[i] = contribNameIds[len - i - 1];
                contribNameIds[len - i - 1] = tmp;
            }
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy