All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.bigdl.python.api.PythonBigDLKeras.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2016 The BigDL Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.intel.analytics.bigdl.python.api

import java.util.{List => JList}

import com.intel.analytics.bigdl.{Criterion, DataSet, nn}
import com.intel.analytics.bigdl.dataset.{DataSet, LocalDataSet, MiniBatch}
import com.intel.analytics.bigdl.nn.Graph.ModuleNode
import com.intel.analytics.bigdl.nn.{Container, SpatialBatchNormalization}
import com.intel.analytics.bigdl.nn.abstractnn.{AbstractModule, Activity}
import com.intel.analytics.bigdl.nn.keras._
import com.intel.analytics.bigdl.numeric._
import com.intel.analytics.bigdl.optim.{OptimMethod, Regularizer, ValidationMethod}
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.transform.vision.image.{ImageFeature, ImageFeatureToMiniBatch}
import com.intel.analytics.bigdl.utils.{Engine, MultiShape, Shape, SingleShape}
import org.apache.spark.api.java.JavaRDD

import scala.collection.JavaConverters._
import scala.language.existentials
import scala.reflect.ClassTag


object PythonBigDLKeras {

  def ofFloat(): PythonBigDLKeras[Float] = new PythonBigDLKeras[Float]()

  def ofDouble(): PythonBigDLKeras[Double] = new PythonBigDLKeras[Double]()
}

class PythonBigDLKeras[T: ClassTag](implicit ev: TensorNumeric[T]) extends PythonBigDL[T] {

  def toScalaShape(inputShape: JList[Int]): Shape = {
    if (inputShape == null) {
      null
    } else {
      Shape(inputShape.asScala.toArray)
    }
  }

  def toScalaMultiShape(inputShape: JList[JList[Int]]): Shape = {
    if (inputShape == null) {
      null
    } else {
      Shape(inputShape.asScala.toArray.map(shape => Shape(shape.asScala.toArray)).toList)
    }
  }

  def toScalaArray(list: JList[Int]): Array[Int] = {
    if (list == null) {
      null
    } else {
      list.asScala.toArray
    }
  }

  def createKerasModel(input: JList[ModuleNode[T]],
      output: JList[ModuleNode[T]]): Model[T] = {
      nn.keras.Model(input.asScala.toArray, output.asScala.toArray)
  }

  def createKerasSequential(): nn.keras.Sequential[T] = {
      nn.keras.Sequential[T]()
  }

  def createKerasInput(
    name : String = null,
    inputShape: JList[Int] = null): ModuleNode[T] = {
    Input(name = name, inputShape = toScalaShape(inputShape))
  }

  def createKerasInputLayer(
    inputShape: JList[Int] = null): KerasLayer[Activity, Activity, T] = {
    InputLayer(inputShape = toScalaShape(inputShape))
  }

  def shapeToJList(shape: Shape): JList[JList[Int]] = {
    val shapes = if (shape.isInstanceOf[SingleShape]) {
      MultiShape(List(shape))
    }
    else {
      shape
    }
    shapes.toMulti().map(single => single.toSingle().toList.asJava).toList.asJava
  }

  def getOutputShape(module: Container[Activity, Activity, T]): JList[JList[Int]] = {
    val output = module.getOutputShape()
    shapeToJList(output)
  }

  def getInputShape(module: Container[Activity, Activity, T]): JList[JList[Int]] = {
    val input = module.getInputShape()
    // TODO: inputShape can be nested MultiShape
    shapeToJList(input)
  }

  def createKerasDense(
    outputDim: Int,
    init: String = "glorot_uniform",
    activation: String = null,
    wRegularizer: Regularizer[T] = null,
    bRegularizer: Regularizer[T] = null,
    bias: Boolean = true,
    inputShape: JList[Int] = null): Dense[T] = {
    Dense(outputDim, init, activation, wRegularizer,
      bRegularizer, bias, toScalaShape(inputShape))
  }

  def createKerasEmbedding(
    inputDim: Int,
    outputDim: Int,
    init: String = "uniform",
    wRegularizer: Regularizer[T] = null,
    inputShape: JList[Int] = null): Embedding[T] = {
    Embedding[T](inputDim, outputDim, init, wRegularizer, toScalaShape(inputShape))
  }

  def createKerasBatchNormalization(
    epsilon: Double = 0.001,
    momentum: Double = 0.99,
    betaInit: String = "zero",
    gammaInit: String = "one",
    dimOrdering: String = "th",
    inputShape: JList[Int] = null): BatchNormalization[T] = {
    BatchNormalization[T](epsilon, momentum, betaInit,
      gammaInit, dimOrdering, toScalaShape(inputShape))
  }

  def setRunningMean(module: BatchNormalization[T], runningMean: JTensor): Unit = {
    module.labor.asInstanceOf[SpatialBatchNormalization[T]]
      .runningMean.set(toTensor(runningMean))
  }

  def setRunningStd(module: BatchNormalization[T], runningStd: JTensor): Unit = {
    module.labor.asInstanceOf[SpatialBatchNormalization[T]]
      .runningVar.set(toTensor(runningStd))
  }

  def getRunningMean(module: BatchNormalization[T]): JTensor = {
    toJTensor(module.labor.asInstanceOf[SpatialBatchNormalization[T]]
      .runningMean)
  }

  def getRunningStd(module: BatchNormalization[T]): JTensor = {
    toJTensor(module.labor.asInstanceOf[SpatialBatchNormalization[T]]
      .runningVar)
  }

  def createKerasMerge(
    layers: JList[AbstractModule[Activity, Activity, T]] = null,
    mode: String = "sum",
    concatAxis: Int = -1,
    inputShape: JList[JList[Int]]): Merge[T] = {
    val layersList = if (layers != null) layers.asScala.toList
                     else null
    Merge[T](layersList, mode, concatAxis, toScalaMultiShape(inputShape))
  }

  def createKerasConvolution2D(
    nbFilter: Int,
    nbRow: Int,
    nbCol: Int,
    init: String = "glorot_uniform",
    activation: String = null,
    borderMode: String = "valid",
    subsample: JList[Int],
    dimOrdering: String = "th",
    wRegularizer: Regularizer[T] = null,
    bRegularizer: Regularizer[T] = null,
    bias: Boolean = true,
    inputShape: JList[Int] = null): Convolution2D[T] = {
    new Convolution2D(nbFilter, nbRow, nbCol, KerasUtils.getInitMethod(init),
      KerasUtils.getKerasActivation(activation), borderMode,
      toScalaArray(subsample), KerasUtils.toBigDLFormat(dimOrdering),
      wRegularizer, bRegularizer, bias, toScalaShape(inputShape))
  }

  def createKerasMaxPooling2D(
    poolSize: JList[Int],
    strides: JList[Int],
    borderMode: String = "valid",
    dimOrdering: String = "th",
    inputShape: JList[Int] = null): MaxPooling2D[T] = {
    new MaxPooling2D[T](toScalaArray(poolSize), toScalaArray(strides),
      borderMode, KerasUtils.toBigDLFormat(dimOrdering), toScalaShape(inputShape))
  }

  def createKerasActivation(
    activation: String,
    inputShape: JList[Int] = null): Activation[T] = {
    Activation(activation, toScalaShape(inputShape))
  }

  def createKerasReshape(
    targetShape: JList[Int],
    inputShape: JList[Int] = null): Reshape[T] = {
    Reshape(toScalaArray(targetShape), toScalaShape(inputShape))
  }

  def createKerasDropout(
    p: Double,
    inputShape: JList[Int] = null): Dropout[T] = {
    Dropout(p, toScalaShape(inputShape))
  }

  def createKerasFlatten(
    inputShape: JList[Int] = null): Flatten[T] = {
    Flatten(toScalaShape(inputShape))
  }

  def createKerasSimpleRNN(
    outputDim: Int,
    activation: String = "tanh",
    returnSequences: Boolean = false,
    goBackwards: Boolean = false,
    wRegularizer: Regularizer[T] = null,
    uRegularizer: Regularizer[T] = null,
    bRegularizer: Regularizer[T] = null,
    inputShape: JList[Int] = null): SimpleRNN[T] = {
    SimpleRNN(outputDim, activation, returnSequences, goBackwards,
      wRegularizer, uRegularizer, bRegularizer, toScalaShape(inputShape))
  }

  def createKerasLSTM(
    outputDim: Int,
    activation: String = "tanh",
    innerActivation: String = "hard_sigmoid",
    returnSequences: Boolean = false,
    goBackwards: Boolean = false,
    wRegularizer: Regularizer[T] = null,
    uRegularizer: Regularizer[T] = null,
    bRegularizer: Regularizer[T] = null,
    inputShape: JList[Int] = null): LSTM[T] = {
    LSTM(outputDim, activation, innerActivation, returnSequences,
      goBackwards, wRegularizer, uRegularizer, bRegularizer, toScalaShape(inputShape))
  }

  def createKerasGRU(
    outputDim: Int,
    activation: String = "tanh",
    innerActivation: String = "hard_sigmoid",
    returnSequences: Boolean = false,
    goBackwards: Boolean = false,
    wRegularizer: Regularizer[T] = null,
    uRegularizer: Regularizer[T] = null,
    bRegularizer: Regularizer[T] = null,
    inputShape: JList[Int] = null): GRU[T] = {
    GRU(outputDim, activation, innerActivation, returnSequences,
      goBackwards, wRegularizer, uRegularizer, bRegularizer, toScalaShape(inputShape))
  }

  def createKerasHighway(
    activation: String = null,
    wRegularizer: Regularizer[T] = null,
    bRegularizer: Regularizer[T] = null,
    bias: Boolean = true,
    inputShape: JList[Int] = null): Highway[T] = {
    Highway(activation, wRegularizer, bRegularizer, bias, toScalaShape(inputShape))
  }

  def createKerasZeroPadding1D(
    padding: JList[Int],
    inputShape: JList[Int] = null): ZeroPadding1D[T] = {
    new ZeroPadding1D(toScalaArray(padding), toScalaShape(inputShape))
  }

  def createKerasZeroPadding2D(
    padding: JList[Int],
    dimOrdering: String = "th",
    inputShape: JList[Int] = null): ZeroPadding2D[T] = {
    new ZeroPadding2D(toScalaArray(padding),
      KerasUtils.toBigDLFormat(dimOrdering), toScalaShape(inputShape))
  }

  def createKerasUpSampling1D(
    length: Int = 2,
    inputShape: JList[Int] = null): UpSampling1D[T] = {
    UpSampling1D(length, toScalaShape(inputShape))
  }

  def createKerasUpSampling2D(
    size: JList[Int],
    dimOrdering: String = "th",
    inputShape: JList[Int] = null): UpSampling2D[T] = {
    new UpSampling2D(toScalaArray(size), KerasUtils.toBigDLFormat(dimOrdering),
      toScalaShape(inputShape))
  }

  def createKerasUpSampling3D(
    size: JList[Int],
    dimOrdering: String = "th",
    inputShape: JList[Int] = null): UpSampling3D[T] = {
    new UpSampling3D(toScalaArray(size), KerasUtils.toBigDLFormat5D(dimOrdering),
      toScalaShape(inputShape))
  }

  def createKerasMaxoutDense(
    outputDim: Int,
    nbFeature: Int = 4,
    wRegularizer: Regularizer[T] = null,
    bRegularizer: Regularizer[T] = null,
    bias: Boolean = true,
    inputShape: JList[Int] = null): MaxoutDense[T] = {
    MaxoutDense(outputDim, nbFeature, wRegularizer,
      bRegularizer, bias, toScalaShape(inputShape))
  }

  def createKerasConvolution1D(
    nbFilter: Int,
    filterLength: Int,
    init: String = "glorot_uniform",
    activation: String = null,
    borderMode: String = "valid",
    subsampleLength: Int = 1,
    wRegularizer: Regularizer[T] = null,
    bRegularizer: Regularizer[T] = null,
    bias: Boolean = true,
    inputShape: JList[Int] = null): Convolution1D[T] = {
    Convolution1D(nbFilter, filterLength, init, activation, borderMode,
      subsampleLength, wRegularizer, bRegularizer, bias, toScalaShape(inputShape))
  }

  def createKerasConvolution3D(
    nbFilter: Int,
    kernelDim1: Int,
    kernelDim2: Int,
    kernelDim3: Int,
    init: String = "glorot_uniform",
    activation: String = null,
    borderMode: String = "valid",
    subsample: JList[Int],
    dimOrdering: String = "th",
    wRegularizer: Regularizer[T] = null,
    bRegularizer: Regularizer[T] = null,
    bias: Boolean = true,
    inputShape: JList[Int] = null): Convolution3D[T] = {
    new Convolution3D(nbFilter, kernelDim1, kernelDim2, kernelDim3,
      KerasUtils.getInitMethod(init), KerasUtils.getKerasActivation(activation),
      borderMode, toScalaArray(subsample), KerasUtils.toBigDLFormat5D(dimOrdering),
      wRegularizer, bRegularizer, bias, toScalaShape(inputShape))
  }

  def createKerasMaxPooling1D(
    poolLength: Int = 2,
    stride: Int = -1,
    borderMode: String = "valid",
    inputShape: JList[Int] = null): MaxPooling1D[T] = {
    MaxPooling1D(poolLength, stride, borderMode, toScalaShape(inputShape))
  }

  def createKerasMaxPooling3D(
    poolSize: JList[Int],
    strides: JList[Int],
    dimOrdering: String = "th",
    inputShape: JList[Int] = null): MaxPooling3D[T] = {
    new MaxPooling3D(toScalaArray(poolSize), toScalaArray(strides),
      KerasUtils.toBigDLFormat5D(dimOrdering), toScalaShape(inputShape))
  }

  def createKerasAveragePooling1D(
    poolLength: Int = 2,
    stride: Int = -1,
    borderMode: String = "valid",
    inputShape: JList[Int] = null): AveragePooling1D[T] = {
    AveragePooling1D(poolLength, stride, borderMode, toScalaShape(inputShape))
  }

  def createKerasAveragePooling2D(
    poolSize: JList[Int],
    strides: JList[Int],
    borderMode: String = "valid",
    dimOrdering: String = "th",
    inputShape: JList[Int] = null): AveragePooling2D[T] = {
    new AveragePooling2D(toScalaArray(poolSize), toScalaArray(strides),
      borderMode, KerasUtils.toBigDLFormat(dimOrdering), toScalaShape(inputShape))
  }

  def createKerasAveragePooling3D(
    poolSize: JList[Int],
    strides: JList[Int],
    dimOrdering: String = "th",
    inputShape: JList[Int] = null): AveragePooling3D[T] = {
    new AveragePooling3D(toScalaArray(poolSize), toScalaArray(strides),
      KerasUtils.toBigDLFormat5D(dimOrdering), toScalaShape(inputShape))
  }

  def createKerasGlobalAveragePooling2D(
    dimOrdering: String = "th",
    inputShape: JList[Int] = null): GlobalAveragePooling2D[T] = {
    GlobalAveragePooling2D(dimOrdering, toScalaShape(inputShape))
  }

  def createKerasGlobalMaxPooling2D(
    dimOrdering: String = "th",
    inputShape: JList[Int] = null): GlobalMaxPooling2D[T] = {
    GlobalMaxPooling2D(dimOrdering, toScalaShape(inputShape))
  }

  def createKerasRepeatVector(
    n: Int,
    inputShape: JList[Int] = null): RepeatVector[T] = {
    RepeatVector(n, toScalaShape(inputShape))
  }

  def createKerasPermute(
    dims: JList[Int],
    inputShape: JList[Int] = null): Permute[T] = {
    Permute(toScalaArray(dims), toScalaShape(inputShape))
  }

  def createKerasCropping1D(
    cropping: JList[Int],
    inputShape: JList[Int] = null): Cropping1D[T] = {
    new Cropping1D(toScalaArray(cropping), toScalaShape(inputShape))
  }

  def createKerasCropping2D(
    heightCrop: JList[Int],
    widthCrop: JList[Int],
    dimOrdering: String = "th",
    inputShape: JList[Int] = null): Cropping2D[T] = {
    new Cropping2D(toScalaArray(heightCrop), toScalaArray(widthCrop),
      KerasUtils.toBigDLFormat(dimOrdering), toScalaShape(inputShape))
  }

  def createKerasCropping3D(
    dim1Crop: JList[Int],
    dim2Crop: JList[Int],
    dim3Crop: JList[Int],
    dimOrdering: String = "th",
    inputShape: JList[Int] = null): Cropping3D[T] = {
    new Cropping3D(toScalaArray(dim1Crop), toScalaArray(dim2Crop), toScalaArray(dim3Crop),
      KerasUtils.toBigDLFormat5D(dimOrdering), toScalaShape(inputShape))
  }

  def createKerasAtrousConvolution1D(
    nbFilter: Int,
    filterLength: Int,
    init: String = "glorot_uniform",
    activation: String = null,
    subsampleLength: Int = 1,
    atrousRate: Int = 1,
    wRegularizer: Regularizer[T] = null,
    bRegularizer: Regularizer[T] = null,
    inputShape: JList[Int] = null): AtrousConvolution1D[T] = {
    AtrousConvolution1D(nbFilter, filterLength, init, activation,
      subsampleLength, atrousRate, wRegularizer, bRegularizer, toScalaShape(inputShape))
  }

  def createKerasAtrousConvolution2D(
    nbFilter: Int,
    nbRow: Int,
    nbCol: Int,
    init: String = "glorot_uniform",
    activation: String = null,
    subsample: JList[Int],
    atrousRate: JList[Int],
    dimOrdering: String = "th",
    wRegularizer: Regularizer[T] = null,
    bRegularizer: Regularizer[T] = null,
    inputShape: JList[Int] = null): AtrousConvolution2D[T] = {
    new AtrousConvolution2D(nbFilter, nbRow, nbCol, KerasUtils.getInitMethod(init),
      KerasUtils.getKerasActivation(activation), toScalaArray(subsample),
      toScalaArray(atrousRate), KerasUtils.toBigDLFormat(dimOrdering),
      wRegularizer, bRegularizer, toScalaShape(inputShape))
  }

  def createKerasDeconvolution2D(
    nbFilter: Int,
    nbRow: Int,
    nbCol: Int,
    init: String = "glorot_uniform",
    activation: String = null,
    subsample: JList[Int],
    dimOrdering: String = "th",
    wRegularizer: Regularizer[T] = null,
    bRegularizer: Regularizer[T] = null,
    bias: Boolean = true,
    inputShape: JList[Int] = null): Deconvolution2D[T] = {
    new Deconvolution2D(nbFilter, nbRow, nbCol, KerasUtils.getInitMethod(init),
      KerasUtils.getKerasActivation(activation), toScalaArray(subsample),
      KerasUtils.toBigDLFormat(dimOrdering), wRegularizer, bRegularizer,
      bias, toScalaShape(inputShape))
  }

  def createKerasConvLSTM2D(
    nbFilter: Int,
    nbKernel: Int,
    activation: String = "tanh",
    innerActivation: String = "hard_sigmoid",
    dimOrdering: String = "th",
    subsample: Int = 1,
    wRegularizer: Regularizer[T] = null,
    uRegularizer: Regularizer[T] = null,
    bRegularizer: Regularizer[T] = null,
    returnSequences: Boolean = false,
    goBackwards: Boolean = false,
    inputShape: JList[Int] = null): ConvLSTM2D[T] = {
    ConvLSTM2D(nbFilter, nbKernel, activation, innerActivation,
      dimOrdering, subsample, wRegularizer, uRegularizer, bRegularizer,
      returnSequences, goBackwards, toScalaShape(inputShape))
  }

  def createKerasLocallyConnected1D(
    nbFilter: Int,
    filterLength: Int,
    activation: String = null,
    subsampleLength: Int = 1,
    wRegularizer: Regularizer[T] = null,
    bRegularizer: Regularizer[T] = null,
    bias: Boolean = true,
    inputShape: JList[Int] = null): LocallyConnected1D[T] = {
    LocallyConnected1D(nbFilter, filterLength, activation, subsampleLength,
      wRegularizer, bRegularizer, bias, toScalaShape(inputShape))
  }

  def createKerasLocallyConnected2D(
    nbFilter: Int,
    nbRow: Int,
    nbCol: Int,
    activation: String = null,
    borderMode: String = "valid",
    subsample: JList[Int],
    dimOrdering: String = "th",
    wRegularizer: Regularizer[T] = null,
    bRegularizer: Regularizer[T] = null,
    bias: Boolean = true,
    inputShape: JList[Int] = null): LocallyConnected2D[T] = {
    new LocallyConnected2D(nbFilter, nbRow, nbCol, KerasUtils.getKerasActivation(activation),
      borderMode, toScalaArray(subsample), KerasUtils.toBigDLFormat(dimOrdering),
      wRegularizer, bRegularizer, bias, toScalaShape(inputShape))
  }

  def createKerasSeparableConvolution2D(
    nbFilter: Int,
    nbRow: Int,
    nbCol: Int,
    init: String = "glorot_uniform",
    activation: String = null,
    borderMode: String = "valid",
    subsample: JList[Int],
    depthMultiplier: Int = 1,
    dimOrdering: String = "th",
    depthwiseRegularizer: Regularizer[T] = null,
    pointwiseRegularizer: Regularizer[T] = null,
    bRegularizer: Regularizer[T] = null,
    bias: Boolean = true,
    inputShape: JList[Int] = null): SeparableConvolution2D[T] = {
    new SeparableConvolution2D(nbFilter, nbRow, nbCol, KerasUtils.getInitMethod(init),
      KerasUtils.getKerasActivation(activation), borderMode, toScalaArray(subsample),
      depthMultiplier, KerasUtils.toBigDLFormat(dimOrdering),
      depthwiseRegularizer, pointwiseRegularizer, bRegularizer, bias, toScalaShape(inputShape))
  }

  def createKerasZeroPadding3D(
    padding: JList[Int],
    dimOrdering: String = "th",
    inputShape: JList[Int] = null): ZeroPadding3D[T] = {
    new ZeroPadding3D(toScalaArray(padding), KerasUtils.toBigDLFormat5D(dimOrdering),
      toScalaShape(inputShape))
  }

  def createKerasGlobalAveragePooling1D(
    inputShape: JList[Int] = null): GlobalAveragePooling1D[T] = {
    GlobalAveragePooling1D(toScalaShape(inputShape))
  }

  def createKerasGlobalMaxPooling1D(
    inputShape: JList[Int] = null): GlobalMaxPooling1D[T] = {
    GlobalMaxPooling1D(toScalaShape(inputShape))
  }

  def createKerasGlobalMaxPooling3D(
    dimOrdering: String = "th",
    inputShape: JList[Int] = null): GlobalMaxPooling3D[T] = {
    GlobalMaxPooling3D(dimOrdering, toScalaShape(inputShape))
  }

  def createKerasGlobalAveragePooling3D(
    dimOrdering: String = "th",
    inputShape: JList[Int] = null): GlobalAveragePooling3D[T] = {
    GlobalAveragePooling3D(dimOrdering, toScalaShape(inputShape))
  }

  def createKerasSpatialDropout1D(
    p: Double = 0.5,
    inputShape: JList[Int] = null): SpatialDropout1D[T] = {
    SpatialDropout1D(p, toScalaShape(inputShape))
  }

  def createKerasSpatialDropout2D(
    p: Double = 0.5,
    dimOrdering: String = "th",
    inputShape: JList[Int] = null): SpatialDropout2D[T] = {
    SpatialDropout2D(p, dimOrdering, toScalaShape(inputShape))
  }

  def createKerasSpatialDropout3D(
    p: Double = 0.5,
    dimOrdering: String = "th",
    inputShape: JList[Int] = null): SpatialDropout3D[T] = {
    SpatialDropout3D(p, dimOrdering, toScalaShape(inputShape))
  }

  def createKerasGaussianDropout(
    p: Double,
    inputShape: JList[Int] = null): GaussianDropout[T] = {
    GaussianDropout(p, toScalaShape(inputShape))
  }

  def createKerasGaussianNoise(
    sigma: Double,
    inputShape: JList[Int] = null): GaussianNoise[T] = {
    GaussianNoise(sigma, toScalaShape(inputShape))
  }

  def createKerasMasking(
    maskValue: Double = 0.0,
    inputShape: JList[Int] = null): Masking[T] = {
    Masking(maskValue, toScalaShape(inputShape))
  }

  def createKerasSReLU(
    tLeftInit: String = "zero",
    aLeftInit: String = "glorot_uniform",
    tRightInit: String = "glorot_uniform",
    aRightInit: String = "one",
    sharedAxes: JList[Int] = null,
    inputShape: JList[Int] = null): SReLU[T] = {
    SReLU(tLeftInit, aLeftInit, tRightInit, aRightInit,
      toScalaArray(sharedAxes), toScalaShape(inputShape))
  }

  def createKerasELU(
    alpha: Double = 1.0,
    inputShape: JList[Int] = null): ELU[T] = {
    ELU(alpha, toScalaShape(inputShape))
  }

  def createKerasLeakyReLU(
    alpha: Double = 0.01,
    inputShape: JList[Int] = null): LeakyReLU[T] = {
    LeakyReLU(alpha, toScalaShape(inputShape))
  }

  def createKerasThresholdedReLU(
    theta: Double = 1.0,
    inputShape: JList[Int] = null): ThresholdedReLU[T] = {
    ThresholdedReLU(theta, toScalaShape(inputShape))
  }

  def createKerasTimeDistributed(
    layer: KerasLayer[Tensor[T], Tensor[T], T],
    inputShape: JList[Int] = null): TimeDistributed[T] = {
    TimeDistributed(layer, toScalaShape(inputShape))
  }

  def createKerasBidirectional(
    layer: Recurrent[T],
    mergeMode: String = "concat",
    inputShape: JList[Int] = null): Bidirectional[T] = {
    Bidirectional(layer, mergeMode, toScalaShape(inputShape))
  }

  def compile(
    module: KerasModel[T],
    optimizer: OptimMethod[T],
    loss: Criterion[T],
    metrics: JList[ValidationMethod[T]] = null): Unit = {
    module.compile(optimizer, loss,
      if (metrics == null) null else metrics.asScala.toArray)
  }

  def fit(
    module: KerasModel[T],
    x: JavaRDD[Sample],
    batchSize: Int = 32,
    epochs: Int = 10,
    validationData: JavaRDD[Sample] = null): Unit = {
    module.fit(toJSample(x), batchSize, epochs,
      if (validationData == null) null else toJSample(validationData))
  }

  def fit(
    module: KerasModel[T],
    x: DataSet[ImageFeature],
    batchSize: Int,
    epochs: Int,
    validationData: DataSet[ImageFeature]): Unit = {
    val trainData = x -> ImageFeatureToMiniBatch[T](batchSize)
    val valData =
      if (validationData != null) validationData -> ImageFeatureToMiniBatch[T](batchSize)
      else null
    module.fit(trainData, epochs, valData)
  }

  def fit(
    module: KerasModel[T],
    xTrain: JList[JTensor],
    yTrain: JTensor,
    batchSize: Int,
    epochs: Int,
    xVal: JList[JTensor],
    yVal: JTensor,
    localCores: Int): Unit = {
    val trainArray = toSampleArray(xTrain.asScala.toList.map{f => toTensor(f)}, toTensor(yTrain))
    val trainData = batching(DataSet.array(trainArray), batchSize)
      .asInstanceOf[LocalDataSet[MiniBatch[T]]]
    val valData = if (xVal != null && yVal != null) {
      val valArray = toSampleArray(xVal.asScala.toList.map{f => toTensor(f)}, toTensor(yVal))
      batching(DataSet.array(valArray), batchSize)
    } else null
    Engine.setNodeAndCore(1, localCores)
    module.fit(trainData, epochs, valData)
  }

  def evaluate(
    module: KerasModel[T],
    x: JavaRDD[Sample],
    batchSize: Int = 32): JList[EvaluatedResult] = {
    val resultArray = module.evaluate(toJSample(x), batchSize)
    val testResultArray = resultArray.map { result =>
      EvaluatedResult(result._1.result()._1, result._1.result()._2,
        result._2.toString())
    }
    testResultArray.toList.asJava
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy