All Downloads are FREE. Search and download functionalities are using the official Maven repository.

smile.feature.imputation.KNNImputer Maven / Gradle / Ivy

The newest version!
package smile.feature.imputation;

import java.util.Arrays;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.measure.NominalScale;
import smile.data.transform.Transform;
import smile.data.type.StructField;
import smile.data.type.StructType;
import smile.math.distance.Distance;
import smile.math.MathEx;
import smile.neighbor.KNNSearch;
import smile.neighbor.LinearSearch;
import smile.neighbor.Neighbor;

/**
 * Missing value imputation with k-nearest neighbors. The KNN-based method
 * selects instances similar to the instance of interest to impute
 * missing values. If we consider instance A that has one missing value on
 * attribute i, this method would find K other instances, which have a value
 * present on attribute i, with values most similar (in terms of some distance,
 * e.g. Euclidean distance) to A on other attributes without missing values.
 * The average of values on attribute i from the K nearest
 * neighbors is then used as an estimate for the missing value in instance A.
 * In the weighted average, the contribution of each instance is weighted by
 * similarity between it and instance A.
 *
 * @author Haifeng Li
 */public class KNNImputer implements Transform {
    /** The number of nearest neighbors used for imputation. */
    private final int k;
    /** K-nearest neighbor search algorithm. */
    private final KNNSearch knn;

    /**
     * Constructor.
     * @param data the map of column name to the constant value.
     * @param k the number of nearest neighbors used for imputation.
     * @param distance the distance measure.
     */
    public KNNImputer(DataFrame data, int k, Distance distance) {
        this.k = k;
        this.knn = LinearSearch.of(data.toList(), distance);
    }

    /**
     * Constructor with Euclidean distance on selected columns.
     * @param data the map of column name to the constant value.
     * @param k the number of nearest neighbors used for imputation.
     * @param columns the columns used in Euclidean distance computation.
     *                If empty, all columns will be used.
     */
    public KNNImputer(DataFrame data, int k, String... columns) {
        this(data, k, (x, y) -> {
            double[] xd = x.toArray(columns);
            double[] yd = y.toArray(columns);
            return MathEx.squaredDistanceWithMissingValues(xd, yd);
        });
    }

    @Override
    public Tuple apply(Tuple x) {
        StructType schema = x.schema();
        Neighbor[] neighbors = knn.search(x, k);
        return new smile.data.AbstractTuple() {
            @Override
            public Object get(int i) {
                Object xi = x.get(i);
                if (!SimpleImputer.isMissing(xi)) {
                    return xi;
                } else {
                    StructField field = schema.field(i);
                    if (field.type.isBoolean()) {
                        int[] vector = MathEx.omit(
                                Arrays.stream(neighbors)
                                        .mapToInt(neighbor -> neighbor.key.getInt(i)).toArray(),
                                Integer.MIN_VALUE);
                        return vector.length == 0 ? null : MathEx.mode(vector) != 0;
                    } else if (field.type.isChar()) {
                        int[] vector = MathEx.omit(
                                Arrays.stream(neighbors)
                                        .mapToInt(neighbor -> neighbor.key.getInt(i)).toArray(),
                                Integer.MIN_VALUE);
                        return vector.length == 0 ? null : (char) MathEx.mode(vector);
                    } else if (field.measure instanceof NominalScale) {
                        int[] vector = MathEx.omit(
                                Arrays.stream(neighbors)
                                        .mapToInt(neighbor -> neighbor.key.getInt(i)).toArray(),
                                Integer.MIN_VALUE);
                        return vector.length == 0 ? null : MathEx.mode(vector);
                    } else if (field.type.isNumeric()) {
                        double[] vector = MathEx.omit(
                                Arrays.stream(neighbors)
                                        .mapToDouble(neighbor -> neighbor.key.getDouble(i)).toArray(),
                                Integer.MIN_VALUE);
                        return vector.length == 0 ? null : MathEx.mean(vector);
                    } else {
                        return null;
                    }
                }
            }

            @Override
            public StructType schema() {
                return schema;
            }
        };
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy