All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.bigdl.optim.Validator.scala Maven / Gradle / Ivy

There is a newer version: 0.11.1
Show newest version
/*
 * Copyright 2016 The BigDL Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.intel.analytics.bigdl.optim

import com.intel.analytics.bigdl._
import com.intel.analytics.bigdl.utils.Engine
import com.intel.analytics.bigdl.dataset.{DistributedDataSet, LocalDataSet, MiniBatch}
import org.apache.log4j.Logger

/**
 * [[Validator]] is an abstract class which is used to test a model automatically
 * with some certain validation methods such as [[Top1Accuracy]], as an argument of
 * its `test` method.
 *
 * @param model the model to be validated
 * @param dataSet the data set used to validate a model
 * @tparam T numeric type, which can be [[Float]] or [[Double]]
 * @tparam D the type of elements in DataSet, such as [[MiniBatch]]
 */
abstract class Validator[T, D](
  model: Module[T],
  dataSet: DataSet[D]
) {
  def test(vMethods: Array[ValidationMethod[T]]): Array[(ValidationResult, ValidationMethod[T])]
}
@deprecated(
  "Validator(model, dataset) is deprecated. Please use model.evaluate instead",
  "0.2.0")
object Validator {
  private val logger = Logger.getLogger(getClass)

  def apply[T, D](model: Module[T], dataset: DataSet[D]): Validator[T, D] = {
    logger.warn("Validator(model, dataset) is deprecated. Please use model.evaluate instead")
    dataset match {
      case d: DistributedDataSet[_] =>
        new DistriValidator[T](
          model = model,
          dataSet = d.asInstanceOf[DistributedDataSet[MiniBatch[T]]]
        ).asInstanceOf[Validator[T, D]]
      case d: LocalDataSet[_] =>
        new LocalValidator[T](
          model = model,
          dataSet = d.asInstanceOf[LocalDataSet[MiniBatch[T]]]
        ).asInstanceOf[Validator[T, D]]
      case _ =>
        throw new UnsupportedOperationException
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy