All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.bigdl.nn.CrossProduct.scala Maven / Gradle / Ivy

There is a newer version: 0.11.1
Show newest version
/*
 * Copyright 2016 The BigDL Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.intel.analytics.bigdl.nn

import com.intel.analytics.bigdl.nn.abstractnn.AbstractModule
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.utils.{T, Table}

import scala.reflect.ClassTag

/**
 * A layer which takes a table of multiple tensors(n >= 2) as input
 * and calculate to dot product for `all combinations of pairs` among input tensors.
 * 

* Dot-product outputs are ordered according to orders of pairs in input Table. * For instance, input (Table) is T(A, B, C), output (Tensor) will be [A.*B, A.*C, B.*C]. *

* Dimensions of input' Tensors could be one or two, if two, first dimension is `batchSize`. * For convenience, output is 2-dim Tensor regardless of input' dims. *

* Table size checking and Tensor size checking will be execute before each forward, * when [[numTensor]] and [[embeddingSize]] are set values greater than zero. * * @param numTensor (for checking)number of Tensor input Table contains, default: 0(won't check) * @param embeddingSize (for checking)vector length of dot product, default: 0(won't check) */ class CrossProduct[T: ClassTag]( val numTensor: Int = 0, val embeddingSize: Int = 0 )(implicit ev: TensorNumeric[T]) extends AbstractModule[Table, Tensor[T], T] { override def updateOutput(input: Table): Tensor[T] = { val len = input.length() require(numTensor <= 0 || numTensor == len, s"Input tensor number is $len, unequal to numTensor($numTensor)!") val (_, batch, _) = getShape(input[Tensor[T]](1)) output.resize(batch, len * (len - 1) / 2) if (embeddingSize > 0) { var i = 1 while (i <= len) { checkEmbeddingSize(input(i)) i += 1 } } var cc = 1 var i = 1 var j = 2 while (i < len) { val ijDot = batchDot(input(i), input(j)) output.select(2, cc).copy(ijDot) cc += 1 if (j == len) { i += 1 j = i + 1 } else { j += 1 } } output } override def updateGradInput(input: Table, gradOutput: Tensor[T]): Table = { gradInput = T() val len = input.length() val gout = gradOutput require(gout.dim() == 2, s"invalid dim of gradOutput(${gout.dim()})!") val outLen = len * (len - 1) / 2 require(gout.size(2) == outLen, s"invalid colSize of gradOutput(${gout.size(2)}), it should be $outLen!") val (dim, _, emLen) = getShape(input[Tensor[T]](1)) var cc = 1 var i = 1 var j = 2 while (i < len) { val (ti, tj) = dim match { case 1 => input[Tensor[T]](i).view(1, emLen) -> input[Tensor[T]](j).view(1, emLen) case 2 => input[Tensor[T]](i) -> input[Tensor[T]](j) } // get cc_th column data from total gradOut val go = gout.narrow(2, cc, 1) val jInc = Tensor[T]().resizeAs(ti).copy(ti).cmul(go) if (dim == 1) jInc.squeeze() gradInput.get[Tensor[T]](j) match { case None => gradInput.update(j, jInc) case Some(v) => v.add(jInc) } val iInc = Tensor[T]().resizeAs(tj).copy(tj).cmul(go) if (dim == 1) iInc.squeeze() gradInput.get[Tensor[T]](i) match { case None => gradInput.update(i, iInc) case Some(v) => v.add(iInc) } cc += 1 if (j == len) { i += 1 j = i + 1 } else { j += 1 } } gradInput } protected def checkEmbeddingSize(t: Tensor[T]): Unit = { val size = if (t.dim() == 1) t.size(1) else t.size(2) require(embeddingSize <= 0 || embeddingSize == size, s"size of input Tensor($size) not equal to embeddingSize($embeddingSize)!") } protected def batchDot(t1: Tensor[T], t2: Tensor[T]): Tensor[T] = { var (input1, input2) = (t1, t2) if (input1.dim() == 1) { input1 = input1.view(1, input1.size(1)) input2 = input2.view(1, input2.size(1)) } val buffer = Tensor[T]() buffer.resizeAs(input1).cmul(input1, input2) buffer.sum(2).squeeze() } private def getShape(t: Tensor[T]) = { val (batch, size) = t.dim() match { case 1 => 1 -> t.size(1) case 2 => t.size(1) -> t.size(2) case n => throw new IllegalArgumentException(s"wrong dim of input Tensor($n)!") } (t.dim(), batch, size) } } object CrossProduct { def apply[T: ClassTag]()(implicit ev: TensorNumeric[T]): CrossProduct[T] = new CrossProduct[T]() def apply[T: ClassTag]( numTensor: Int = 0, embeddingSize: Int = 0 )(implicit ev: TensorNumeric[T]): CrossProduct[T] = { new CrossProduct(numTensor, embeddingSize) } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy