com.datastax.data.prepare.spark.dataset.MultiStringIndexerOperator Maven / Gradle / Ivy
package com.datastax.data.prepare.spark.dataset;
import com.datastax.insight.spec.Operator;
import com.datastax.insight.annonation.InsightComponent;
import com.datastax.insight.annonation.InsightComponentArg;
import com.datastax.data.prepare.util.Consts;
import com.datastax.data.prepare.util.CustomException;
import com.datastax.data.prepare.util.SharedMethods;
import org.apache.spark.ml.feature.StringIndexer;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.DataTypes;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.HashMap;
import java.util.Map;
public class MultiStringIndexerOperator implements Operator {
private static final Logger logger = LoggerFactory.getLogger(MultiStringIndexerOperator.class);
@InsightComponent(name = "StringIndexer", description = "将字符串转换成索引,和标签数值化转换相同,支持多列转换")
public static Dataset multiStringIndexer(
@InsightComponentArg(externalInput = true, name = "数据集", description = "数据集") Dataset dataset,
@InsightComponentArg(name = "列名", description = "需要转换的列名,多个列名用分号隔开") String column,
@InsightComponentArg(name = "转换后的列名", description = "转换生成的索引列的列名,不能与现有列名重复") String indexerColumnName) {
if(dataset == null) {
logger.info("数据集为空");
return null;
}
if(column == null || column.length() == 0) {
throw new NullPointerException("StringIndexer组件的参数为空");
}
Map map = new HashMap<>();
SharedMethods.recordSchema(dataset.schema().fields(), map);
String[] columns = column.split(Consts.DELIMITER);
String[] results = indexerColumnName.split(Consts.DELIMITER);
if(columns.length != results.length) {
throw new CustomException("StringIdexer组件的列名和转换后的列名的数量不等");
}
Dataset data = dataset.toDF();
for(int i = 0; i < columns.length; i++) {
String c = columns[i].trim();
String r = results[i].trim();
if(c.length() == 0) {
logger.info("列名参数的第" + (i + 1) + "个参数去掉前后空格后为空,跳过");
continue;
}
if(!map.containsKey(c)) {
logger.info("数据集中找不到" + c + "列,跳过");
continue;
}
if(r.length() == 0) {
throw new CustomException("转换后的列名参数的第" + (i + 1) + "个参数去掉前后空格后为空");
}
if(map.containsKey(r)) {
throw new CustomException("转换后生成的列名" + r + "和现有列名冲突");
}
data = new StringIndexer()
.setInputCol(c)
.setOutputCol(r)
.fit(data)
.transform(data);
map.put(r, new Object[]{map.size() + 1, DataTypes.IntegerType});
}
return (Dataset) data;
}
}