All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.datastax.data.prepare.spark.dataset.ColumnReduce.scala Maven / Gradle / Ivy

package com.datastax.data.prepare.spark.dataset

import com.datastax.data.prepare.spark.dataset.params.ColumnReduceParam
import com.datastax.data.prepare.util.{Consts, CustomException}
import org.apache.spark.mllib.linalg.DenseVector
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Dataset, Row}
import org.slf4j.LoggerFactory

import scala.collection.mutable
import scala.collection.mutable.ListBuffer


object ColumnReduce {
  private val logger = LoggerFactory.getLogger(ColumnReduce.getClass)

  //todo 可能不能将Vector类型和别的数值型拼接在一起,生成vector类型,有待验证
  def merge(data: Dataset[Row], columnReduceParams: Array[ColumnReduceParam]): Dataset[Row] = {
    var schema = data.schema
    val fieldMap = record(schema.fields, "数据集中的列名%s重复")((fs, i) => (fs.name, Tuple2(i, fs.dataType)))
    val paramMap = new mutable.HashMap[String, Int]
    for(i <- columnReduceParams.indices) {
      val newName = columnReduceParams(i).getNewColumnName
      require(!fieldMap.contains(newName), "生成的新列名" + newName + "和数据集中存在的列名重复")
      require(!paramMap.contains(newName), "合并生成的新列名中" + newName + "重复")
      paramMap += (newName -> (i + 1))
      schema = schema.add(columnReduceParams(i).getNewColumnName, columnReduceParams(i).getDataType, true)
    }

    val recordPositions = new mutable.HashMap[Int, Array[Int]]
    val references = new mutable.HashMap[Int, (List[Int], Array[Int])]
    for(i <- columnReduceParams.indices) {
      val temp = columnReduceParams(i).getMergeColumns.split(Consts.DELIMITER)
      val positions = new Array[Int](temp.length)
      var flag = true
      val listBuffer = ListBuffer.empty[Int]
      for(j <- temp.indices) {
        if(fieldMap.contains(temp(j))) {
          positions(j) = fieldMap(temp(j))._1
        } else {
          if(paramMap.contains(temp(j))) {
            require(paramMap(temp(j)) != (i + 1), "参数循环嵌套")
            positions(j) = -paramMap(temp(j))
            listBuffer += paramMap(temp(j))
            flag = false
          } else {
            throw new CustomException("第" + (j + 1) + "列参数中的合并列参数中的" + temp(j) + "列不存在")
          }
        }
      }
      if(flag) {
        recordPositions += (i -> positions)
      } else {
        references += ((i + 1) -> Tuple2(listBuffer.toList, positions))
      }
    }
    require(recordPositions.nonEmpty, "互相引用的参数没有初始入口")
    val executionOrder = new ListBuffer[Int]
    if(recordPositions.size < columnReduceParams.length) {
      for(i <- references.keySet) {
        if(!executionOrder.contains(i)) {
          getExecutionOrder(references, i, executionOrder, new ListBuffer[Int])
        }
      }
    }
    val resultOrder = executionOrder.toList

    val encoder = RowEncoder(schema)
    val result = data.map(r => {
      val array = new Array[Any](schema.fields.length)
      System.arraycopy(r.toSeq.toArray, 0, array, 0, r.toSeq.length)
      for(i <- recordPositions.keySet) {
//        columnReduceParams(i).getDataType match {
//          case StringType =>
//            array(i + r.size) = (for{j <- recordPositions(i).indices } yield r.get(recordPositions(i)(j))).mkString(columnReduceParams(i).getConnectSymbol)
//          case DoubleType =>
//            array(i + r.size) = (for{j <- recordPositions(i).indices } yield r.get(recordPositions(i)(j))).mkString(columnReduceParams(i).getConnectSymbol).asInstanceOf[Double]
//          case DateType =>
//            array(i + r.size) = (for{j <- recordPositions(i).indices } yield r.get(recordPositions(i)(j))).mkString(columnReduceParams(i).getConnectSymbol).asInstanceOf[java.sql.Date]
//          case TimestampType =>
//            array(i + r.size) = (for{j <- recordPositions(i).indices } yield r.get(recordPositions(i)(j))).mkString(columnReduceParams(i).getConnectSymbol).asInstanceOf[java.sql.Timestamp]
//          case _ =>
//            array(i + r.size) = new DenseVector((for{j <- recordPositions(i).indices } yield r.get(recordPositions(i)(j)).asInstanceOf[Double]).toArray)
//        }

        array(i + r.size) = convertType(for{j <- recordPositions(i).indices } yield r.get(recordPositions(i)(j)), columnReduceParams(i).getDataType, columnReduceParams(i).getConnectSymbol)

      }
      if(resultOrder.nonEmpty) {
        for(i <- resultOrder.indices) {
          val temp = references(resultOrder(i))._2
          val tempArray = new Array[Any](temp.length)
//          val buffer = new mutable.StringBuilder()
          for(j <- temp.indices) {
            var t = temp(j)
            if(t < 0) {
              t = r.size + (-t) - 1
            }
            if(t == r.size) {
              tempArray(j) = array(t)
            } else {
              tempArray(j) = r.get(t)
            }
//            if(j < temp.length - 1) {
//              buffer.append(columnReduceParams(i).getConnectSymbol)
//            }
          }
//          array(resultOrder(i) + r.size - 1) = buffer.toString()
          array(resultOrder(i) + r.size - 1) = convertType(tempArray, columnReduceParams(i).getDataType, columnReduceParams(i).getConnectSymbol)
        }
      }
      Row.fromSeq(array)
    })(encoder)
    result
  }

  private def record[T1, T2, T3](as: Array[T1], errorMsg: String)(func: (T1, Int) => (T2, T3)): mutable.HashMap[T2, T3] = {
    val map = new mutable.HashMap[T2, T3]
    for(i <- as.indices) {
      val t = func(as(i), i)
      require(!map.contains(t._1), errorMsg.format(t._1))
      map += t
    }
    map
  }

  private def getExecutionOrder(references: mutable.HashMap[Int, (List[Int], Array[Int])], position: Int, executionOrder: ListBuffer[Int], sign: ListBuffer[Int]): Unit = {
    if(sign.size == references.size) {
      throw new CustomException("互相引用的参数没有初始入口")
    }
    val list = references(position)._1
    var flag = true
    for(i <- list.indices) {
      if(references.contains(list(i)) && !executionOrder.contains(list(i))) {
        if(sign.contains(position)) {
          throw new CustomException("参数循环嵌套")
        } else {
          sign += position
        }
        flag = false
        getExecutionOrder(references, list(i), executionOrder, sign)
      }
      if(executionOrder.contains(list(i))) {
        flag = true
      }
    }
    if(flag) {
      executionOrder += position
    } else {
      throw new CustomException("未考虑到的错误")
    }
  }

  private def convertType[T](indexedSeq: IndexedSeq[T], dataType: DataType, connectSymbol: String): Any = dataType match {
    case StringType =>
      indexedSeq.mkString(connectSymbol)
    case DoubleType =>
      indexedSeq.mkString(connectSymbol).toDouble
//    case DateType =>    行不通,直接设为String类型,通过类型转换组件转格式
//      indexedSeq.mkString(connectSymbol).asInstanceOf[java.sql.Date]
//    case TimestampType =>
//      indexedSeq.mkString(connectSymbol).asInstanceOf[java.sql.Timestamp]
    case _ =>
      new DenseVector((for{i <- indexedSeq.indices } yield indexedSeq(i).toString.toDouble).toArray)
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy