All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.ucr.cs.bdlab.test.ScalaSparkTest.scala Maven / Gradle / Ivy

There is a newer version: 0.10.1-RC2
Show newest version
/*
 * Copyright 2020 University of California, Riverside
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.spark.test

import java.awt.Color
import java.awt.image.BufferedImage
import java.io._

import org.apache.hadoop.fs.{FileUtil, Path}
import org.apache.hadoop.io.Text
import org.apache.hadoop.util.LineReader
import org.apache.spark.{SparkConf, SparkContext}
import org.scalatest.{BeforeAndAfterEach, Suite}

object SingleSparkContext {
  // Wrap spark context in an object to make it singleton
  var _sc: SparkContext = _
}

/**
  * A mixin for Scala tests that creates a Spark context and adds methods to create an empty scratch directory
  */
trait ScalaSparkTest extends Suite with BeforeAndAfterEach {
  System.setProperty("org.geotools.referencing.forceXY", "true")

  // Initialization and resetting of a common spark context
  def sparkContext: SparkContext = {
    if (SingleSparkContext._sc == null) {
      val activeContext = SparkContext.getActive
      if (activeContext.isDefined)
        SingleSparkContext._sc = activeContext.get
      else
        SingleSparkContext._sc = new SparkContext("local", "test")
    }
    SingleSparkContext._sc
  }

  def closeSC(): Unit = {
    if (SingleSparkContext._sc != null) {
      SingleSparkContext._sc.stop()
      SingleSparkContext._sc = null
    }
  }

  /**A scratch directory that gets deleted before and after each test*/
  var _scratchDir: File = _

  def scratchPath: Path = new Path(_scratchDir.getPath)
  def scratchDir: File = _scratchDir

  def clearScratchDir: Unit = {
    if (_scratchDir != null) {
      FileUtil.fullyDelete(_scratchDir)
      if (_scratchDir.exists)
        _scratchDir.deleteOnExit()
    }
  }

  def resetScratchDir: Unit = {
    clearScratchDir
    var count = 0
    do {
      _scratchDir = new File(s"target/scratch_test-$count")
      count += 1
    } while (_scratchDir.exists())
    _scratchDir.mkdirs
    val thisScratch = _scratchDir.toString
    Runtime.getRuntime.addShutdownHook(new Thread(
      new Runnable() {
        override def run(): Unit = FileUtil.fullyDelete(new File(thisScratch))
      }
    ))
  }

  def resetSparkConfiguration: Unit = {
    if (SingleSparkContext._sc != null) {
      val sparkConf: SparkConf = SingleSparkContext._sc.conf
      SingleSparkContext._sc.hadoopConfiguration.clear()
      for (entry <- sparkConf.getAll) {
        if (entry._1 != "spark.app.id" && entry._1 != "spark.driver.host")
          sparkConf.remove(entry._1)
      }
      for (entry <- new SparkConf().setMaster("local").setAppName("test").getAll)
        sparkConf.set(entry._1, entry._2)
    }
  }

  override protected def beforeEach(): Unit = {
    resetScratchDir
    resetSparkConfiguration
  }

  override protected def afterEach(): Unit = clearScratchDir

  // Read text files in different ways

  /**
   * Reads up-to maxLines from the given input stream
   *
   * @param is       the input stream to read from
   * @param maxLines the upper bound of the number of lines to read
   * @return the lines as an array of Strings
   */
  def getLines(is: InputStream, maxLines: Int = Int.MaxValue): Array[String] = {
    val reader: LineReader = new LineReader(is)
    val line: Text = new Text
    var lines = List[String]()
    while (lines.length < maxLines && reader.readLine(line) > 0)
      lines = lines :+ line.toString
    lines.toArray
  }

  /**
   * Read the first n lines from the given resource and return those lines as an array of Strings.
   * If the given upper bound is bigger than the input file, the entire input file is loaded and returned.
   * Therefore, the returned array might be smaller than the given upper bound if the file is smaller.
   *
   * @param resourcePath the path to the resource to read
   * @param maxLines     the upper bound of the number of lines to read
   * @return an array of strings containing the lines read from the input file
   */
  def readTextResource(resourcePath: String, maxLines: Int): Array[String] = {
    var is: BufferedInputStream = null
    try {
      is = new BufferedInputStream(getClass.getResourceAsStream(resourcePath))
      getLines(is, maxLines)
    } finally if (is != null) is.close()
  }

  def readTextResource(resourcePath: String): Array[String] = readTextResource(resourcePath, Int.MaxValue)


  /**
   * Read a text file as a single big string.
   *
   * @param filename the name (or path) of the file
   * @return the contents of the file as one big String.
   */
  def readFile(filename: String): Array[String] = {
    var is: InputStream = null
    try {
      is = new BufferedInputStream(new FileInputStream(filename))
      getLines(is, Int.MaxValue)
    } finally if (is != null) is.close()
  }

  /**
   * Reads a CSV file that contains only numbers and returns the result as a two-dimensional array in a column format.
   * That is, if the input contains {@code n} lines and each line contains {@code d} columns, the return value is an
   * array of size {@code d} where each entry is an array of size {@code n}.
   *
   * @param resourcePath the path of the resource to read, typically under src/main/resources
   * @param maxLines     the maximum number of lines to read
   * @return a two-dimensional array where the first index is the dimension (column) and the second dimension is
   *         the line number (row)
   */
  def readCoordsResource(resourcePath: String, maxLines: Int): Array[Array[Double]] = {
    val lines = readTextResource(resourcePath, maxLines)
    val numDimensions = lines(0).split(",").length
    val coords = Array.ofDim[Double](numDimensions, lines.length)
    for (i <- 0 until lines.length) {
      val parts = lines(i).split(",")
      assert(numDimensions == parts.length,
        s"Number of dimensions differ ${parts.length} <> ${numDimensions} in line #$i")
      for (dim <- 0 until numDimensions) {
        coords(dim)(i) = parts(dim).toDouble
      }
    }
    coords
  }

  def readCoordsResource(resourcePath: String): Array[Array[Double]] = readCoordsResource(resourcePath, Int.MaxValue)


  /**
   * Read a resource as a raw byte array
   *
   * @param resourcePath the path to the resource
   * @return the contents of that resource as a byte array
   */
  def readResourceData(resourcePath: String): Array[Byte] = {
    val in = getClass.getResourceAsStream(resourcePath)
    try {
      var totalSize = 0
      var buffers = Seq[Array[Byte]]()
      var eofReached = false
      do {
        var buffer = new Array[Byte](4096)
        val bufferSize = in.read(buffer)
        if (bufferSize > 0) {
          if (bufferSize != buffer.length)
            buffer = buffer.slice(0, bufferSize)
          buffers = buffers :+ buffer
          totalSize += bufferSize
        }
        eofReached = bufferSize == -1
      } while ( {
        !eofReached
      })
      val allData = new Array[Byte](totalSize)
      var offset = 0
      for (buffer <- buffers) {
        System.arraycopy(buffer, 0, allData, offset, buffer.length)
        offset += buffer.length
      }
      assert(offset == totalSize)
      allData
    } finally in.close()
  }

  // Make copies of files for test

  /**
   * Copy a resource to a temporary file to allow reading it as a file.
   *
   * @param resourcePath the path of the resource to read
   * @param filePath     the path of the file to write
   * @param overwrite    set this flag to automatically overwrite the output file.
   */
  def copyResource(resourcePath: String, filePath: File, overwrite: Boolean): Unit = {
    if (!overwrite && filePath.exists) fail("Cannot overwrite an existing file " + filePath)
    // Create directory
    if (!new File(filePath.getParent).exists) new File(filePath.getParent).mkdirs
    var in: InputStream = null
    var out: OutputStream = null
    try {
      val buffer = new Array[Byte](1024 * 1024)
      in = getClass.getResourceAsStream(resourcePath)
      out = new FileOutputStream(filePath)
      var bufferLength: Int = 0
      do {
        bufferLength = in.read(buffer, 0, buffer.length)
        if (bufferLength > 0)
          out.write(buffer, 0, bufferLength)
      } while(bufferLength > 0)
    } finally {
      if (in != null) in.close()
      if (out != null) out.close()
    }
  }

  def copyResource(resourcePath: String, filePath: File): Unit =
    copyResource(resourcePath, filePath, false)

  /**
   * Copies an entire directory from the resource path to the given local directory.
   *
   * @param resourcePath a path to a directory in the class path (can be inside the JAR file)
   * @param localDirPath a path to local directory to write to.
   */
  def copyDirectoryFromResources(resourcePath: String, localDirPath: File): Unit = {
    val resPath = if (!resourcePath.endsWith("/")) resourcePath + "/" else resourcePath
    val inputFiles = new BufferedReader(new InputStreamReader(getClass.getResourceAsStream(resPath)))
    var file: String = null
    if (!localDirPath.exists) localDirPath.mkdirs
    do {
      file = inputFiles.readLine
      if (file != null)
        copyResource(resPath + file, new File(localDirPath, file))
    } while (file != null)
  }

  def makeFileCopy(resourcePath: String) = {
    val localPath: File = new File(_scratchDir, new File(resourcePath).getName)
    copyResource(resourcePath, localPath)
    localPath
  }

  def makeDirCopy(resourcePath: String) = {
    val localPath: File = new File(_scratchDir, new File(resourcePath).getName)
    copyDirectoryFromResources(resourcePath, localPath)
    localPath
  }



  /**
   * Read all non-empty files in the given directory as an array of strings, one for each line.
   * @param dir the directory to read
   * @return one string array that combines the contents of all files.
   */
  def readFilesInDirAsLines(dir: File) : Array[String] = {
    var lines = List[String]()
    dir.listFiles()
      .filterNot(p => p.getName.startsWith(".") || p.getName.startsWith("_"))
      .foreach(f => {
        if (f.length() > 0) {
          lines = List.concat(lines, readFile(f.getPath).toList)
        }
      })
    lines.toArray
  }

  // New assertions

  /**
   * Tests the contents of two input stream for equality. The bytes are consumed from both streams while testing them.
   *
   * @param expected the correct (expected) stream
   * @param actual   the stream produced from the code
   * @throws IOException if an error happens while reading one of the input stream
   */
  @throws[IOException]
  def assertEquals(expected: InputStream, actual: InputStream): Unit = { // The last offset that was compared
    var offsetCompared = 0
    val expectedBuffer = new Array[Byte](1024)
    val actualBuffer = new Array[Byte](1024)
    var expectedLength = 0
    var actualLength = 0
    while ( {
      true
    }) {
      var readLength = expected.read(expectedBuffer, expectedLength, expectedBuffer.length - expectedLength)
      if (readLength > 0) expectedLength += readLength
      readLength = actual.read(actualBuffer, actualLength, actualBuffer.length - actualLength)
      if (readLength > 0) actualLength += readLength
      if ((expectedLength == 0) ^ (actualLength == 0)) fail("File lengths not equal!")
      if (expectedLength == 0 && actualLength == 0) return
      for (i <- 0 until Math.min(expectedLength, actualLength)) {
        assert(expectedBuffer(i) == actualBuffer(i), "Contents differ at byte: " + offsetCompared)
        offsetCompared += 1
      }
      // Remove the compared bytes from the buffers
      if (expectedLength < actualLength) {
        System.arraycopy(actualBuffer, expectedLength, actualBuffer, 0, actualLength - expectedLength)
        actualLength -= expectedLength
        expectedLength = 0
      }
      else {
        System.arraycopy(expectedBuffer, actualLength, expectedBuffer, 0, expectedLength - actualLength)
        expectedLength -= actualLength
        actualLength = 0
      }
    }
  }

  /**
   * Tests two arrays for equality
   *
   * @param expected the expected array
   * @param actual   the actual array produced by the code
   */
  def assertArrayEqualsGeneric(expected: Array[Any], actual: Array[Any]): Unit = {
    assert(expected.length == actual.length,
      s"Array lengths differ. Expected ${expected.length} but found ${actual.length}")
    for (i <- expected.indices) {
      assert(expected(i) == actual(i),
        s"The arrays differ at position $i. Expected ${expected(i).toString} but found ${actual(i).toString}")
    }
  }

  def assertArrayEquals(expected: Array[Int], actual: Array[Int]): Unit = {
    assert(expected.length == actual.length,
      s"Array lengths differ. Expected ${expected.length} but found ${actual.length}")
    for (i <- expected.indices) {
      assert(expected(i) == actual(i),
        s"The arrays differ at position $i. Expected ${expected(i).toString} but found ${actual(i).toString}")
    }
  }


  def assertArrayEquals(expected: Array[Float], actual: Array[Float], tolerance: Float): Unit = {
    assert(expected.length == actual.length,
      s"Array lengths differ. Expected ${expected.length} but found ${actual.length}")
    for (i <- expected.indices) {
      assert(expected(i) == actual(i),
        s"The arrays differ at position $i. Expected ${expected(i).toString} but found ${actual(i).toString}")
    }
  }

  /**
   * Tests two images for equality (pixel-by-pixel)
   *
   * @param expected the first image (typically, the value that is known to be correct)
   * @param actual   the second image (typically, the value retrieved from the code)
   */
  def assertImageEquals(expected: BufferedImage, actual: BufferedImage): Unit =
    assertImageEquals(null, expected, actual)

  def assertImageEquals(message: String, expected: BufferedImage, actual: BufferedImage): Unit = {
    // Uncomment the following block to write both images for a visual comparison
    /*
    try {
      javax.imageio.ImageIO.write(expected, "png", new File("expected.png"));
      javax.imageio.ImageIO.write(actual, "png", new File("actual.png"));
    } catch {
      case e: IOException => e.printStackTrace();
    }
    */
    assert(expected.getWidth == actual.getWidth,
      s"Width is not compatible: Expected ${expected.getWidth} and actual ${actual.getWidth}")
    assert(expected.getHeight == actual.getHeight,
      s"Height is not compatible: Expected ${expected.getHeight} and actual ${actual.getHeight}")
    for (x <- 0 until expected.getWidth; y <- 0 until expected.getHeight()) {
      val expectedPixel = expected.getRGB(x, y)
      val actualPixel = actual.getRGB(x, y)
      assert(expectedPixel == actualPixel, s"$message. Pixels differ at ($x, $y): " +
        s"Expected ${new Color(expectedPixel, true)} but found ${new Color(actualPixel, true)}")
    }
  }

  /**
   * An assertion for the existence of a file (or a directory)
   *
   * @param filePath path to a local file
   */
  protected def assertFileExists(filePath: String): Unit =
    assert(new File(filePath).exists, String.format("File '%s' does not exist", filePath))
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy