All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.holdenkarau.spark.testing.DataFrameGenerator.scala Maven / Gradle / Ivy

The newest version!
package com.holdenkarau.spark.testing

import java.math.{RoundingMode}
import java.sql.{Date, Timestamp}

import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
import org.scalacheck.{Arbitrary, Gen}

object DataFrameGenerator {

  /**
   * Creates a DataFrame Generator for the given Schema.
   *
   * @param sqlContext    SQL Context.
   * @param schema        The required Schema.
   * @param minPartitions minimum number of partitions, defaults to 1.
   * @return Arbitrary DataFrames generator of the required schema.
   */
  def arbitraryDataFrame(
    sqlContext: SQLContext, schema: StructType, minPartitions: Int = 1):
      Arbitrary[DataFrame] = {
    arbitraryDataFrameWithCustomFields(sqlContext, schema, minPartitions)()
  }

  /**
   * Creates a DataFrame Generator for the given Schema, and the given custom
   * generators.
   * Custom generators should be specified as a list of:
   * (column index, generator function) tuples.
   *
   * Note: The given custom generators should match the required schema,
   * for ex. you can't use Int generator for StringType.
   *
   * Note 2: The ColumnGenerator* accepted as userGenerators has changed.
   * ColumnGenerator is now the base class of the
   * accepted generators, users upgrading to 0.6 need to change their calls
   * to use Column.  Further explanation can be found in the release notes, and
   * in the class descriptions at the bottom of this file.
   *
   * @param sqlContext     SQL Context.
   * @param schema         The required Schema.
   * @param minPartitions  minimum number of partitions, defaults to 1.
   * @param userGenerators custom user generators in the form of:
   *                       (column index, generator function).
   *                       where column index starts from 0 to length - 1
   * @return Arbitrary DataFrames generator of the required schema.
   */
  def arbitraryDataFrameWithCustomFields(
    sqlContext: SQLContext, schema: StructType, minPartitions: Int = 1)
    (userGenerators: ColumnGeneratorBase*): Arbitrary[DataFrame] = {
    import sqlContext._

    val arbitraryRDDs = RDDGenerator.genRDD(
      sqlContext.sparkContext, minPartitions)(
      getRowGenerator(schema, userGenerators))
    Arbitrary {
      arbitraryRDDs.map { r =>
        sqlContext.createDataFrame(r, schema)
      }
    }
  }

  /**
   * Creates row generator for the required schema.
   *
   * @param schema the required Row's schema.
   * @return Gen[Row]
   */
  def getRowGenerator(schema: StructType): Gen[Row] = {
    getRowGenerator(schema, List())
  }

  /**
   * Creates row generator for the required schema and with user's custom generators.
   *
   * Note: Custom generators should match the required schema, for example
   * you can't use Int generator for StringType.
   *
   * @param schema           the required Row's schema.
   * @param customGenerators user custom generator, this is useful if the you want
   *                         to control specific columns generation.
   * @return Gen[Row]
   */
  def getRowGenerator(
    schema: StructType, customGenerators: Seq[ColumnGeneratorBase]): Gen[Row] = {
    val generators: List[Gen[Any]] =
      createGenerators(schema.fields, customGenerators)
    val listGen: Gen[List[Any]] =
      Gen.sequence[List[Any], Any](generators)
    val generator: Gen[Row] =
      listGen.map(list => Row.fromSeq(list))
    generator
  }

  private def createGenerators(
    fields: Array[StructField],
    userGenerators: Seq[ColumnGeneratorBase]):
      List[Gen[Any]] = {
    val generatorMap = userGenerators.map(
      generator => (generator.columnName -> generator)).toMap
    fields.toList.map { field =>
    if (generatorMap.contains(field.name)) {
        generatorMap.get(field.name) match {
          case Some(gen: ColumnGenerator) => gen.gen
          case Some(list: ColumnList) => getGenerator(field.dataType, list.gen, nullable = field.nullable)
        }
      }
      else getGenerator(field.dataType, nullable = field.nullable)
    }
  }

  private def getGenerator(
    dataType: DataType,
    generators: Seq[ColumnGeneratorBase] = Seq(),
    nullable: Boolean = false): Gen[Any] = {
    val nonNullGen = dataType match {
      case ByteType => Arbitrary.arbitrary[Byte]
      case ShortType => Arbitrary.arbitrary[Short]
      case IntegerType => Arbitrary.arbitrary[Int]
      case LongType => Arbitrary.arbitrary[Long]
      case FloatType => Arbitrary.arbitrary[Float]
      case DoubleType => Arbitrary.arbitrary[Double]
      case StringType => Arbitrary.arbitrary[String]
      case BinaryType => Arbitrary.arbitrary[Array[Byte]]
      case BooleanType => Arbitrary.arbitrary[Boolean]
      case TimestampType => Arbitrary.arbitrary[Long].map{
        l => new Timestamp(l/10000)
      }
      case DateType => Arbitrary.arbLong.arbitrary.map{
        l => new Date(l/10000)
      }
      case dec: DecimalType => {
        // With the new ANSI default we need to make sure were passing in
        // valid values.
        Arbitrary.arbitrary[BigDecimal]
          .retryUntil { d =>
            try {
              val sd = new Decimal()
              // Make sure it can be converted
              sd.set(d, dec.precision, dec.scale)
              true
            } catch {
              case e: Exception => false
            }
          }
          .map(_.bigDecimal.setScale(dec.scale, RoundingMode.HALF_UP))
          .asInstanceOf[Gen[java.math.BigDecimal]]
      }
      case arr: ArrayType => {
        val elementGenerator = getGenerator(arr.elementType, nullable = arr.containsNull)
        Gen.listOf(elementGenerator)
      }
      case map: MapType => {
        val keyGenerator = getGenerator(map.keyType)
        val valueGenerator = getGenerator(map.valueType, nullable = map.valueContainsNull)
        val keyValueGenerator: Gen[(Any, Any)] = for {
          key <- keyGenerator
          value <- valueGenerator
        } yield (key, value)

        Gen.mapOf(keyValueGenerator)
      }
      case row: StructType => getRowGenerator(row, generators)
      case MLUserDefinedType(generator) => generator
      case _ => throw new UnsupportedOperationException(
        s"Type: $dataType not supported")
    }
    if (nullable) {
      // Spark 3.1+ has difficulty with arrays that are null.
      dataType match {
        case arr: ArrayType => nonNullGen
        case _ =>
          Gen.oneOf(nonNullGen, Gen.const(null))
      }
    } else {
      nonNullGen
    }
  }
}

/**
 * Previously Column. Allows the user to specify a generator for a
 * specific column.
 */
class ColumnGenerator(val columnName: String, generator: => Gen[Any])
    extends ColumnGeneratorBase {
  lazy val gen = generator
}

/**
 * ColumnList allows users to specify custom generators for a list of
 * columns inside a StructType column.
 */
class ColumnList(val columnName: String, generators: => Seq[ColumnGeneratorBase])
    extends ColumnGeneratorBase {
  lazy val gen = generators
}

/**
 * ColumnGenerator - prevously Column; it is now the base class for all
 * ColumnGenerators.
 */
abstract class ColumnGeneratorBase extends java.io.Serializable {
  val columnName: String
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy