All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.bigdl.nn.keras.Merge.scala Maven / Gradle / Ivy

There is a newer version: 0.13.0
Show newest version
/*
 * Copyright 2016 The BigDL Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.intel.analytics.bigdl.nn.keras

import com.intel.analytics.bigdl.nn.Graph.ModuleNode
import com.intel.analytics.bigdl.nn.abstractnn.{AbstractModule, Activity}
import com.intel.analytics.bigdl.nn.{CAddTable, CAveTable, CMaxTable, CMulTable, CosineDistance, DotProduct, JoinTable, ParallelTable, Sequential => TSequential}
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.utils.{MultiShape, Shape}

import scala.reflect.ClassTag

/**
 * Used to merge a list of inputs into a single output, following some merge mode.
 * To merge layers, it must take at least two input layers.
 *
 * When using this layer as the first layer in a model, you need to provide the argument
 * inputShape for input layers (each as a Single Shape, does not include the batch dimension).
 *
 * @param layers A list of layer instances. Must be more than one layer.
 * @param mode Merge mode. String, must be one of: 'sum', 'mul', 'concat', 'ave', 'cos',
 *             'dot', 'max'. Default is 'sum'.
 * @param concatAxis Integer, axis to use when concatenating layers. Only specify this when merge
 *                   mode is 'concat'. Default is -1, meaning the last axis of the input.
 * @tparam T The numeric type of parameter(e.g. weight, bias). Only support float/double now.
 */
class Merge[T: ClassTag](
   val layers: Array[AbstractModule[Activity, Activity, T]] = null,
   val mode: String = "sum",
   val concatAxis: Int = -1,
   // MultiShape isn't directly supported for serialization. Use Shape instead.
   val inputShape: Shape = null)(implicit ev: TensorNumeric[T])
  extends KerasLayer[Tensor[T], Tensor[T], T](Merge.calcBatchInputShape(inputShape, layers)) {

  private val mergeMode = mode.toLowerCase()
  private var axis = concatAxis

  require(mergeMode == "sum" || mergeMode == "mul" || mergeMode == "concat" || mergeMode == "ave"
  || mergeMode == "cos" || mergeMode == "dot" || mergeMode == "max",
  s"Invalid merge mode: $mergeMode")
  if (layers != null) {
    require(layers.length >= 2, s"Merge must take at least two input layers " +
      s"but found ${layers.length}")
    this.excludeInvalidLayers(layers)
  }

  private def computeOutputShapeForConcat(input: List[Shape]): Shape = {
    import scala.util.control.Breaks._
    val input1 = input.head.toSingle().toArray
    val output = input1.clone()
    require(Math.abs(concatAxis) < output.length, s"Invalid concat axis $concatAxis")
    axis = if (concatAxis < 0) concatAxis + output.length else concatAxis
    var i = 1
    while (i < input.length) {
      val input_i = input(i).toSingle().toArray
      var j = 0
      while (j < input_i.length) {
        if (j != axis) require(input_i(j)==output(j), s"Incompatible input dimension for merge " +
          s"mode concat: (${output.deep.mkString(", ")}), " +
          s"(${input_i.deep.mkString(", ")})")
        j += 1
      }
      if (output(axis) == -1 || input_i(axis) == -1) {
        output(i) = -1
        break
      }
      output(axis) = output(axis) + input_i(axis)
      i += 1
    }
    Shape(output)
  }

  private def checkSameInputShape(input: List[Shape]): Unit = {
    val input1 = input.head.toSingle().toArray
    var i = 1
    while (i < input.length) {
      val input_i = input(i).toSingle().toArray
      require(input_i.sameElements(input1), s"Incompatible input dimension for " +
        s"merge mode $mergeMode: (${input1.deep.mkString(", ")}), " +
        s"(${input_i.deep.mkString(", ")})")
      i += 1
    }
  }

  override def computeOutputShape(inputShape: Shape): Shape = {
    val input = inputShape.toMulti()
    val input1 = input.head.toSingle().toArray
    if (mergeMode == "concat") {
      computeOutputShapeForConcat(input)
    }
    else {
      checkSameInputShape(input)
      if (mergeMode == "dot" || mergeMode == "cos") {
        require(input.head.toSingle().length <=2, s"For merge mode $mergeMode, 3D input " +
          s"or above is currently not supported, got input dim ${input.head.toSingle().length}")
        require(input.length == 2, s"Merge mode $mergeMode takes exactly two layers, " +
          s"but got ${input.length}")
        if (mergeMode == "dot") Shape(-1, 1) else Shape(-1, 1, 1)
      }
      else {
        input.head
      }
    }
  }

  override def doBuild(inputShape: Shape): AbstractModule[Tensor[T], Tensor[T], T] = {
    val input = inputShape.toMulti()
    val mergeLayer = mergeMode match {
      case "sum" => CAddTable()
      case "mul" => CMulTable()
      case "max" => CMaxTable()
      case "ave" => CAveTable()
      case "concat" =>
        val input1 = input.head.toSingle().toArray
        JoinTable(axis, input1.length -1)
      case "dot" =>
        val seq = TSequential[T]()
        seq.add(DotProduct())
        seq.add(com.intel.analytics.bigdl.nn.Reshape(Array(1), Some(true)))
        seq
      case "cos" =>
        val seq = TSequential[T]()
        seq.add(CosineDistance())
        seq.add(com.intel.analytics.bigdl.nn.Reshape(Array(1, 1), Some(true)))
        seq
    }
    if (layers != null) { // In the case `layers != null`, return a ParallelTable to merge layers
      val model = TSequential[T]()
      val parallel = ParallelTable()
      var i = 0
      while(i < layers.length) {
        parallel.add(layers(i).asInstanceOf[KerasLayer[Activity, Activity, T]].labor)
        i += 1
      }
      model.add(parallel)
      model.add(mergeLayer)
      model.asInstanceOf[AbstractModule[Tensor[T], Tensor[T], T]]
    }
    else { // In the case `layers == null`, only return a merge layer to merge nodes not layers.
      mergeLayer.asInstanceOf[AbstractModule[Tensor[T], Tensor[T], T]]
    }
  }
}

object Merge {
  def calcBatchInputShape[T: ClassTag](
    inputShape: Shape = null,
    layers: Array[AbstractModule[Activity, Activity, T]]): Shape = {
    val batchInputShape = KerasLayer.addBatch(inputShape)
    val actualInputShape = if (layers != null) {
      MultiShape(layers.map { layer =>
        if (layer.isBuilt()) {  // it's possible while reloaded from file
          layer.getOutputShape()
        } else {
          layer.build(layer.getInputShape())
        }
      }.toList)
    } else null
    if (batchInputShape != null) {
      require(batchInputShape.isInstanceOf[MultiShape],
        "Merge requires inputShape to be MultiShape")
      require(batchInputShape.toMulti().equals(actualInputShape.toMulti()),
        "Actual layer input shapes are not the same as expected layer input shapes")
    }
    actualInputShape
  }

  def apply[@specialized(Float, Double) T: ClassTag](
    layers: List[AbstractModule[Activity, Activity, T]] = null,
    mode: String = "sum",
    concatAxis: Int = -1,
    inputShape: Shape = null)(implicit ev: TensorNumeric[T]): Merge[T] = {
    val layersArray = if (layers != null) layers.toArray else null
    new Merge[T](layersArray, mode, concatAxis, inputShape)
  }

  def merge[@specialized(Float, Double) T: ClassTag](
    inputs: List[ModuleNode[T]],
    mode: String = "sum",
    concatAxis: Int = -1,
    name: String = null)(implicit ev: TensorNumeric[T]): ModuleNode[T] = {
    val mergeLayer = new Merge[T](mode = mode, concatAxis = concatAxis)
    if (name != null) mergeLayer.setName(name)
    mergeLayer.inputs(inputs.toArray)
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy