All Downloads are FREE. Search and download functionalities are using the official Maven repository.

eu.stratosphere.api.scala.analysis.UserDefinedFunction.scala Maven / Gradle / Ivy

/***********************************************************************************************************************
 * Copyright (C) 2010-2013 by the Stratosphere project (http://stratosphere.eu)
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
 * the License. You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
 * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations under the License.
 **********************************************************************************************************************/

package eu.stratosphere.api.scala.analysis

import scala.collection.mutable

abstract class UDF[R] extends Serializable {

  val outputUDT: UDT[R]
  val outputFields = FieldSet.newOutputSet(outputUDT)

  def getOutputSerializer = outputUDT.getSerializer(outputFields.toSerializerIndexArray)

  def getOutputLength = {
    val indexes = outputFields.toIndexSet
    if (indexes.isEmpty) {
      0 
    } else {
      indexes.max + 1
    }
  }

  def allocateOutputGlobalIndexes(startPos: Int): Int = {

    outputFields.setGlobalized()

    outputFields.map(_.globalPos).foldLeft(startPos) {
      case (i, gPos @ GlobalPos.Unknown()) => gPos.setIndex(i); i + 1
      case (i, _)                          => i
    }
    startPos
  }

  def assignOutputGlobalIndexes(sameAs: FieldSet[Field]): Unit = {

    outputFields.setGlobalized()

    outputFields.foreach {
      case OutputField(localPos, globalPos) => globalPos.setReference(sameAs(localPos).globalPos)
    }
  }

  def setOutputGlobalIndexes(startPos: Int, sameAs: Option[FieldSet[Field]]): Int = sameAs match {
    case None         => allocateOutputGlobalIndexes(startPos)
    case Some(sameAs) => assignOutputGlobalIndexes(sameAs); startPos
  }

  def attachOutputsToInputs(inputFields: FieldSet[InputField]): Unit = {

    inputFields.setGlobalized()

    inputFields.foreach {
      case InputField(localPos, globalPos) => globalPos.setReference(outputFields(localPos).globalPos)
    }
  }

  protected def markFieldCopied(inputGlobalPos: GlobalPos, outputLocalPos: Int): Unit = {
    val outputField = outputFields(outputLocalPos)
    outputField.globalPos.setReference(inputGlobalPos)
    outputField.isUsed = false
  }
}

class UDF0[R](val outputUDT: UDT[R]) extends UDF[R]

class UDF1[T, R](val inputUDT: UDT[T], val outputUDT: UDT[R]) extends UDF[R] {

  val inputFields = FieldSet.newInputSet(inputUDT)
  val forwardSet = mutable.Set[(InputField, OutputField)]()
  val discardSet = mutable.Set[GlobalPos]()

  def getInputDeserializer = inputUDT.getSerializer(inputFields.toSerializerIndexArray)
  def getForwardIndexSetFrom = forwardSet.map(_._1.localPos)
  def getForwardIndexSetTo = forwardSet.map(_._2.localPos)
  def getForwardIndexArrayFrom = getForwardIndexSetFrom.toArray
  def getForwardIndexArrayTo = getForwardIndexSetTo.toArray
  def getDiscardIndexArray = discardSet.map(_.getValue).toArray

  override def getOutputLength = {
    val forwardMax = if (forwardSet.isEmpty) -1 else forwardSet.map(_._2.localPos).max
    math.max(super.getOutputLength, forwardMax + 1)
  }

  def markInputFieldUnread(localPos: Int): Unit = {
    inputFields(localPos).isUsed = false
  }

  def markFieldCopied(inputLocalPos: Int, outputLocalPos: Int): Unit = {
    val inputField = inputFields(inputLocalPos)
    val inputGlobalPos = inputField.globalPos
    forwardSet.add((inputField, outputFields(outputLocalPos)))
    markFieldCopied(inputGlobalPos, outputLocalPos)
  }
}

class UDF2[T1, T2, R](val leftInputUDT: UDT[T1], val rightInputUDT: UDT[T2], val outputUDT: UDT[R]) extends UDF[R] {

  val leftInputFields = FieldSet.newInputSet(leftInputUDT)
  val leftForwardSet = mutable.Set[(InputField, OutputField)]()
  val leftDiscardSet = mutable.Set[GlobalPos]()

  val rightInputFields = FieldSet.newInputSet(rightInputUDT)
  val rightForwardSet = mutable.Set[(InputField, OutputField)]()
  val rightDiscardSet = mutable.Set[GlobalPos]()

  def getLeftInputDeserializer = leftInputUDT.getSerializer(leftInputFields.toSerializerIndexArray)
  def getLeftForwardIndexSetFrom = leftForwardSet.map(_._1.localPos)
  def getLeftForwardIndexSetTo = leftForwardSet.map(_._2.localPos)
  def getLeftForwardIndexArrayFrom = getLeftForwardIndexSetFrom.toArray
  def getLeftForwardIndexArrayTo = getLeftForwardIndexSetTo.toArray
  def getLeftDiscardIndexArray = leftDiscardSet.map(_.getValue).toArray

  def getRightInputDeserializer = rightInputUDT.getSerializer(rightInputFields.toSerializerIndexArray)
  def getRightForwardIndexSetFrom = rightForwardSet.map(_._1.localPos)
  def getRightForwardIndexSetTo = rightForwardSet.map(_._2.localPos)
  def getRightForwardIndexArrayFrom = getRightForwardIndexSetFrom.toArray
  def getRightForwardIndexArrayTo = getRightForwardIndexSetTo.toArray
  def getRightDiscardIndexArray = rightDiscardSet.map(_.getValue).toArray

  override def getOutputLength = {
    val leftForwardMax = if (leftForwardSet.isEmpty) -1 else leftForwardSet.map(_._2.localPos).max
    val rightForwardMax = if (rightForwardSet.isEmpty) -1 else rightForwardSet.map(_._2.localPos).max
    math.max(super.getOutputLength, math.max(leftForwardMax, rightForwardMax) + 1)
  }

  private def getInputField(localPos: Either[Int, Int]): InputField = localPos match {
    case Left(pos)  => leftInputFields(pos)
    case Right(pos) => rightInputFields(pos)
  }

  def markInputFieldUnread(localPos: Either[Int, Int]): Unit = {
    localPos.fold(leftInputFields(_), rightInputFields(_)).isUsed = false
  }

  def markFieldCopied(inputLocalPos: Either[Int, Int], outputLocalPos: Int): Unit = {
    val (inputFields, forwardSet) = inputLocalPos.fold(_ => (leftInputFields, leftForwardSet), _ => (rightInputFields, rightForwardSet))
    val inputField = inputFields(inputLocalPos.merge)
    val inputGlobalPos = inputField.globalPos
    forwardSet.add((inputField, outputFields(outputLocalPos)))
    markFieldCopied(inputGlobalPos, outputLocalPos)
  }
}





© 2015 - 2025 Weber Informatics LLC | Privacy Policy