All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.zoo.serving.postprocessing.PostProcessing.scala Maven / Gradle / Ivy

/*
 * Copyright 2018 Analytics Zoo Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.intel.analytics.zoo.serving.postprocessing

import java.util.Base64

import com.intel.analytics.bigdl.nn.abstractnn.Activity
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.zoo.serving.serialization.ArrowSerializer
import com.intel.analytics.zoo.serving.utils.TensorUtils


/**
 * PostProssing
 * PostProcessing contains two steps
 * step 1 is filter, which is optional,
 * used to transform output tensor to type wanted
 * step 2 is to ndarray string, which is mandatory
 * to parse tensor into readable string
 * this string could be parsed by json in Python to a list
 * @param tensor
 */
class PostProcessing(tensor: Tensor[Float], filter: String = "") {
  var t: Tensor[Float] = tensor
  val totalSize = TensorUtils.getTotalSize(t)

  /**
   * Transform tensor into readable string,
   * could apply to any shape of tensor
   * @return
   */
  def tensorToNdArrayString(): String = {
    val sizeArray = t.size()
    var strideArray = Array[Int]()
    (0 until sizeArray.length).foreach(i => {
      var res: Int = 1
      (0 to i).foreach(j => {
        res *= sizeArray(sizeArray.length - 1 - j)
      })
      strideArray = strideArray :+ res
    })
    val flatTensor = t.resize(totalSize).toArray()
    var str: String = ""
    (0 until flatTensor.length).foreach(i => {
      (0 until sizeArray.length).foreach(j => {
        if (i % strideArray(j) == 0) {
          str += "["
        }
      })
      str += flatTensor(i).toString
      (0 until sizeArray.length).foreach(j => {
        if ((i + 1) % strideArray(j) == 0) {
          str += "]"
        }
      })
      if (i != flatTensor.length - 1) {
        str += ","
      }
    })
    str
  }
  /**
   * TopN filter, take 1-D size (n) tensor as input
   * @param topN
   * @return string, representing 2-D size (topN, 2) tensor
   */
  def rankTopN(topN: Int): String = {
    val list = TensorUtils.getTopN(topN, t)
    var res: String = ""
    res += "["
    (0 until list.size).foreach(i =>
      res += "[" + list(i)._1.toString + "," + list(i)._2.toString + "]"
    )
    res += "]"
    res
  }

  /**
   * Pick TopN value of output tensor
   * only (1) * record_size * box_value_number is supported
   * thus only 2 or 3 dimension is valid for now
   */
  def pickTopN(topN: Int): String = {
    require(t.dim() == 2 || t.dim() == 3,
      "pickTopN post-processing only take 2 or 3 output dim tensor")
    val thisT = if (t.dim() == 3) {
      t.squeeze(1)
    } else {
      t
    }
    require(thisT.dim() == 2,
      "Your input dim is 3 but squeeze operation fails, please open issue to Analytics Zoo team")
    var res: String = ""
    res += "["
    (1 to topN).foreach(topIdx => {
      res += "["
      (1 to t.size(2)).foreach(boxIdx => {
        res += t.valueAt(topIdx, boxIdx)
        if (boxIdx != thisT.size(2)) {
          res += ","
        }
      })
      res += "]"
    })
    res += "]"
    res
  }
  def processTensor(): String = {
    if (filter != "") {
      require(filter.last == ')',
        "please check your filter format, should be filter_name(filter_args)")
      require(filter.split("\\(").length == 2,
        "please check your filter format, should be filter_name(filter_args)")

      val filterType = filter.split("\\(").head
      val filterArgs = filter.split("\\(").last.dropRight(1).split(",")
      val res = filterType match {
        case "topN" =>
          require(filterArgs.length == 1, "topN filter only support 1 argument, please check.")
          rankTopN(filterArgs(0).toInt)
        case "pickTopN" =>
          require(filterArgs.length == 1, "pickTopN filter only support 1 argument, please check.")
          pickTopN(filterArgs(0).toInt)
        case _ => ""
      }
      res
    }
    else {
      tensorToNdArrayString()
    }
  }
}
object PostProcessing {
  /**
   *
   * @param t the result of prediction
   * @param filter custom postprocessing
   * @param index index of tensor to select, -1 means return tensor directly without select
   * @return
   */
  def apply(t: Activity, filter: String = "", index: Int = -1): String = {
    if (filter == "") {
      val byteArr = ArrowSerializer.activityBatchToByte(t, index)
      Base64.getEncoder.encodeToString(byteArr)
    }
    else {
      if (t.isTable) {
        var value = ""
        t.toTable.keySet.foreach(key => {
          val cls = new PostProcessing(t.toTable(key)
            .asInstanceOf[Tensor[Float]].select(1, index), filter)
          value += cls.processTensor()
        })
        value
      } else if (t.isTensor) {
        val cls = new PostProcessing(t.toTensor[Float].select(1, index), filter)
        cls.processTensor()
      } else {
        throw new Error("Your input for Post-processing is invalid, " +
          "neither Table nor Tensor, please check.")
      }
    }

  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy