All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.datastax.data.prepare.spark.dataset.FlatMapOperator Maven / Gradle / Ivy

The newest version!
package com.datastax.data.prepare.spark.dataset;

import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.datastax.insight.core.driver.SparkContextBuilder;
import com.datastax.insight.spec.Operator;
import com.datastax.insight.annonation.InsightComponent;
import com.datastax.insight.annonation.InsightComponentArg;
import com.datastax.data.prepare.spark.dataset.params.FlatMapParam;
import com.datastax.data.prepare.util.Consts;
import com.datastax.data.prepare.util.CustomException;
import com.datastax.data.prepare.util.SharedMethods;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class FlatMapOperator implements Operator {
    private static Logger logger = LoggerFactory.getLogger(FlatMapOperator.class);

    @InsightComponent(name = "列拆分", description = "将一列按照一定规则拆分成多列")
    public static  Dataset flatMap(
            @InsightComponentArg(externalInput = true, name = "数据集", description = "数据集") Dataset data,
            @InsightComponentArg(name = "参数", description = "参数") JSONArray array) {
        if(array.isEmpty()) {
            logger.info("列拆分组件参数为空, 返回原数据集");
            return data;
        }
        if(data == null) {
            logger.info("列拆分组件中的数据集为空, 返回空");
            return null;
        }
        List flatMapParams = new ArrayList<>();
        for(int i = 0; i < array.size(); i++) {
            JSONObject jsonObject = array.getJSONObject(i);
            FlatMapParam flatMapParam = new FlatMapParam();
            String column = jsonObject.getString("selector");
            String pattern = jsonObject.getString("selectorValue");
            int sliceNum = jsonObject.getInteger("method");
            String columnName = jsonObject.getString("methodValue");
            if(sliceNum <= 0) {
//                logger.info("分列数小于等于0,若选中列中的某一个值切分后生成值的数量与其他不等,便会报错!");
                throw new CustomException("分列数小于等于0,若选中列中的某一个值切分后生成值的数量与其他不等,便会报错。因此需要设定该值");
            }
            if(column == null || pattern == null) {
                continue;
            }
            if(columnName == null || columnName.split(Consts.DELIMITER).length != sliceNum) {
                logger.info("切分后的列名为空或者切分后的列数和填写的分列数不符,默认为列名加_slice加数值迭加(1,2,3...)");
                StringBuffer stringBuffer = new StringBuffer();
                for(int j = 0; j < sliceNum; j++) {
                    stringBuffer.append(column).append("_sp").append(j + 1);
                    if(j + 1 != sliceNum) {
                        stringBuffer.append(Consts.DELIMITER);
                    }
                }
                columnName = stringBuffer.toString();
            }
            flatMapParam.setColumn(column);
            flatMapParam.setPattern(pattern);
            flatMapParam.setSliceNum(sliceNum);
            flatMapParam.setSliceColumnName(columnName);
            flatMapParams.add(flatMapParam);
        }
        return flatMap1(data, flatMapParams);
    }

    protected static  Dataset flatMap1(Dataset data, List flatMapParams) {
        SparkSession spark = SparkContextBuilder.getSession();
        Map schemaRecord = new HashMap<>();
        StructField[] fields = data.schema().fields();
        SharedMethods.recordSchema(fields, schemaRecord);
        Map infos = new HashMap<>();
        for(FlatMapParam flatMapParam : flatMapParams) {
            int position = Integer.valueOf(schemaRecord.get(flatMapParam.getColumn())[0].toString());
            if(infos.containsKey(position)) {
                logger.info("分割的列" + flatMapParam.getColumn() + "重复,跳过");
            } else {
                infos.put(position, flatMapParam);
            }
        }
        JavaRDD javaRDD = data.toDF().javaRDD().map(new Function() {
            @Override
            public Row call(Row r) throws Exception {
                List list = new ArrayList<>();
                for(int i = 0; i < r.size(); i++) {
                    if(infos.containsKey(i)) {
                        Object obj = r.get(i);
                        int sliceNum = infos.get(i).getSliceNum();
                        if(obj == null) {
                            for(int j = 0; j < sliceNum; j++) {
                                list.add(null);
                            }
                        } else {
                            String value = obj.toString();
                            String[] values = value.split(infos.get(i).getPattern());
                            for(int j = 0; j < sliceNum; j++) {
                                if(j < values.length) {
                                    list.add(values[j]);
                                } else {
                                    list.add(null);
                                }
                            }
                        }
                    } else {
                        list.add(r.get(i));
                    }
                }
                return RowFactory.create(list.toArray());
            }
        });

        StructType newSchema = new StructType();
        for(int i = 0; i < fields.length; i++) {
            if(infos.containsKey(i)) {
                String sliceColumnName = infos.get(i).getSliceColumnName();
                String[] sliceColumnNames = sliceColumnName.split(Consts.DELIMITER);
                for(String name : sliceColumnNames) {
                    newSchema = newSchema.add(name, DataTypes.StringType, true);
                }
            } else {
                newSchema = newSchema.add(fields[i]);
            }
        }
        return (Dataset) spark.createDataFrame(javaRDD, newSchema);
    }

}