All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.datastax.data.prepare.spark.dataset.DataBinningOperator Maven / Gradle / Ivy

The newest version!
package com.datastax.data.prepare.spark.dataset;

import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.datastax.insight.annonation.InsightComponent;
import com.datastax.insight.core.driver.SparkContextBuilder;
import com.datastax.insight.spec.Operator;
import com.datastax.insight.annonation.InsightComponent;
import com.datastax.insight.annonation.InsightComponentArg;
import com.datastax.data.prepare.spark.dataset.params.DataBinning;
import com.datastax.data.prepare.util.Consts;
import com.datastax.data.prepare.util.SharedMethods;
import com.google.common.base.Strings;
import org.apache.spark.ml.feature.Bucketizer;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.List;

import static org.apache.spark.sql.functions.*;

public class DataBinningOperator implements Operator {
    private static final Logger logger = LoggerFactory.getLogger(DataBinningOperator.class);

    static  Dataset binning(Dataset data, DataBinning... dataBinnings) {
        if(data.count() == 0 || dataBinnings.length == 0) {
            logger.info("Detail parameter of DataBinning is empty or Dataset is empty");
            return data;
        }
        StructField[] recordFields = new StructField[data.schema().fields().length];
        for (DataBinning dataBinning : dataBinnings) {
            if (dataBinning == null) {
                continue;
            }
            StructField[] fields = SharedMethods.attributeFilter(data, dataBinning.getAttributeSelector(), dataBinning.isInvertSelection(),
                    dataBinning.getAttribute(), dataBinning.getRegularExpression(), dataBinning.getValueType());
            if (fields == null) {
                continue;
            }
            data = binningHandle(data, dataBinning, SharedMethods.dropDuplicates(recordFields, fields));
        }
        return data;
    }


    static  Dataset binning(Dataset data, String json) {
        if(json == null || "".equals(json)) {
            return data;
        }
        JSONArray array = JSON.parseArray(json);
        JSONObject object = array.getJSONObject(0);
        if(Consts.SIZE.equals(object.getString("method")) || Consts.FREQUENCY.equals(object.getString("method"))) {
            return basicDiscretize(data, array);
        }
        if(Consts.BINNING.equals(object.getString("method"))) {
            return binDiscretize(data, array);
        }
        if(Consts.USER_SPECIFICATION.equals(object.getString("method"))) {
            return userDefineDiscretize(data, array);
        }

        return data;
    }

    @InsightComponent(name = "基本离散化", type = "com.datastax.insight.dataprprocess.basicDiscretize", description = "数据离散化", order = 500901)
    public static  Dataset basicDiscretize(
            @InsightComponentArg(externalInput = true, name = "data", description = "待分箱的数据集") Dataset data,
            @InsightComponentArg(name = "参数", description = "数据离散的json参数") JSONArray array) {
        if(array.isEmpty()) {
            return data;
        }
        DataBinning[] dataBinnings = new DataBinning[array.size()];
        int sign = 0;
        for(int i=0; i Dataset binDiscretize(
            @InsightComponentArg(externalInput = true, name = "data", description = "待分箱的数据集") Dataset data,
            @InsightComponentArg(name = "参数", description = "数据离散的json参数") JSONArray array) {
        if(array.isEmpty()) {
            return data;
        }
        DataBinning[] dataBinnings = new DataBinning[array.size()];
        int sign = 0;
        for(int i=0; i Dataset userDefineDiscretize(
            @InsightComponentArg(externalInput = true, name = "data", description = "待分箱的数据集") Dataset data,
            @InsightComponentArg(name = "参数", description = "数据离散的json参数") JSONArray array) {
        if(array.isEmpty()) {
            return data;
        }
        DataBinning[] dataBinnings = new DataBinning[array.size()];
        int sign = 0;
        for(int i=0; i Dataset binningHandle(Dataset data, DataBinning dataBinning, StructField[] fields ) {
        if(fields == null) {
            logger.info("没有属性被选中,返回原数据集");
            return data;
        }
        if(dataBinning.getBinningType() == null || "".equals(dataBinning.getBinningType())) {
            logger.info("数据离散类型为空,返回原数据集");
            return data;
        }
        if(Consts.SIZE.equals(dataBinning.getBinningType())) {
            if(dataBinning.getBinSize() < 1 || dataBinning.getBinSize() >= data.count()) {
                logger.info("大小离散化的binSize小于1或者大于数据集的行数,返回原数据");
                return data;
            }
            for (StructField field : fields) {
                if (SharedMethods.isNumericType(field)) {
                    data = sizeDiscretize(data, field, dataBinning);
                }
            }
        }
        if(Consts.BINNING.equals(dataBinning.getBinningType())) {
            if(dataBinning.getBinSize() < 1 ) {
                logger.info("分级离散化的binSize小于1,返回原数据");
                return data;
            }
            for (StructField field : fields) {
                if (SharedMethods.isNumericType(field)) {
                    data = binningDiscretize(data, field, dataBinning);
                }
            }
        }
        if(Consts.FREQUENCY.equals(dataBinning.getBinningType())) {
            if(dataBinning.getBinSize() < 1 || dataBinning.getBinSize() >= data.count()) {
                logger.info("频率离散化的binSize小于1或者大于数据集的行数,返回原数据集");
                return data;
            }
            for (StructField field : fields) {
                if (SharedMethods.isNumericType(field)) {
                    data = frequencyDiscretize(data, field, dataBinning);
                }
            }
        }
        if(Consts.USER_SPECIFICATION.equals(dataBinning.getBinningType())) {
            for (StructField field : fields) {
                if (SharedMethods.isNumericType(field)) {
                    data = userSpecificationDiscretize(data, field, dataBinning);
                }
            }
        }
        if(Consts.ENTROPY.equals(dataBinning.getBinningType())) {
            for (StructField field : fields) {
                if (SharedMethods.isNumericType(field)) {
                    data = entropyDiscretize(data, field);
                }
            }
        }

        return data;
    }

    /**
     * 大小离散化
     */
    private static  Dataset sizeDiscretize(Dataset data, final StructField field, final DataBinning dataBinning) {
        int binSize = dataBinning.getBinSize();
        Row[] rows = (Row[]) data.sort(field.name()).select(field.name()).collect();
        double[] columnData = changeType(rows);
        if(columnData == null) {
            logger.info(field.name() + "列全部为空,返回原数据集");
            return data;
        }
        DecimalFormat format = new DecimalFormat("#.000");
        int mod = (int) data.count()%binSize;
        int size = (int) data.count()/binSize+(mod == 0 ? 0 : 1)+2;
        int position = mod == 0 ? binSize-1 : mod-1;
        int preposition = -1;
        int i = 0;
        int point = 0;
        double[] doubles = new double[size];
        doubles[point] = Double.NEGATIVE_INFINITY;
        while(columnData[i] == columnData[i+1]) {
            i++;
        }
        if(i >= mod) {
            doubles[++point] = Double.parseDouble(format.format((columnData[i]+columnData[i+1])/2.0));
            position = i+binSize-(i+1-mod)%binSize;
        }
        while(true) {
            if(position >= columnData.length-1) {
                break;
            }
            i = position;
            while(columnData[i] == columnData[position+1]) {
                i--;
                if(i == preposition) {
                    break;
                }
            }
            if(i == preposition) {
                preposition = position;
                position = position + binSize;
                continue;
            }
            double temp = Double.parseDouble(format.format((columnData[i]+columnData[i+1])/2.0));
            if(temp != doubles[point]) {
                doubles[++point] = temp;
            }
            preposition = position;
            position = position + binSize;
        }
        if(point == 0) {
            logger.info("数据集的" + field.name() + "列全部相同,返回原数据集");
            return data;
        }
        doubles[++point] = Double.POSITIVE_INFINITY;
        return bucketizer(data, dropRedundant(doubles, point), field);
    }

    /**
     * 分级离散化
     */
    private static  Dataset binningDiscretize(Dataset data, final StructField field, final DataBinning dataBinning) {
        Row[] rows = (Row[]) data.select(min(field.name()), max(field.name())).collect();
        if(rows[0].get(0) == null && rows[0].get(1) == null) {
            logger.info(field.name() + "列的最大值最小值都为空,表示该列为空列,返回原数据集");
            return data;
        }
        double min = Double.parseDouble(rows[0].get(0).toString());
        double max = Double.parseDouble(rows[0].get(1).toString());
        if(min == max) {
            logger.info(field.name() + "列最大值和最小值相等,返回一个bucket");
            return bucketizer(data, new double[]{Double.NEGATIVE_INFINITY, min, Double.POSITIVE_INFINITY}, field);
        }
        if(dataBinning.isDefineBoundaries()) {
            if(dataBinning.getMinValue() >= dataBinning.getMaxValue()) {
                logger.info("分级离散化用户自定的边界的最小值大于或等于最大值,边界应在[" + min + ", " + max + "]范围内,返回原数据集");
                return data;
            }
            if(dataBinning.getMaxValue() < min) {
                logger.info("分级离散化用户自定的边界的最大值小于数据集最小值,边界应在[" + min + ", " + max + "]范围内,返回原数据集");
                return data;
            }
            if(dataBinning.getMinValue() > max) {
                logger.info("分级离散化用户自定的边界的最小值大于数据集最大值,边界应在[" + min + ", " + max + "]范围内,返回原数据集");
                return data;
            }
            if(dataBinning.getMinValue() > min) {
                min = dataBinning.getMinValue();
            }
            if(dataBinning.getMaxValue() < max) {
                max = dataBinning.getMaxValue();
            }
        }
        double[] doubles = new double[dataBinning.getBinSize() + 3];
        int position = 0;
        doubles[position] = Double.NEGATIVE_INFINITY;
        doubles[++position] = min;
        if(dataBinning.getBinSize() != 1) {
            double interval = (max-min)/dataBinning.getBinSize();
            for(int i=1; i Dataset frequencyDiscretize(Dataset data, final StructField field, final DataBinning dataBinning) {
        Row[] rows = (Row[]) data.dropDuplicates(field.name()).sort(field.name()).select(field.name()).collect();
        double[] columnData = changeType(rows);
        if(columnData == null) {
            logger.info(field.name() + "列全部为空,返回原数据集");
            return data;
        }
        if(dataBinning.getBinSize() >= columnData.length) {
            logger.info("频率离散化的binSize大于或等于" + field.name() + "列去重和去空之后的长度,返回该列内容为0.0");
            return bucketizer(data, new double[]{Double.NEGATIVE_INFINITY, columnData[columnData.length-1]+1, Double.POSITIVE_INFINITY}, field);
        }
        int position = columnData.length/dataBinning.getBinSize();
        int interval = position;
        int mod = columnData.length%dataBinning.getBinSize();
        double[] doubles = new double[dataBinning.getBinSize()+2];
        int i = 0, j = 1;
        doubles[i] = Double.NEGATIVE_INFINITY;
        while(true) {
            if(j != 1) {
                if(mod != 0) {
                    position = position + interval + 1;
                    mod--;
                }else {
                    position = position + interval;
                }
            }
            if(position >= columnData.length) {
                break;
            }
            doubles[++i] = (columnData[position]+columnData[position-1])/2.0;
            j++;
        }
        doubles[++i] = Double.POSITIVE_INFINITY;
        return bucketizer(data, dropRedundant(doubles, i), field);
    }

    private static  Dataset userSpecificationDiscretize(Dataset data, final StructField field, final DataBinning dataBinning) {
        if(dataBinning.getUpperLimits() == null || dataBinning.getClassNames() == null) {
            logger.info("用户自定离散化参数为空,返回原数据集");
            return data;
        }
        return bucketizer(data, dataBinning.getUpperLimits(), field, dataBinning.getClassNames());
    }

    //todo 最小熵离散化
    private static  Dataset entropyDiscretize(Dataset data, final StructField field) {
//        Row[] rows = (Row[]) data.groupBy(field.name()).count().sort(field.name()).select(field.name(), "count").collect();
//        double sign;
//        int total = 0;
//        for(int i=0; i Dataset bucketizer(Dataset data, double[] doubles, StructField field) {
        if(judge(doubles)) {
            logger.info("bucketizer范围为[-Infinity, Infinity]");
            return data;
        }
        String bucketedName = "bucketed-" + field.name();
        Bucketizer bucketizer = new Bucketizer()
                .setInputCol(field.name())
                .setOutputCol(bucketedName)
                .setSplits(doubles);
        return (Dataset) bucketizer.transform(data).withColumn(field.name(), col(bucketedName)).drop(bucketedName);
    }

    private static  Dataset bucketizer(Dataset data, double[] doubles, StructField field, String[] strings) {
        if(judge(doubles)) {
            logger.info("bucketizer范围为[-Infinity, Infinity]");
            return data;
        }
        data = bucketizer(data, doubles, field);
        String id = "auto_increasing_id";
        String joinColumn = "join_column_for_type_change";
        SparkSession session = SparkContextBuilder.getSession();
        Row[] temp = (Row[]) data.select(field.name()).collect();
        //todo andy 大数据情况,性能差,可能会挂 用scala写
        List result = new ArrayList<>(temp.length);
        for(int i=0; i) data.withColumn(id, monotonically_increasing_id()).join(data1, id).withColumn(field.name(), col(joinColumn)).drop(id, joinColumn);
    }

    private static boolean judge(double[] doubles) {
        return doubles[0] == Double.NEGATIVE_INFINITY && doubles[1] == Double.POSITIVE_INFINITY;
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy