com.intel.analytics.bigdl.dllib.nn.SoftMax.scala Maven / Gradle / Ivy
The newest version!
/*
* Copyright 2016 The BigDL Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.intel.analytics.bigdl.dllib.nn
import com.intel.analytics.bigdl.dllib.nn.abstractnn.TensorModule
import com.intel.analytics.bigdl.dllib.tensor.Tensor
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.dllib.utils.Shape
import scala.reflect.ClassTag
/**
* Applies the SoftMax function to an n-dimensional input Tensor, rescaling them so that the
* elements of the n-dimensional output Tensor lie in the range (0, 1) and sum to 1.
* Softmax is defined as: f_i(x) = exp(x_i - shift) / sum_j exp(x_j - shift)
* where shift = max_i(x_i).
*/
@SerialVersionUID(- 7842335603491194236L)
class SoftMax[T: ClassTag]()(implicit ev: TensorNumeric[T])
extends TensorModule[T] {
override def updateOutput(input: Tensor[T]): Tensor[T] = {
val dim = input.dim()
val sizes = input.size()
val shift = input.max(dim)._1
val shiftInput = input.clone()
if (dim <= 4 && dim > 1) {
optimzedOperation(shiftInput, shift, "-")
} else {
shiftInput.sub(shift.expand(sizes).contiguous())
}
val exp = shiftInput.exp()
val clonedExp = exp.clone()
val sum = clonedExp.sum(dim)
if (dim <= 4 && dim > 1) {
optimzedOperation(clonedExp, sum, "/")
} else {
clonedExp.div(sum.expand(sizes).contiguous())
}
output = clonedExp
output
}
override def updateGradInput(input: Tensor[T], gradOutput: Tensor[T]): Tensor[T] = {
val dim = input.dim()
val sum = (output.clone().cmul(gradOutput)).sum(dim)
gradInput = output.clone().cmul(gradOutput - sum.expand(input.size()))
gradInput
}
private def optimzedOperation(input1: Tensor[T], input2: Tensor[T], operation: String) = {
val dim = input1.dim()
val kk = Array.fill[Int](dim-1)(1)
var m = 0
var cnt = 0
while (kk(0) < input1.size(1) + 1) {
cnt += 1
if (cnt < input1.dim() - 1) {
m = 1
while (m