All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hivemall.smile.utils.SmileExtUtils Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package hivemall.smile.utils;

import hivemall.annotations.VisibleForTesting;
import hivemall.smile.classification.DecisionTree.SplitRule;
import hivemall.utils.collections.arrays.SparseIntArray;
import hivemall.utils.collections.lists.DoubleArrayList;
import hivemall.utils.collections.lists.IntArrayList;
import hivemall.utils.lang.NumberUtils;
import hivemall.utils.lang.Preconditions;
import hivemall.utils.math.MathUtils;
import hivemall.utils.random.PRNG;
import hivemall.utils.random.RandomNumberGeneratorFactory;
import matrix4j.matrix.ColumnMajorMatrix;
import matrix4j.matrix.Matrix;
import matrix4j.matrix.MatrixUtils;
import matrix4j.vector.VectorProcedure;
import smile.data.NominalAttribute;
import smile.data.NumericAttribute;
import smile.sort.QuickSort;

import java.util.Arrays;

import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;

import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.roaringbitmap.RoaringBitmap;

public final class SmileExtUtils {
    public static final byte NUMERIC = (byte) 1;
    public static final byte NOMINAL = (byte) 2;

    private SmileExtUtils() {}

    /**
     * @param opt command separated list of Q and C. Q for {@link NumericAttribute}, C for
     *        {@link NominalAttribute}.
     */
    @Nonnull
    public static RoaringBitmap resolveAttributes(@Nullable final String opt)
            throws UDFArgumentException {
        final RoaringBitmap attr = new RoaringBitmap();
        if (opt == null) {
            return attr;
        }
        final String[] opts = opt.split(",");
        final int size = opts.length;
        for (int i = 0; i < size; i++) {
            final String type = opts[i];
            if ("Q".equals(type)) {
                continue;
            } else if ("C".equals(type)) {
                attr.add(i);
            } else {
                throw new UDFArgumentException("Unsupported attribute type: " + type);
            }
        }
        return attr;
    }

    /**
     * @param opt comma separated list of zero-start indexes
     */
    @Nonnull
    public static RoaringBitmap parseNominalAttributeIndicies(@Nullable final String opt)
            throws UDFArgumentException {
        final RoaringBitmap attr = new RoaringBitmap();
        if (opt == null) {
            return attr;
        }
        for (String s : opt.split(",")) {
            if (NumberUtils.isDigits(s)) {
                int index = NumberUtils.parseInt(s);
                attr.add(index);
            } else {
                throw new UDFArgumentException("Expected integer but got " + s);
            }
        }
        return attr;
    }

    @VisibleForTesting
    @Nonnull
    public static RoaringBitmap convertAttributeTypes(
            @Nonnull final smile.data.Attribute[] original) {
        final int size = original.length;
        final RoaringBitmap nominalAttrs = new RoaringBitmap();
        for (int i = 0; i < size; i++) {
            smile.data.Attribute o = original[i];
            switch (o.type) {
                case NOMINAL: {
                    nominalAttrs.add(i);
                    break;
                }
                case NUMERIC: {
                    break;
                }
                default:
                    throw new UnsupportedOperationException("Unsupported type: " + o.type);
            }
        }
        return nominalAttrs;
    }

    @Nonnull
    public static VariableOrder sort(@Nonnull final RoaringBitmap nominalAttrs,
            @Nonnull final Matrix x, @Nonnull final int[] samples) {
        final int n = x.numRows();
        final int p = x.numColumns();

        final SparseIntArray[] index = new SparseIntArray[p];
        if (x.isSparse()) {
            int initSize = n / 10;
            final DoubleArrayList dlist = new DoubleArrayList(initSize);
            final IntArrayList ilist = new IntArrayList(initSize);
            final VectorProcedure proc = new VectorProcedure() {
                @Override
                public void apply(final int i, final double v) {
                    if (samples[i] == 0) {
                        return;
                    }
                    dlist.add(v);
                    ilist.add(i);
                }
            };

            final ColumnMajorMatrix x2 = x.toColumnMajorMatrix();
            for (int j = 0; j < p; j++) {
                if (nominalAttrs.contains(j)) {
                    continue; // nop for categorical columns
                }
                // sort only numerical columns
                x2.eachNonNullInColumn(j, proc);
                if (ilist.isEmpty()) {
                    continue;
                }
                int[] rowPtrs = ilist.toArray();
                QuickSort.sort(dlist.array(), rowPtrs, rowPtrs.length);
                index[j] = new SparseIntArray(rowPtrs);
                dlist.clear();
                ilist.clear();
            }
        } else {
            final DoubleArrayList dlist = new DoubleArrayList(n);
            final IntArrayList ilist = new IntArrayList(n);
            for (int j = 0; j < p; j++) {
                if (nominalAttrs.contains(j)) {
                    continue; // nop for categorical columns
                }
                // sort only numerical columns
                for (int i = 0; i < n; i++) {
                    if (samples[i] == 0) {
                        continue;
                    }
                    double x_ij = x.get(i, j);
                    dlist.add(x_ij);
                    ilist.add(i);
                }
                if (ilist.isEmpty()) {
                    continue;
                }
                int[] rowPtrs = ilist.toArray();
                QuickSort.sort(dlist.array(), rowPtrs, rowPtrs.length);
                index[j] = new SparseIntArray(rowPtrs);
                dlist.clear();
                ilist.clear();
            }
        }

        return new VariableOrder(index);
    }

    @Nonnull
    public static int[] classLabels(@Nonnull final int[] y) throws HiveException {
        final int[] labels = smile.math.Math.unique(y);
        Arrays.sort(labels);

        if (labels.length < 2) {
            throw new HiveException("Only one class.");
        }
        for (int i = 0; i < labels.length; i++) {
            if (labels[i] < 0) {
                throw new HiveException("Negative class label: " + labels[i]);
            }
            if (i > 0 && (labels[i] - labels[i - 1]) > 1) {
                throw new HiveException("Missing class: " + (labels[i - 1] + 1));
            }
        }

        return labels;
    }

    @Nonnull
    public static SplitRule resolveSplitRule(@Nullable String ruleName) {
        if ("gini".equalsIgnoreCase(ruleName)) {
            return SplitRule.GINI;
        } else if ("entropy".equalsIgnoreCase(ruleName)) {
            return SplitRule.ENTROPY;
        } else if ("classification_error".equalsIgnoreCase(ruleName)) {
            return SplitRule.CLASSIFICATION_ERROR;
        } else {
            return SplitRule.GINI;
        }
    }

    public static int computeNumInputVars(final float numVars, @Nonnull final Matrix x) {
        final int numInputVars;
        if (numVars <= 0.f) {
            int dims = x.numColumns();
            numInputVars = (int) Math.ceil(Math.sqrt(dims));
        } else if (numVars > 0.f && numVars <= 1.f) {
            numInputVars = (int) (numVars * x.numColumns());
        } else {
            numInputVars = (int) numVars;
        }
        return numInputVars;
    }

    public static long generateSeed() {
        return Thread.currentThread().getId() * System.nanoTime();
    }

    public static void shuffle(@Nonnull final int[] x, @Nonnull final PRNG rnd) {
        for (int i = x.length; i > 1; i--) {
            int j = rnd.nextInt(i);
            swap(x, i - 1, j);
        }
    }

    @Nonnull
    public static Matrix shuffle(@Nonnull final Matrix x, @Nonnull final int[] y, long seed) {
        final int numRows = x.numRows();
        if (numRows != y.length) {
            throw new IllegalArgumentException(
                "x.length (" + numRows + ") != y.length (" + y.length + ')');
        }
        if (seed == -1L) {
            seed = generateSeed();
        }

        final PRNG rnd = RandomNumberGeneratorFactory.createPRNG(seed);
        if (x.swappable()) {
            for (int i = numRows; i > 1; i--) {
                int j = rnd.nextInt(i);
                int k = i - 1;
                x.swap(k, j);
                swap(y, k, j);
            }
            return x;
        } else {
            final int[] indices = MathUtils.permutation(numRows);
            for (int i = numRows; i > 1; i--) {
                int j = rnd.nextInt(i);
                int k = i - 1;
                swap(indices, k, j);
                swap(y, k, j);
            }
            return MatrixUtils.shuffle(x, indices);
        }
    }

    @Nonnull
    public static Matrix shuffle(@Nonnull final Matrix x, @Nonnull final double[] y,
            @Nonnull long seed) {
        final int numRows = x.numRows();
        if (numRows != y.length) {
            throw new IllegalArgumentException(
                "x.length (" + numRows + ") != y.length (" + y.length + ')');
        }
        if (seed == -1L) {
            seed = generateSeed();
        }

        final PRNG rnd = RandomNumberGeneratorFactory.createPRNG(seed);
        if (x.swappable()) {
            for (int i = numRows; i > 1; i--) {
                int j = rnd.nextInt(i);
                int k = i - 1;
                x.swap(k, j);
                swap(y, k, j);
            }
            return x;
        } else {
            final int[] indices = MathUtils.permutation(numRows);
            for (int i = numRows; i > 1; i--) {
                int j = rnd.nextInt(i);
                int k = i - 1;
                swap(indices, k, j);
                swap(y, k, j);
            }
            return MatrixUtils.shuffle(x, indices);
        }
    }

    /**
     * Swap two elements of an array.
     */
    private static void swap(final int[] x, final int i, final int j) {
        int s = x[i];
        x[i] = x[j];
        x[j] = s;
    }

    /**
     * Swap two elements of an array.
     */
    private static void swap(final double[] x, final int i, final int j) {
        double s = x[i];
        x[i] = x[j];
        x[j] = s;
    }

    public static boolean containsNumericType(@Nonnull final Matrix x,
            final RoaringBitmap attributes) {
        int numColumns = x.numColumns();
        int numCategoricalCols = attributes.getCardinality();
        return numColumns != numCategoricalCols; // contains at least one numerical column
    }

    @Nonnull
    public static String resolveFeatureName(final int index, @Nullable final String[] names) {
        if (names == null) {
            return "feature#" + index;
        }
        if (index >= names.length) {
            return "feature#" + index;
        }
        return names[index];
    }

    @Nonnull
    public static String resolveName(final int index, @Nullable final String[] names) {
        if (names == null) {
            return String.valueOf(index);
        }
        if (index >= names.length) {
            return String.valueOf(index);
        }
        return names[index];
    }

    /**
     * Generates an evenly distributed range of hue values in the HSV color scale.
     *
     * @return colors
     */
    public static double[] getColorBrew(@Nonnegative int n) {
        Preconditions.checkArgument(n >= 1);

        final double hue_step = 360.d / n;

        final double[] colors = new double[n];
        for (int i = 0; i < n; i++) {
            colors[i] = i * hue_step / 360.d;
        }
        return colors;
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy