All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.sql.execution.r.MapPartitionsRWrapper.scala Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.sql.execution.r

import org.apache.spark.api.r._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.Row
import org.apache.spark.sql.api.r.SQLUtils._
import org.apache.spark.sql.types.StructType

/**
 * A function wrapper that applies the given R function to each partition.
 */
case class MapPartitionsRWrapper(
    func: Array[Byte],
    packageNames: Array[Byte],
    broadcastVars: Array[Broadcast[Object]],
    inputSchema: StructType,
    outputSchema: StructType) extends (Iterator[Any] => Iterator[Any]) {
  def apply(iter: Iterator[Any]): Iterator[Any] = {
    // If the content of current DataFrame is serialized R data?
    val isSerializedRData = inputSchema == SERIALIZED_R_DATA_SCHEMA

    val (newIter, deserializer, colNames) =
      if (!isSerializedRData) {
        // Serialize each row into a byte array that can be deserialized in the R worker
        (iter.asInstanceOf[Iterator[Row]].map {row => rowToRBytes(row)},
         SerializationFormats.ROW, inputSchema.fieldNames)
      } else {
        (iter.asInstanceOf[Iterator[Row]].map { row => row(0) }, SerializationFormats.BYTE, null)
      }

    val serializer = if (outputSchema != SERIALIZED_R_DATA_SCHEMA) {
      SerializationFormats.ROW
    } else {
      SerializationFormats.BYTE
    }

    val runner = new RRunner[Any, Array[Byte]](
      func, deserializer, serializer, packageNames, broadcastVars,
      isDataFrame = true, colNames = colNames, mode = RRunnerModes.DATAFRAME_DAPPLY)
    // Partition index is ignored. Dataset has no support for mapPartitionsWithIndex.
    val outputIter = runner.compute(newIter, -1)

    if (serializer == SerializationFormats.ROW) {
      outputIter.map { bytes => bytesToRow(bytes, outputSchema) }
    } else {
      outputIter.map { bytes => Row.fromSeq(Seq(bytes)) }
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy