
org.apache.flink.ml.stats.anovatest.ANOVATest Maven / Gradle / Ivy
The newest version!
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.ml.stats.anovatest;
import org.apache.flink.api.common.functions.AggregateFunction;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.ml.api.AlgoOperator;
import org.apache.flink.ml.common.datastream.DataStreamUtils;
import org.apache.flink.ml.common.param.HasFlatten;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.types.Row;
import org.apache.flink.util.Preconditions;
import org.apache.commons.math3.distribution.FDistribution;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.IntStream;
/**
* An AlgoOperator which implements the ANOVA test algorithm.
*
* See Wikipedia for more
* information on ANOVA test.
*
*
The input of this algorithm is a table containing a labelColumn of numerical type and a
* featuresColumn of vector type. Each index in the input vector represents a feature to be tested.
* By default, the output of this algorithm is a table containing a single row with the following
* columns, each of which has one value per feature.
*
*
* - "pValues": vector
*
- "degreesOfFreedom": int array
*
- "fValues": vector
*
*
* The output of this algorithm can be flattened to multiple rows by setting {@link
* HasFlatten#FLATTEN} to true, which would contain the following columns:
*
*
* - "featureIndex": int
*
- "pValue": double
*
- "degreeOfFreedom": int
*
- "fValues": double
*
*/
public class ANOVATest implements AlgoOperator, ANOVATestParams {
private final Map, Object> paramMap = new HashMap<>();
public ANOVATest() {
ParamUtils.initializeMapWithDefaultValues(paramMap, this);
}
@Override
public Table[] transform(Table... inputs) {
Preconditions.checkArgument(inputs.length == 1);
final String featuresCol = getFeaturesCol();
final String labelCol = getLabelCol();
StreamTableEnvironment tEnv =
(StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
DataStream> inputData =
tEnv.toDataStream(inputs[0])
.map(
(MapFunction>)
row -> {
Number number = (Number) row.getField(labelCol);
Preconditions.checkNotNull(
number, "Input data must contain label value.");
return new Tuple2<>(
((Vector) row.getField(featuresCol)),
number.doubleValue());
},
Types.TUPLE(VectorTypeInfo.INSTANCE, Types.DOUBLE));
DataStream> streamWithANOVA =
DataStreamUtils.aggregate(
inputData,
new ANOVAAggregator(),
Types.OBJECT_ARRAY(
Types.TUPLE(
Types.DOUBLE,
Types.DOUBLE,
Types.MAP(
Types.DOUBLE,
Types.TUPLE(Types.DOUBLE, Types.LONG)))),
Types.LIST(Types.ROW(Types.INT, Types.DOUBLE, Types.LONG, Types.DOUBLE)));
return new Table[] {convertToTable(tEnv, streamWithANOVA, getFlatten())};
}
/** Computes the p-value, fValues and the number of degrees of freedom of input features. */
@SuppressWarnings("unchecked")
private static class ANOVAAggregator
implements AggregateFunction<
Tuple2,
Tuple3>>[],
List> {
@Override
public Tuple3>>[] createAccumulator() {
return new Tuple3[0];
}
@Override
public Tuple3>>[] add(
Tuple2 featuresAndLabel,
Tuple3>>[] acc) {
Vector features = featuresAndLabel.f0;
double label = featuresAndLabel.f1;
int numOfFeatures = features.size();
if (acc.length == 0) {
acc = new Tuple3[features.size()];
for (int i = 0; i < numOfFeatures; i++) {
acc[i] = Tuple3.of(0.0, 0.0, new HashMap<>());
}
}
for (int i = 0; i < numOfFeatures; i++) {
double featureValue = features.get(i);
acc[i].f0 += featureValue;
acc[i].f1 += featureValue * featureValue;
if (acc[i].f2.containsKey(label)) {
acc[i].f2.get(label).f0 += featureValue;
acc[i].f2.get(label).f1 += 1L;
} else {
acc[i].f2.put(label, Tuple2.of(featureValue, 1L));
}
}
return acc;
}
@Override
public List getResult(
Tuple3>>[] acc) {
List results = new ArrayList<>();
for (int i = 0; i < acc.length; i++) {
Tuple3 resultOfANOVA =
computeANOVA(acc[i].f0, acc[i].f1, acc[i].f2);
results.add(Row.of(i, resultOfANOVA.f0, resultOfANOVA.f1, resultOfANOVA.f2));
}
return results;
}
@Override
public Tuple3>>[] merge(
Tuple3>>[] acc1,
Tuple3>>[] acc2) {
if (acc1.length == 0) {
return acc2;
}
if (acc2.length == 0) {
return acc1;
}
IntStream.range(0, acc1.length)
.forEach(
i -> {
acc2[i].f0 += acc1[i].f0;
acc2[i].f1 += acc1[i].f1;
acc1[i].f2.forEach(
(k, v) -> {
if (acc2[i].f2.containsKey(k)) {
acc2[i].f2.get(k).f0 += v.f0;
acc2[i].f2.get(k).f1 += v.f1;
} else {
acc2[i].f2.put(k, v);
}
});
});
return acc2;
}
private Tuple3 computeANOVA(
double sum, double sumOfSq, HashMap> summary) {
long numOfClasses = summary.size();
long numOfSamples = summary.values().stream().mapToLong(t -> t.f1).sum();
double sqSum = sum * sum;
double ssTot = sumOfSq - sqSum / numOfSamples;
double totalSqSum = 0;
for (Tuple2 t : summary.values()) {
totalSqSum += t.f0 * t.f0 / t.f1;
}
double sumOfSqBetween = totalSqSum - (sqSum / numOfSamples);
double sumOfSqWithin = ssTot - sumOfSqBetween;
long degreeOfFreedomBetween = numOfClasses - 1;
Preconditions.checkArgument(
degreeOfFreedomBetween > 0, "Num of classes should be positive.");
long degreeOfFreedomWithin = numOfSamples - numOfClasses;
Preconditions.checkArgument(
degreeOfFreedomWithin > 0,
"Num of samples should be greater than num of classes.");
double meanSqBetween = sumOfSqBetween / degreeOfFreedomBetween;
double meanSqWithin = sumOfSqWithin / degreeOfFreedomWithin;
double fValue = meanSqBetween / meanSqWithin;
FDistribution fd = new FDistribution(degreeOfFreedomBetween, degreeOfFreedomWithin);
double pValue = 1 - fd.cumulativeProbability(fValue);
long degreeOfFreedom = degreeOfFreedomBetween + degreeOfFreedomWithin;
return Tuple3.of(pValue, degreeOfFreedom, fValue);
}
}
private Table convertToTable(
StreamTableEnvironment tEnv, DataStream> datastream, boolean flatten) {
if (flatten) {
DataStream output =
datastream
.flatMap(
(FlatMapFunction, Row>)
(list, collector) -> list.forEach(collector::collect))
.setParallelism(1)
.returns(Types.ROW(Types.INT, Types.DOUBLE, Types.LONG, Types.DOUBLE));
return tEnv.fromDataStream(output)
.as("featureIndex", "pValue", "degreeOfFreedom", "fValue");
} else {
DataStream> output =
datastream.map(
new MapFunction, Tuple3>() {
@Override
public Tuple3 map(
List rows) {
int numOfFeatures = rows.size();
DenseVector pValues = new DenseVector(numOfFeatures);
DenseVector fValues = new DenseVector(numOfFeatures);
long[] degrees = new long[numOfFeatures];
for (int i = 0; i < numOfFeatures; i++) {
Row row = rows.get(i);
pValues.set(i, (double) row.getField(1));
degrees[i] = (long) row.getField(2);
fValues.set(i, (double) row.getField(3));
}
return Tuple3.of(pValues, degrees, fValues);
}
});
return tEnv.fromDataStream(output).as("pValues", "degreesOfFreedom", "fValues");
}
}
@Override
public void save(String path) throws IOException {
ReadWriteUtils.saveMetadata(this, path);
}
public static ANOVATest load(StreamTableEnvironment tEnv, String path) throws IOException {
return ReadWriteUtils.loadStageParam(path);
}
@Override
public Map, Object> getParamMap() {
return paramMap;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy