All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.alibaba.alink.operator.common.statistics.StatisticsHelper Maven / Gradle / Ivy

package com.alibaba.alink.operator.common.statistics;

import java.security.InvalidParameterException;
import java.util.ArrayList;
import java.util.List;

import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummarizer;
import com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary;
import com.alibaba.alink.operator.common.statistics.basicstatistic.CorrelationResult;
import com.alibaba.alink.operator.common.statistics.basicstatistic.DenseVectorSummarizer;
import com.alibaba.alink.operator.common.statistics.basicstatistic.TableSummarizer;
import com.alibaba.alink.operator.common.statistics.basicstatistic.TableSummary;
import com.alibaba.alink.operator.common.statistics.basicstatistic.VectorSummarizerUtil;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;

import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

/**
 * Util for batch statistical calculation.
 */
public class StatisticsHelper {
    /**
     * Transform cols or vector col to vector and compute summary.
     * it deal with table which only have number cols and not has missing value.
     * It return tuple2, f0 is data, f1 is summary. data f0 is vector, data f1 is row which is reserved cols.
     */
    public static Tuple2>, DataSet> summaryHelper(BatchOperator in,
                                                                                                 String[] selectedColNames,
                                                                                                 String vectorColName,
                                                                                                 String[] reservedColNames) {
        DataSet> data = transformToVector(in, selectedColNames, vectorColName, reservedColNames);

        DataSet vectorDataSet = data
            .map(new MapFunction, Vector>() {
                @Override
                public Vector map(Tuple2 tuple2) {
                    return tuple2.f0;
                }
            });

        return Tuple2.of(data, summary(vectorDataSet));
    }

    /**
     * Transform to vector without reservedColNames and compute summary.
     * it deal with table which only have number cols and has no missing value.
     */
    public static Tuple2, DataSet> summaryHelper(BatchOperator in,
                                                                                    String[] selectedColNames,
                                                                                    String vectorColName) {
        DataSet data = transformToVector(in, selectedColNames, vectorColName);
        return Tuple2.of(data, summary(data));
    }

    /**
     * table summary, selectedColNames must be set.
     */
    public static DataSet summary(BatchOperator in, String[] selectedColNames) {
        return summarizer(in, selectedColNames, false)
            .map(new MapFunction() {
                @Override
                public TableSummary map(TableSummarizer summarizer) throws Exception {
                    return summarizer.toSummary();
                }
            }).name("toSummary");
    }

    /**
     * vector stat, selectedColName must be set.
     */
    public static DataSet vectorSummary(BatchOperator in, String selectedColName) {
        return vectorSummarizer(in, selectedColName, false)
            .map(new MapFunction() {
                @Override
                public BaseVectorSummary map(BaseVectorSummarizer summarizer) {
                    return summarizer.toSummary();
                }
            }).name("toSummary");
    }

    /**
     * given vector dataSet, return vector summary.
     */
    public static DataSet summary(DataSet data) {
        return summarizer(data, false)
            .map(new MapFunction() {
                @Override
                public BaseVectorSummary map(BaseVectorSummarizer summarizer) throws Exception {
                    return summarizer.toSummary();
                }
            }).name("toSummary");
    }

    /**
     * calculate correlation. result is tuple2, f0 is summary, f1 is correlation.
     */
    public static DataSet> pearsonCorrelation(BatchOperator in, String[] selectedColNames) {
        return summarizer(in, selectedColNames, true)
            .map(new MapFunction>() {
                @Override
                public Tuple2 map(TableSummarizer summarizer) {
                    return Tuple2.of(summarizer.toSummary(), summarizer.correlation());
                }
            });
    }

    /**
     * calculate correlation. result is tuple2, f0 is summary, f1 is correlation.
     */
    public static DataSet> vectorPearsonCorrelation(BatchOperator in, String selectedColName) {
        return vectorSummarizer(in, selectedColName, true)
            .map(new MapFunction>() {
                @Override
                public Tuple2 map(BaseVectorSummarizer summarizer) {
                    return Tuple2.of(summarizer.toSummary(), summarizer.correlation());
                }
            });
    }

    /**
     * transform to vector.
     */
    public static DataSet transformToVector(BatchOperator in,
                                                    String[] selectedColNames,
                                                    String vectorColName) {

        checkSimpleStatParameter(in, selectedColNames, vectorColName, null);

        if (selectedColNames != null && selectedColNames.length != 0) {
            int[] selectedColIndices = TableUtil.findColIndices(in.getColNames(), selectedColNames);
            return in.getDataSet().map(new ColsToVectorWithoutReservedColsMap(selectedColIndices));
        } else {
            int selectColIndex = TableUtil.findColIndex(in.getColNames(), vectorColName);
            return in.getDataSet().map(new VectorCoToVectorWithoutReservedColsMap(selectColIndex));
        }
    }

    /**
     * transform to vector, result is tuple2, f0 is vector, f1 is reserved cols.
     */
    public static DataSet> transformToVector(BatchOperator in,
                                                                 String[] selectedColNames,
                                                                 String vectorColName,
                                                                 String[] reservedColNames) {

        checkSimpleStatParameter(in, selectedColNames, vectorColName, reservedColNames);

        int[] reservedColIndices = TableUtil.findColIndices(in.getColNames(), reservedColNames);

        if (selectedColNames != null && selectedColNames.length != 0) {
            int[] selectedColIndices = TableUtil.findColIndices(in.getColNames(), selectedColNames);
            return in.getDataSet()
                .map(new ColsToVectorWithReservedColsMap(selectedColIndices, reservedColIndices))
                .name("transform_data");
        } else {
            int vectorColIndex = TableUtil.findColIndex(in.getColNames(), vectorColName);
            return in.getDataSet()
                .map(new VectorColToVectorWithReservedColsMap(vectorColIndex, reservedColIndices))
                .name("transform_data");
        }
    }

    /**
     * Transform vector or table columns to table columns.
     * selectedCols and vectorCol only be one non-empty, reserved cols can be empty.
     * 

* If selected cols is not null, it will combines selected cols and reserved cols, * and selected cols will transform to double type. *

* If vector col is not null, it will splits vector to cols and combines with reserved cols. *

* If selected cols and vector col both set, it will use vector col. */ public static DataSet transformToColumns(BatchOperator in, String[] selectedColNames, String vectorColName, String[] reservedColNames) { checkSimpleStatParameter(in, selectedColNames, vectorColName, reservedColNames); int[] reservedColIndices = null; if (reservedColNames != null) { reservedColIndices = TableUtil.findColIndices(in.getColNames(), reservedColNames); } if (selectedColNames != null && selectedColNames.length != 0) { int[] selectedColIndices = TableUtil.findColIndices(in.getSchema(), selectedColNames); return in.getDataSet().map(new ColsToDoubleColsMap(selectedColIndices, reservedColIndices)); } if (vectorColName != null) { int vectorColIndex = TableUtil.findColIndex(in.getColNames(), vectorColName); return in.getDataSet().map(new VectorColToTableMap(vectorColIndex, reservedColIndices)); } throw new InvalidParameterException("selectedColName and vectorColName must be set one only."); } /** * check parameters is invalid. */ private static void checkSimpleStatParameter(BatchOperator in, String[] selectedColNames, String vectorColName, String[] reservedColNames) { if (selectedColNames != null && selectedColNames.length != 0 && vectorColName != null) { throw new InvalidParameterException("selectedColName and vectorColName must be set one only."); } TableUtil.assertSelectedColExist(in.getColNames(), selectedColNames); TableUtil.assertNumericalCols(in.getSchema(), selectedColNames); TableUtil.assertSelectedColExist(in.getColNames(), vectorColName); TableUtil.assertVectorCols(in.getSchema(), vectorColName); TableUtil.assertSelectedColExist(in.getColNames(), reservedColNames); } /** * table stat */ private static DataSet summarizer(BatchOperator in, String[] selectedColNames, boolean calculateOuterProduct) { if (selectedColNames == null || selectedColNames.length == 0) { throw new InvalidParameterException("selectedColNames must be set."); } in = in.select(selectedColNames); return summarizer(in.getDataSet(), calculateOuterProduct, getNumericalColIndices(in.getColTypes()), selectedColNames); } /** * vector stat */ private static DataSet vectorSummarizer(BatchOperator in, String selectedColName, boolean calculateOuterProduct) { if (TableUtil.findColIndex(in.getColNames(), selectedColName) < 0) { throw new RuntimeException(selectedColName + " is not exist."); } DataSet data = transformToVector(in, null, selectedColName); return summarizer(data, calculateOuterProduct); } /** * given vector dataSet, return vector summary. if outerProduct is true, calculate covariance. */ private static DataSet summarizer(DataSet data, boolean bCov) { return data .mapPartition(new VectorSummarizerPartition(bCov)).name("summarizer_map") .reduce(new ReduceFunction() { @Override public BaseVectorSummarizer reduce(BaseVectorSummarizer value1, BaseVectorSummarizer value2) throws Exception { return VectorSummarizerUtil.merge(value1, value2); } }) .name("summarizer_reduce"); } /** * given data, return summary. numberIndices is the indices of cols which are number type in selected cols. */ private static DataSet summarizer(DataSet data, boolean bCov, int[] numberIndices, String[] selectedColNames) { return data .mapPartition(new TableSummarizerPartition(bCov, numberIndices, selectedColNames)) .reduce(new ReduceFunction() { @Override public TableSummarizer reduce(TableSummarizer left, TableSummarizer right) { return TableSummarizer.merge(left, right); } }); } /** * get indices which type is number */ private static int[] getNumericalColIndices(TypeInformation[] colTypes) { List numberColIndicesList = new ArrayList<>(); for (int i = 0; i < colTypes.length; i++) { if (TableUtil.isNumber(colTypes[i])) { numberColIndicesList.add(i); } } int[] numberColIndices = new int[numberColIndicesList.size()]; for (int i = 0; i < numberColIndices.length; i++) { numberColIndices[i] = numberColIndicesList.get(i); } return numberColIndices; } /** * MapFunction for Transform cols to vector. result it tuple, f0 is vector, f1 is reserved cols. */ private static class ColsToVectorWithReservedColsMap implements MapFunction> { private int[] selectedColIndices; private int[] reservedColIndices; ColsToVectorWithReservedColsMap(int[] selectedColIndices, int[] reservedColIndices) { this.selectedColIndices = selectedColIndices; this.reservedColIndices = reservedColIndices; } @Override public Tuple2 map(Row in) throws Exception { DenseVector dv = new DenseVector(selectedColIndices.length); for (int i = 0; i < selectedColIndices.length; i++) { dv.set(i, ((Number) in.getField(this.selectedColIndices[i])).doubleValue()); } Row out = new Row(reservedColIndices.length); for (int i = 0; i < this.reservedColIndices.length; ++i) { out.setField(i, in.getField(this.reservedColIndices[i])); } return new Tuple2<>(dv, out); } } /** * MapFunction for Transform cols to vector. */ private static class ColsToVectorWithoutReservedColsMap implements MapFunction { private int[] selectedColIndices; ColsToVectorWithoutReservedColsMap(int[] selectedColIndices) { this.selectedColIndices = selectedColIndices; } @Override public Vector map(Row in) throws Exception { DenseVector dv = new DenseVector(selectedColIndices.length); double[] data = dv.getData(); for (int i = 0; i < selectedColIndices.length; i++) { Object obj = in.getField(this.selectedColIndices[i]); if (obj instanceof Number) { data[i] = ((Number) obj).doubleValue(); } } return dv; } } /** * MapFunction for vector col transform to vector. result is tuple2, f0 is vector, f1 is reserved cols. */ private static class VectorColToVectorWithReservedColsMap implements MapFunction> { private int vectorColIndex; private int[] reservedColIndices; VectorColToVectorWithReservedColsMap(int vectorColIndex, int[] reservedColIndices) { this.vectorColIndex = vectorColIndex; this.reservedColIndices = reservedColIndices; } @Override public Tuple2 map(Row in) throws Exception { Vector vec = VectorUtil.getVector(in.getField(vectorColIndex)); if (vec == null) { throw new RuntimeException( "vector is null, please check your input data."); } Row out = new Row(reservedColIndices.length); for (int i = 0; i < this.reservedColIndices.length; ++i) { out.setField(i, in.getField(this.reservedColIndices[i])); } return Tuple2.of(vec, out); } } /** * MapFunction for vector col transform to vector. */ private static class VectorCoToVectorWithoutReservedColsMap implements MapFunction { private int vectorColIndex; VectorCoToVectorWithoutReservedColsMap(int vectorColIndex) { this.vectorColIndex = vectorColIndex; } @Override public Vector map(Row in) throws Exception { return VectorUtil.getVector(in.getField(vectorColIndex)); } } /** * Transform vector col to table columns. */ private static class VectorColToTableMap implements MapFunction { /** * vector col index. */ private int vectorColIndex; /** * reserved col indices. */ private int[] reservedColIndices; VectorColToTableMap(int vectorColIndex, int[] reservedColIndices) { this.vectorColIndex = vectorColIndex; this.reservedColIndices = reservedColIndices; if (reservedColIndices == null) { this.reservedColIndices = new int[0]; } } @Override public Row map(Row in) throws Exception { Row out; Vector vector = VectorUtil.getVector(in.getField(vectorColIndex)); DenseVector feature; if (vector instanceof DenseVector) { feature = (DenseVector) vector; } else { feature = ((SparseVector) vector).toDenseVector(); } if (feature.getData() != null) { out = new Row(feature.size() + reservedColIndices.length); for (int i = 0; i < feature.size(); i++) { out.setField(i, feature.get(i)); } for (int i = 0; i < reservedColIndices.length; i++) { out.setField(i + feature.size(), in.getField(reservedColIndices[i])); } } else { out = new Row(reservedColIndices.length); for (int i = 0; i < reservedColIndices.length; i++) { out.setField(i, in.getField(reservedColIndices[i])); } } return out; } } /** * MapFunction for cols. */ private static class ColsToDoubleColsMap implements MapFunction { private int[] selectedColIndices; private int[] reservedColIndices; ColsToDoubleColsMap(int[] selectedColIndices, int[] reservedColIndices) { this.selectedColIndices = selectedColIndices; this.reservedColIndices = reservedColIndices; if (reservedColIndices == null) { this.reservedColIndices = new int[0]; } } @Override public Row map(Row in) throws Exception { //table cols and reserved cols, table cols will be transform to double type. Row out = new Row(selectedColIndices.length + reservedColIndices.length); for (int i = 0; i < this.selectedColIndices.length; ++i) { out.setField(i, ((Number) in.getField(this.selectedColIndices[i])).doubleValue()); } for (int i = 0; i < reservedColIndices.length; i++) { out.setField(i + selectedColIndices.length, in.getField(reservedColIndices[i])); } return out; } } /** * It is table summary partition of one worker, will merge result later. */ public static class TableSummarizerPartition implements MapPartitionFunction { private boolean outerProduct; private int[] numericalIndices; private String[] selectedColNames; TableSummarizerPartition(boolean outerProduct, int[] numericalIndices, String[] selectedColNames) { this.outerProduct = outerProduct; this.numericalIndices = numericalIndices; this.selectedColNames = selectedColNames; } @Override public void mapPartition(Iterable iterable, Collector collector) throws Exception { TableSummarizer srt = new TableSummarizer(selectedColNames, numericalIndices, outerProduct); srt.colNames = selectedColNames; for (Row sv : iterable) { srt = (TableSummarizer) srt.visit(sv); } collector.collect(srt); } } /** * It is vector summary partition of one worker, will merge result later. */ public static class VectorSummarizerPartition implements MapPartitionFunction { private boolean outerProduct; public VectorSummarizerPartition(boolean outerProduct) { this.outerProduct = outerProduct; } @Override public void mapPartition(Iterable iterable, Collector collector) throws Exception { BaseVectorSummarizer srt = new DenseVectorSummarizer(outerProduct); for (Vector sv : iterable) { srt = srt.visit(sv); } collector.collect(srt); } } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy