com.datastax.data.prepare.spark.dataset.ColumnReduce.scala Maven / Gradle / Ivy
The newest version!
package com.datastax.data.prepare.spark.dataset
import com.datastax.data.prepare.spark.dataset.params.ColumnReduceParam
import com.datastax.data.prepare.util.{Consts, CustomException}
import org.apache.spark.mllib.linalg.DenseVector
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Dataset, Row}
import org.slf4j.LoggerFactory
import scala.collection.mutable
import scala.collection.mutable.ListBuffer
object ColumnReduce {
private val logger = LoggerFactory.getLogger(ColumnReduce.getClass)
//todo 可能不能将Vector类型和别的数值型拼接在一起,生成vector类型,有待验证
def merge(data: Dataset[Row], columnReduceParams: Array[ColumnReduceParam]): Dataset[Row] = {
var schema = data.schema
val fieldMap = record(schema.fields, "数据集中的列名%s重复")((fs, i) => (fs.name, Tuple2(i, fs.dataType)))
val paramMap = new mutable.HashMap[String, Int]
for(i <- columnReduceParams.indices) {
val newName = columnReduceParams(i).getNewColumnName
require(!fieldMap.contains(newName), "生成的新列名" + newName + "和数据集中存在的列名重复")
require(!paramMap.contains(newName), "合并生成的新列名中" + newName + "重复")
paramMap += (newName -> (i + 1))
schema = schema.add(columnReduceParams(i).getNewColumnName, columnReduceParams(i).getDataType, true)
}
val recordPositions = new mutable.HashMap[Int, Array[Int]]
val references = new mutable.HashMap[Int, (List[Int], Array[Int])]
for(i <- columnReduceParams.indices) {
val temp = columnReduceParams(i).getMergeColumns.split(Consts.DELIMITER)
val positions = new Array[Int](temp.length)
var flag = true
val listBuffer = ListBuffer.empty[Int]
for(j <- temp.indices) {
if(fieldMap.contains(temp(j))) {
positions(j) = fieldMap(temp(j))._1
} else {
if(paramMap.contains(temp(j))) {
require(paramMap(temp(j)) != (i + 1), "参数循环嵌套")
positions(j) = -paramMap(temp(j))
listBuffer += paramMap(temp(j))
flag = false
} else {
throw new CustomException("第" + (j + 1) + "列参数中的合并列参数中的" + temp(j) + "列不存在")
}
}
}
if(flag) {
recordPositions += (i -> positions)
} else {
references += ((i + 1) -> Tuple2(listBuffer.toList, positions))
}
}
require(recordPositions.nonEmpty, "互相引用的参数没有初始入口")
val executionOrder = new ListBuffer[Int]
if(recordPositions.size < columnReduceParams.length) {
for(i <- references.keySet) {
if(!executionOrder.contains(i)) {
getExecutionOrder(references, i, executionOrder, new ListBuffer[Int])
}
}
}
val resultOrder = executionOrder.toList
val encoder = RowEncoder(schema)
val result = data.map(r => {
val array = new Array[Any](schema.fields.length)
System.arraycopy(r.toSeq.toArray, 0, array, 0, r.toSeq.length)
for(i <- recordPositions.keySet) {
// columnReduceParams(i).getDataType match {
// case StringType =>
// array(i + r.size) = (for{j <- recordPositions(i).indices } yield r.get(recordPositions(i)(j))).mkString(columnReduceParams(i).getConnectSymbol)
// case DoubleType =>
// array(i + r.size) = (for{j <- recordPositions(i).indices } yield r.get(recordPositions(i)(j))).mkString(columnReduceParams(i).getConnectSymbol).asInstanceOf[Double]
// case DateType =>
// array(i + r.size) = (for{j <- recordPositions(i).indices } yield r.get(recordPositions(i)(j))).mkString(columnReduceParams(i).getConnectSymbol).asInstanceOf[java.sql.Date]
// case TimestampType =>
// array(i + r.size) = (for{j <- recordPositions(i).indices } yield r.get(recordPositions(i)(j))).mkString(columnReduceParams(i).getConnectSymbol).asInstanceOf[java.sql.Timestamp]
// case _ =>
// array(i + r.size) = new DenseVector((for{j <- recordPositions(i).indices } yield r.get(recordPositions(i)(j)).asInstanceOf[Double]).toArray)
// }
array(i + r.size) = convertType(for{j <- recordPositions(i).indices } yield r.get(recordPositions(i)(j)), columnReduceParams(i).getDataType, columnReduceParams(i).getConnectSymbol)
}
if(resultOrder.nonEmpty) {
for(i <- resultOrder.indices) {
val temp = references(resultOrder(i))._2
val tempArray = new Array[Any](temp.length)
// val buffer = new mutable.StringBuilder()
for(j <- temp.indices) {
var t = temp(j)
if(t < 0) {
t = r.size + (-t) - 1
}
if(t == r.size) {
tempArray(j) = array(t)
} else {
tempArray(j) = r.get(t)
}
// if(j < temp.length - 1) {
// buffer.append(columnReduceParams(i).getConnectSymbol)
// }
}
// array(resultOrder(i) + r.size - 1) = buffer.toString()
array(resultOrder(i) + r.size - 1) = convertType(tempArray, columnReduceParams(i).getDataType, columnReduceParams(i).getConnectSymbol)
}
}
Row.fromSeq(array)
})(encoder)
result
}
private def record[T1, T2, T3](as: Array[T1], errorMsg: String)(func: (T1, Int) => (T2, T3)): mutable.HashMap[T2, T3] = {
val map = new mutable.HashMap[T2, T3]
for(i <- as.indices) {
val t = func(as(i), i)
require(!map.contains(t._1), errorMsg.format(t._1))
map += t
}
map
}
private def getExecutionOrder(references: mutable.HashMap[Int, (List[Int], Array[Int])], position: Int, executionOrder: ListBuffer[Int], sign: ListBuffer[Int]): Unit = {
if(sign.size == references.size) {
throw new CustomException("互相引用的参数没有初始入口")
}
val list = references(position)._1
var flag = true
for(i <- list.indices) {
if(references.contains(list(i)) && !executionOrder.contains(list(i))) {
if(sign.contains(position)) {
throw new CustomException("参数循环嵌套")
} else {
sign += position
}
flag = false
getExecutionOrder(references, list(i), executionOrder, sign)
}
if(executionOrder.contains(list(i))) {
flag = true
}
}
if(flag) {
executionOrder += position
} else {
throw new CustomException("未考虑到的错误")
}
}
private def convertType[T](indexedSeq: IndexedSeq[T], dataType: DataType, connectSymbol: String): Any = dataType match {
case StringType =>
indexedSeq.mkString(connectSymbol)
case DoubleType =>
indexedSeq.mkString(connectSymbol).toDouble
// case DateType => 行不通,直接设为String类型,通过类型转换组件转格式
// indexedSeq.mkString(connectSymbol).asInstanceOf[java.sql.Date]
// case TimestampType =>
// indexedSeq.mkString(connectSymbol).asInstanceOf[java.sql.Timestamp]
case _ =>
new DenseVector((for{i <- indexedSeq.indices } yield indexedSeq(i).toString.toDouble).toArray)
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy