com.salesforce.op.test.OpEstimatorSpec.scala Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of transmogrifai-features_2.11 Show documentation
Show all versions of transmogrifai-features_2.11 Show documentation
AutoML library for building modular, reusable, strongly typed machine learning workflows on Spark with minimal hand tuning
The newest version!
/*
* Copyright (c) 2017, Salesforce.com, Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* * Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* * Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* * Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
package com.salesforce.op.test
import java.io.File
import com.salesforce.op.features.types._
import com.salesforce.op.stages._
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.sql.Dataset
import org.scalactic.Equality
import org.scalatest.events.{Event, TestFailed}
import org.scalatest.{Args, Reporter}
import scala.collection.mutable.ArrayBuffer
import scala.reflect._
import scala.reflect.runtime.universe._
/**
* Base test class for testing OP estimator instances.
* Includes common tests for fitting estimator and verifying the fitted model.
*
* @tparam O output feature type
* @tparam ModelType model type produced by this estimator
* @tparam EstimatorType type of the estimator being tested
*/
abstract class OpEstimatorSpec[O <: FeatureType : WeakTypeTag : ClassTag,
ModelType <: Model[ModelType] with OpPipelineStage[O] with OpTransformer : ClassTag,
EstimatorType <: Estimator[ModelType] with OpPipelineStage[O] : ClassTag]
extends OpPipelineStageSpec[O, EstimatorType] {
/**
* Input Dataset to fit & transform
*/
val inputData: Dataset[_]
/**
* Estimator instance to be tested
*/
val estimator: EstimatorType
/**
* Expected result of the transformer applied on the Input Dataset
*/
val expectedResult: Seq[O]
final override lazy val stage = estimator
/**
* Model (transformer) to fit
*/
final lazy val model: ModelType = estimator.fit(inputData)
it should "fit a model" in {
model should not be null
model shouldBe a[ModelType]
}
it should behave like modelSpec()
it should "have fitted a model that matches the estimator" in {
withClue("Model doesn't have a parent:") {
model.hasParent shouldBe true
}
withClue("Model parent should be the original estimator instance:") {
model.parent shouldBe estimator
}
withClue("Model and estimator output feature names don't match:") {
model.getOutputFeatureName shouldBe estimator.getOutputFeatureName
}
assert(model.asInstanceOf[OpPipelineStageBase], estimator, expectSameClass = false)
}
// TODO: test metadata
/**
* Register all model spec tests
*/
private def modelSpec(): Unit = {
// Define transformer spec for the fitted model reusing the same inputs & Spark context
val modelSpec = new OpTransformerSpec[O, ModelType] {
override implicit val featureTypeEquality: Equality[O] = OpEstimatorSpec.this.featureTypeEquality
override implicit val seqEquality: Equality[Seq[O]] = OpEstimatorSpec.this.seqEquality
lazy val transformer: ModelType = OpEstimatorSpec.this.model
lazy val inputData: Dataset[_] = OpEstimatorSpec.this.inputData
lazy val expectedResult: Seq[O] = OpEstimatorSpec.this.expectedResult
override implicit lazy val spark = OpEstimatorSpec.this.spark
override def specName: String = "model"
override def tempDir: File = OpEstimatorSpec.this.tempDir
}
// Register all model spec tests
for {
testName <- modelSpec.testNames
} registerTest(testName) {
// Run test & collect failures
val failures = ArrayBuffer.empty[TestFailed]
val reporter = new Reporter {
def apply(event: Event): Unit = event match {
case f: TestFailed => failures += f
case _ =>
}
}
// Note: We set 'runTestInNewInstance = true' to avoid restarting Spark context on every test run
val args = Args(reporter, runTestInNewInstance = true)
modelSpec.run(testName = Some(testName), args = args)
// Propagate the failure if any
for {failure <- failures.headOption} {
failure.throwable.map(fail(failure.message, _)).getOrElse(fail(failure.message))
}
}
}
}