All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.bigdl.optim.EvaluateMethods.scala Maven / Gradle / Ivy

/*
 * Copyright 2016 The BigDL Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.intel.analytics.bigdl.optim

import com.intel.analytics.bigdl.tensor.Tensor

object EvaluateMethods {
  def calcAccuracy[T](output: Tensor[T], target: Tensor[T]): (Int, Int) = {
    var correct = 0
    var count = 0
    if (output.dim() == 2) {
      output.max(2)._2.squeeze().map(target, (a, b) => {
        if (a == b) {
          correct += 1
        }
        a
      })
      count += output.size(1)
    } else if (output.dim == 1) {
      require(target.size(1) == 1)
      output.max(1)._2.map(target, (a, b) => {
        if (a == b) {
          correct += 1
        }
        a
      })
      count += 1
    } else {
      throw new IllegalArgumentException
    }

    (correct, count)
  }

  def calcTop5Accuracy[T](output: Tensor[T], target: Tensor[T]): (Int, Int) = {
    var correct = 0
    var count = 0
    if (output.dim() == 2) {
      val indices = output.topk(5, 2, false)._2
      var i = 1
      while (i <= output.size(1)) {
        if (indices.valueAt(i, 1) == target.valueAt(i)
          || indices.valueAt(i, 2) == target.valueAt(i)
          || indices.valueAt(i, 3) == target.valueAt(i)
          || indices.valueAt(i, 4) == target.valueAt(i)
          || indices.valueAt(i, 5) == target.valueAt(i)) {
          correct += 1
        }
        i += 1
      }
      count += output.size(1)
    } else if (output.dim == 1) {
      require(target.size(1) == 1)
      val indices = output.topk(5, 1, false)._2
      if (indices.valueAt(1) == target.valueAt(1) || indices.valueAt(2) == target.valueAt(1)
        || indices.valueAt(3) == target.valueAt(1) || indices.valueAt(4) == target.valueAt(1)
        || indices.valueAt(5) == target.valueAt(1)) {
        correct += 1
      }
      count += 1
    } else {
      throw new IllegalArgumentException
    }

    (correct, count)
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy