All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizer Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.ml.feature.kbinsdiscretizer;

import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.ml.api.Estimator;
import org.apache.flink.ml.common.datastream.DataStreamUtils;
import org.apache.flink.ml.feature.minmaxscaler.MinMaxScaler.MinMaxReduceFunctionOperator;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;
import org.apache.flink.util.Preconditions;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

/**
 * An Estimator which implements discretization (also known as quantization or binning) to transform
 * continuous features into discrete ones. The output values are in [0, numBins).
 *
 * 

KBinsDiscretizer implements three different binning strategies, and it can be set by {@link * KBinsDiscretizerParams#STRATEGY}. If the strategy is set as {@link KBinsDiscretizerParams#KMEANS} * or {@link KBinsDiscretizerParams#QUANTILE}, users should further set {@link * KBinsDiscretizerParams#SUB_SAMPLES} for better performance. * *

There are several corner cases for different inputs as listed below: * *

    *
  • When the input values of one column are all the same, then they should be mapped to the * same bin (i.e., the zero-th bin). Thus the corresponding bin edges are * `{Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY}`. *
  • When the number of distinct values of one column is less than the specified number of bins * and the {@link KBinsDiscretizerParams#STRATEGY} is set as {@link * KBinsDiscretizerParams#KMEANS}, we switch to {@link KBinsDiscretizerParams#UNIFORM}. *
  • When the width of one output bin is zero, i.e., the left edge equals to the right edge of * the bin, we replace the right edge as the average value of its two neighbors. One exception * is that the last two edges are the same --- in this case, the left edge is updated as the * average of its two neighbors. For example, the bin edges {0, 1, 1, 2, 2} are transformed * into {0, 1, 1.5, 1.75, 2}. *
*/ public class KBinsDiscretizer implements Estimator, KBinsDiscretizerParams { private static final Logger LOG = LoggerFactory.getLogger(KBinsDiscretizer.class); private final Map, Object> paramMap = new HashMap<>(); public KBinsDiscretizer() { ParamUtils.initializeMapWithDefaultValues(paramMap, this); } @Override public KBinsDiscretizerModel fit(Table... inputs) { Preconditions.checkArgument(inputs.length == 1); StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); String inputCol = getInputCol(); String strategy = getStrategy(); int numBins = getNumBins(); DataStream inputData = tEnv.toDataStream(inputs[0]) .map( (MapFunction) value -> ((Vector) value.getField(inputCol)).toDense()); DataStream preprocessedData; if (strategy.equals(UNIFORM)) { preprocessedData = inputData .transform( "reduceInEachPartition", inputData.getType(), new MinMaxReduceFunctionOperator()) .transform( "reduceInFinalPartition", inputData.getType(), new MinMaxReduceFunctionOperator()) .setParallelism(1); } else { preprocessedData = DataStreamUtils.sample( inputData, getSubSamples(), getClass().getName().hashCode()); } DataStream modelData = DataStreamUtils.mapPartition( preprocessedData, new MapPartitionFunction() { @Override public void mapPartition( Iterable iterable, Collector collector) { List list = new ArrayList<>(); iterable.iterator().forEachRemaining(list::add); if (list.size() == 0) { throw new RuntimeException("The training set is empty."); } double[][] binEdges; switch (strategy) { case UNIFORM: binEdges = findBinEdgesWithUniformStrategy(list, numBins); break; case QUANTILE: binEdges = findBinEdgesWithQuantileStrategy(list, numBins); break; case KMEANS: binEdges = findBinEdgesWithKMeansStrategy(list, numBins); break; default: throw new UnsupportedOperationException( "Unsupported " + STRATEGY + " type: " + strategy + "."); } collector.collect(new KBinsDiscretizerModelData(binEdges)); } }); modelData.getTransformation().setParallelism(1); KBinsDiscretizerModel model = new KBinsDiscretizerModel().setModelData(tEnv.fromDataStream(modelData)); ParamUtils.updateExistingParams(model, getParamMap()); return model; } @Override public Map, Object> getParamMap() { return paramMap; } @Override public void save(String path) throws IOException { ReadWriteUtils.saveMetadata(this, path); } public static KBinsDiscretizer load(StreamTableEnvironment tEnv, String path) throws IOException { return ReadWriteUtils.loadStageParam(path); } private static double[][] findBinEdgesWithUniformStrategy( List input, int numBins) { DenseVector minVector = input.get(0); DenseVector maxVector = input.get(1); int numColumns = minVector.size(); double[][] binEdges = new double[numColumns][]; for (int columnId = 0; columnId < numColumns; columnId++) { double min = minVector.get(columnId); double max = maxVector.get(columnId); if (min == max) { LOG.warn("Feature " + columnId + " is constant and the output will all be zero."); binEdges[columnId] = new double[] {Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY}; } else { double width = (max - min) / numBins; binEdges[columnId] = new double[numBins + 1]; binEdges[columnId][0] = min; for (int edgeId = 1; edgeId < numBins + 1; edgeId++) { binEdges[columnId][edgeId] = binEdges[columnId][edgeId - 1] + width; } } } return binEdges; } private static double[][] findBinEdgesWithQuantileStrategy( List input, int numBins) { int numColumns = input.get(0).size(); int numData = input.size(); double[][] binEdges = new double[numColumns][]; double[] features = new double[numData]; for (int columnId = 0; columnId < numColumns; columnId++) { for (int i = 0; i < numData; i++) { features[i] = input.get(i).get(columnId); } Arrays.sort(features); if (features[0] == features[numData - 1]) { LOG.warn("Feature " + columnId + " is constant and the output will all be zero."); binEdges[columnId] = new double[] {Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY}; } else { double[] tempBinEdges; if (features.length > numBins) { double width = 1.0 * features.length / numBins; tempBinEdges = new double[numBins + 1]; for (int binEdgeId = 0; binEdgeId < numBins; binEdgeId++) { tempBinEdges[binEdgeId] = features[(int) (binEdgeId * width)]; } tempBinEdges[numBins] = features[numData - 1]; } else { tempBinEdges = features; } // Bins with zero width should be converted to a non-empty bin. Map edgesAndCnt = new HashMap<>(numBins); for (double edge : tempBinEdges) { edgesAndCnt.put(edge, edgesAndCnt.getOrDefault(edge, 0) + 1); } List edges = new ArrayList<>(); for (Map.Entry edgeAndCnt : edgesAndCnt.entrySet()) { double edge = edgeAndCnt.getKey(); int cnt = edgeAndCnt.getValue(); edges.add(edge); if (cnt > 1) { edges.add(edge); } } tempBinEdges = edges.stream().mapToDouble(Double::doubleValue).toArray(); Arrays.sort(tempBinEdges); int i = 1; // If there are two consecutive bin edges with the same value, we update the right // edge as the average of its two neighbors. for (; i < tempBinEdges.length - 1; i++) { if (tempBinEdges[i] == tempBinEdges[i - 1]) { tempBinEdges[i] = (tempBinEdges[i + 1] + tempBinEdges[i - 1]) / 2; } } // If the last two bin edges are the same, we update the left bin edge as the // average of its two neighbors. if (tempBinEdges[i] == tempBinEdges[i - 1]) { tempBinEdges[i - 1] = (tempBinEdges[i] + tempBinEdges[i - 2]) / 2; } binEdges[columnId] = tempBinEdges; } } return binEdges; } private static double[][] findBinEdgesWithKMeansStrategy(List input, int numBins) { int numColumns = input.get(0).size(); int numData = input.size(); double[][] binEdges = new double[numColumns][numBins + 1]; double[] features = new double[numData]; double[] kMeansCentroids = new double[numBins]; double[] sumByCluster = new double[numBins]; for (int columnId = 0; columnId < numColumns; columnId++) { for (int i = 0; i < numData; i++) { features[i] = input.get(i).get(columnId); } Arrays.sort(features); if (features[0] == features[numData - 1]) { LOG.warn("Feature " + columnId + " is constant and the output will all be zero."); binEdges[columnId] = new double[] {Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY}; } else { // Checks whether there are more than {numBins} distinct feature values in each // column. // If the number of distinct values is less than {numBins + 1}, then we do not need // to conduct KMeans. Instead, we switch to using {@link // KBinsDiscretizerParams#UNIFORM} for binning. Set distinctFeatureValues = new HashSet<>(numBins + 1); for (double feature : features) { distinctFeatureValues.add(feature); if (distinctFeatureValues.size() >= numBins + 1) { break; } } if (distinctFeatureValues.size() <= numBins) { double min = features[0]; double max = features[features.length - 1]; double width = (max - min) / numBins; binEdges[columnId] = new double[numBins + 1]; binEdges[columnId][0] = min; for (int edgeId = 1; edgeId < numBins + 1; edgeId++) { binEdges[columnId][edgeId] = binEdges[columnId][edgeId - 1] + width; } continue; } else { // Conducts KMeans here. double width = 1.0 * features.length / numBins; for (int clusterId = 0; clusterId < numBins; clusterId++) { kMeansCentroids[clusterId] = features[(int) (clusterId * width)]; } // Default values for KMeans. final double tolerance = 1e-4; final int maxIterations = 300; double oldLoss = Double.MAX_VALUE; double relativeLoss = Double.MAX_VALUE; int iter = 0; int[] countByCluster = new int[numBins]; while (iter < maxIterations && relativeLoss > tolerance) { double loss = 0; for (double featureValue : features) { double minDistance = Math.abs(kMeansCentroids[0] - featureValue); int clusterId = 0; for (int i = 1; i < kMeansCentroids.length; i++) { double distance = Math.abs(kMeansCentroids[i] - featureValue); if (distance < minDistance) { minDistance = distance; clusterId = i; } } countByCluster[clusterId]++; sumByCluster[clusterId] += featureValue; loss += minDistance; } // Updates cluster. for (int clusterId = 0; clusterId < kMeansCentroids.length; clusterId++) { kMeansCentroids[clusterId] = sumByCluster[clusterId] / countByCluster[clusterId]; } loss /= features.length; relativeLoss = Math.abs(loss - oldLoss); oldLoss = loss; iter++; Arrays.fill(sumByCluster, 0); Arrays.fill(countByCluster, 0); } Arrays.sort(kMeansCentroids); binEdges[columnId] = new double[numBins + 1]; binEdges[columnId][0] = features[0]; binEdges[columnId][numBins] = features[features.length - 1]; for (int binEdgeId = 1; binEdgeId < numBins; binEdgeId++) { binEdges[columnId][binEdgeId] = (kMeansCentroids[binEdgeId - 1] + kMeansCentroids[binEdgeId]) / 2; } } } } return binEdges; } }




© 2015 - 2025 Weber Informatics LLC | Privacy Policy