All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.zoo.feature.common.MTSampleToMiniBatch.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2018 Analytics Zoo Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.intel.analytics.zoo.feature.common

import com.intel.analytics.bigdl.dataset._
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.zoo.pipeline.api.keras.layers.utils.EngineRef

import scala.reflect.ClassTag

/**
 * Convert a sequence of [[Sample]] to a sequence of [[MiniBatch]]
 * through function toMiniBatch using multi thread.
 */
class MTSampleToMiniBatch[A: ClassTag, T: ClassTag] (
          totalBatch: Int,
          transformer: Transformer[A, Sample[T]],
          miniBatch: Option[MiniBatch[T]] = None,
          featurePaddingParam: Option[PaddingParam[T]] = None,
          labelPaddingParam: Option[PaddingParam[T]] = None,
          partitionNum: Option[Int] = None)
      (implicit ev: TensorNumeric[T]) extends Transformer[A, MiniBatch[T]] {

  private val batchPerPartition = Utils.getBatchSize(totalBatch, partitionNum)
  var miniBatchBuffer = miniBatch.orNull
  private val batchSize = batchPerPartition
  private val sampleData = new Array[Sample[T]](batchSize)

  private val parallelism = EngineRef.getCoreNumber()

  private val transformers = (0 until parallelism).map(
    _ => transformer.cloneTransformer()
  ).toArray

  private val rawDataCache = new Array[Iterator[A]](batchSize)

  override def apply(prev: Iterator[A]): Iterator[MiniBatch[T]] = {
    new Iterator[MiniBatch[T]] {

      override def hasNext: Boolean = prev.hasNext

      override def next(): MiniBatch[T] = {
        if (prev.hasNext) {
          // prefetch
          var count = 0
          while (count < batchSize && prev.hasNext) {
            val raw = prev.next()
            rawDataCache(count) = Iterator.single(raw)
            count += 1
          }

          // multi thread processing
          (0 until parallelism).toParArray.foreach{tid =>
            var j = tid
            while (j < count) {
              sampleData(j) = transformers(tid).apply(rawDataCache(j)).next()
              j += parallelism
            }
          }

          if (null == miniBatchBuffer) {
            val firstSample = sampleData(0)
            miniBatchBuffer = if (firstSample.isInstanceOf[TensorSample[T]]) {
              SparseMiniBatch(firstSample.numFeature(), firstSample.numLabel())
            } else {
              MiniBatch(firstSample.numFeature(), firstSample.numLabel(),
                featurePaddingParam, labelPaddingParam)
            }
          }

          if (count < batchSize) {
            miniBatchBuffer.set(sampleData.slice(0, count))
          } else {
            miniBatchBuffer.set(sampleData)
          }
        } else {
          null
        }
      }
    }
  }
}

object MTSampleToMiniBatch {
  /**
   * Apply an MTSampleToMiniBatch transformer.
   *
   * @param batchSize           total batch size
   * @param transformer         transformer who rawData to Sample
   * @param featurePaddingParam feature padding strategy, see
   *                            [[com.intel.analytics.bigdl.dataset.PaddingParam]] for details.
   * @param labelPaddingParam   label padding strategy, see
   *                            [[com.intel.analytics.bigdl.dataset.PaddingParam]] for details.
   * @return
   */
  def apply[A: ClassTag, T: ClassTag](
         batchSize : Int,
         transformer: Transformer[A, Sample[T]],
         featurePaddingParam: Option[PaddingParam[T]] = None,
         labelPaddingParam: Option[PaddingParam[T]] = None,
         partitionNum: Option[Int] = None
         )(implicit ev: TensorNumeric[T]): MTSampleToMiniBatch[A, T] = {
    new MTSampleToMiniBatch[A, T](batchSize,
      transformer,
      None, featurePaddingParam, labelPaddingParam, partitionNum)
  }

  /**
   * Apply an MTSampleToMiniBatch transformer with UDF MiniBatch.
   *
   * @param batchSize total batch size
   * @param miniBatch An User-Defined MiniBatch to construct a mini batch.
   * @param transformer transformer who rawData to Sample
   * @return
   */
  def apply[A: ClassTag, T: ClassTag](
        miniBatch: MiniBatch[T],
        batchSize : Int,
        transformer: Transformer[A, Sample[T]],
        partitionNum: Option[Int])
        (implicit ev: TensorNumeric[T]): MTSampleToMiniBatch[A, T] = {
    new MTSampleToMiniBatch[A, T](batchSize,
      transformer,
      Some(miniBatch), partitionNum = partitionNum)
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy