All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.bigdl.dlframes.DLImageTransformer.scala Maven / Gradle / Ivy

There is a newer version: 0.11.1
Show newest version
/*
 * Copyright 2016 The BigDL Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.intel.analytics.bigdl.dlframes

import com.intel.analytics.bigdl.dataset.Transformer
import com.intel.analytics.bigdl.transform.vision.image.{FeatureTransformer, ImageFeature, MatToTensor}
import org.apache.spark.ml.DLTransformerBase
import org.apache.spark.ml.adapter.{HasInputCol, HasOutputCol, SchemaUtils}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row}

/**
 * Provides DataFrame-based API for image pre-processing and feature transformation.
 * DLImageTransformer follows the Spark Transformer API pattern and can be used as one stage
 * in Spark ML pipeline.
 *
 * The input column can be either DLImageSchema.byteSchema or DLImageSchema.floatSchema. If
 * using DLImageReader, the default format is DLImageSchema.byteSchema
 * The output column is always DLImageSchema.floatSchema.
 *
 * @param transformer Single or a sequence of BigDL FeatureTransformers to be used. E.g.
 *                    Resize(256, 256) -> CenterCrop(224, 224) ->
 *                    ChannelNormalize(123, 117, 104, 1, 1, 1) -> MatToTensor()
 */
class DLImageTransformer (
    val transformer: Transformer[ImageFeature, ImageFeature],
    override val uid: String)
  extends DLTransformerBase with HasInputCol with HasOutputCol {

  def this(transformer: FeatureTransformer) =
    this(transformer, Identifiable.randomUID("DLImageTransformer"))

  setDefault(inputCol -> "image")
  def setInputCol(value: String): this.type = set(inputCol, value)

  setDefault(outputCol -> "output")
  def setOutputCol(value: String): this.type = set(outputCol, value)

  protected def validateInputType(inputType: DataType): Unit = {
    val validTypes = Array(DLImageSchema.floatSchema, DLImageSchema.byteSchema)

    require(validTypes.exists(t => SchemaUtils.sameType(inputType, t)),
      s"Bad input type: $inputType. Requires ${validTypes.mkString(", ")}")
  }

  override def transformSchema(schema: StructType): StructType = {
    val inputType = schema($(inputCol)).dataType
    validateInputType(inputType)
    if (schema.fieldNames.contains($(outputCol))) {
      throw new IllegalArgumentException(s"Output column ${$(outputCol)} already exists.")
    }

    val outputFields = schema.fields :+
      StructField($(outputCol), DLImageSchema.floatSchema, nullable = false)
    StructType(outputFields)
  }

  protected override def internalTransform(dataFrame: DataFrame): DataFrame = {
    transformSchema(dataFrame.schema, logging = true)
    val sc = dataFrame.sqlContext.sparkContext
    val localTransformer = this.transformer
    val transformerBC = sc.broadcast(localTransformer)
    val toTensorBC = sc.broadcast(MatToTensor[Float](shareBuffer = true))

    val inputColIndex = dataFrame.schema.fieldIndex($(inputCol))
    val resultRDD = dataFrame.rdd.mapPartitions { rowIter =>
      val localTransformer = transformerBC.value.cloneTransformer()
      val toTensorTransformer = toTensorBC.value.cloneTransformer().asInstanceOf[MatToTensor[Float]]
      rowIter.map { row =>
        val imf = DLImageSchema.row2IMF(row.getAs[Row](inputColIndex))
        val output = localTransformer.apply(Iterator(imf)).toArray.head
        if (!output.contains(ImageFeature.imageTensor)) {
          toTensorTransformer.transform(output)
        }
        Row.fromSeq(row.toSeq ++ Seq(DLImageSchema.imf2Row(output)))
      }
    }

    val resultSchema = transformSchema(dataFrame.schema)
    dataFrame.sqlContext.createDataFrame(resultRDD, resultSchema)
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy