All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.flink.ml.common.FlinkMLTools.scala Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.ml.common

import org.apache.flink.api.common.functions.Partitioner
import org.apache.flink.api.common.io.FileOutputFormat.OutputDirectoryMode
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.io.{TypeSerializerInputFormat, TypeSerializerOutputFormat}
import org.apache.flink.api.scala._
import org.apache.flink.core.fs.FileSystem.WriteMode
import org.apache.flink.core.fs.Path

import scala.reflect.ClassTag

/** FlinkTools contains a set of convenience functions for Flink's machine learning library:
  *
  *  - persist:
  *  Takes up to 5 [[DataSet]]s and file paths. Each [[DataSet]] is written to the specified
  *  path and subsequently re-read from disk. This method can be used to effectively split the
  *  execution graph at the given [[DataSet]]. Writing it to disk triggers its materialization
  *  and specifying it as a source will prevent the re-execution of it.
  *
  *  - block:
  *  Takes a DataSet of elements T and groups them in n blocks.
  *
  */
object FlinkMLTools {

  /** Registers the different FlinkML related types for Kryo serialization
    *
    * @param env
    */
  def registerFlinkMLTypes(env: ExecutionEnvironment): Unit = {

    // Vector types
    env.registerType(classOf[org.apache.flink.ml.math.DenseVector])
    env.registerType(classOf[org.apache.flink.ml.math.SparseVector])

    // Matrix types
    env.registerType(classOf[org.apache.flink.ml.math.DenseMatrix])
    env.registerType(classOf[org.apache.flink.ml.math.SparseMatrix])

    // Breeze Vector types
    env.registerType(classOf[breeze.linalg.DenseVector[_]])
    env.registerType(classOf[breeze.linalg.SparseVector[_]])

    // Breeze specialized types
    env.registerType(breeze.linalg.DenseVector.zeros[Double](0).getClass)
    env.registerType(breeze.linalg.SparseVector.zeros[Double](0).getClass)

    // Breeze Matrix types
    env.registerType(classOf[breeze.linalg.DenseMatrix[Double]])
    env.registerType(classOf[breeze.linalg.CSCMatrix[Double]])

    // Breeze specialized types
    env.registerType(breeze.linalg.DenseMatrix.zeros[Double](0, 0).getClass)
    env.registerType(breeze.linalg.CSCMatrix.zeros[Double](0, 0).getClass)
  }

  /** Writes a [[DataSet]] to the specified path and returns it as a DataSource for subsequent
    * operations.
    *
    * @param dataset [[DataSet]] to write to disk
    * @param path File path to write dataset to
    * @tparam T Type of the [[DataSet]] elements
    * @return [[DataSet]] reading the just written file
    */
  def persist[T: ClassTag: TypeInformation](dataset: DataSet[T], path: String): DataSet[T] = {
    val env = dataset.getExecutionEnvironment
    val outputFormat = new TypeSerializerOutputFormat[T]

    val filePath = new Path(path)

    outputFormat.setOutputFilePath(filePath)
    outputFormat.setWriteMode(WriteMode.OVERWRITE)

    dataset.output(outputFormat)
    env.execute("FlinkTools persist")

    val inputFormat = new TypeSerializerInputFormat[T](dataset.getType)
    inputFormat.setFilePath(filePath)

    env.createInput(inputFormat)
  }

  /** Writes multiple [[DataSet]]s to the specified paths and returns them as DataSources for
    * subsequent operations.
    *
    * @param ds1 First [[DataSet]] to write to disk
    * @param ds2 Second [[DataSet]] to write to disk
    * @param path1 Path for ds1
    * @param path2 Path for ds2
    * @tparam A Type of the first [[DataSet]]'s elements
    * @tparam B Type of the second [[DataSet]]'s elements
    * @return Tuple of [[DataSet]]s reading the just written files
    */
  def persist[A: ClassTag: TypeInformation ,B: ClassTag: TypeInformation](ds1: DataSet[A], ds2:
  DataSet[B], path1: String, path2: String):(DataSet[A], DataSet[B])  = {
    val env = ds1.getExecutionEnvironment

    val f1 = new Path(path1)

    val of1 = new TypeSerializerOutputFormat[A]
    of1.setOutputFilePath(f1)
    of1.setWriteMode(WriteMode.OVERWRITE)

    ds1.output(of1)

    val f2 = new Path(path2)

    val of2 = new TypeSerializerOutputFormat[B]
    of2.setOutputFilePath(f2)
    of2.setWriteMode(WriteMode.OVERWRITE)

    ds2.output(of2)

    env.execute("FlinkTools persist")

    val if1 = new TypeSerializerInputFormat[A](ds1.getType)
    if1.setFilePath(f1)

    val if2 = new TypeSerializerInputFormat[B](ds2.getType)
    if2.setFilePath(f2)

    (env.createInput(if1), env.createInput(if2))
  }

  /** Writes multiple [[DataSet]]s to the specified paths and returns them as DataSources for
    * subsequent operations.
    *
    * @param ds1 First [[DataSet]] to write to disk
    * @param ds2 Second [[DataSet]] to write to disk
    * @param ds3 Third [[DataSet]] to write to disk
    * @param path1 Path for ds1
    * @param path2 Path for ds2
    * @param path3 Path for ds3
    * @tparam A Type of first [[DataSet]]'s elements
    * @tparam B Type of second [[DataSet]]'s elements
    * @tparam C Type of third [[DataSet]]'s elements
    * @return Tuple of [[DataSet]]s reading the just written files
    */
  def persist[A: ClassTag: TypeInformation ,B: ClassTag: TypeInformation,
  C: ClassTag: TypeInformation](ds1: DataSet[A], ds2:  DataSet[B], ds3: DataSet[C], path1:
  String, path2: String, path3: String): (DataSet[A], DataSet[B], DataSet[C])  = {
    val env = ds1.getExecutionEnvironment

    val f1 = new Path(path1)

    val of1 = new TypeSerializerOutputFormat[A]
    of1.setOutputFilePath(f1)
    of1.setWriteMode(WriteMode.OVERWRITE)

    ds1.output(of1)

    val f2 = new Path(path2)

    val of2 = new TypeSerializerOutputFormat[B]
    of2.setOutputFilePath(f2)
    of2.setWriteMode(WriteMode.OVERWRITE)

    ds2.output(of2)

    val f3 = new Path(path3)

    val of3 = new TypeSerializerOutputFormat[C]
    of3.setOutputFilePath(f3)
    of3.setWriteMode(WriteMode.OVERWRITE)

    ds3.output(of3)

    env.execute("FlinkTools persist")

    val if1 = new TypeSerializerInputFormat[A](ds1.getType)
    if1.setFilePath(f1)

    val if2 = new TypeSerializerInputFormat[B](ds2.getType)
    if2.setFilePath(f2)

    val if3 = new TypeSerializerInputFormat[C](ds3.getType)
    if3.setFilePath(f3)

    (env.createInput(if1), env.createInput(if2), env.createInput(if3))
  }

  /** Writes multiple [[DataSet]]s to the specified paths and returns them as DataSources for
    * subsequent operations.
    *
    * @param ds1 First [[DataSet]] to write to disk
    * @param ds2 Second [[DataSet]] to write to disk
    * @param ds3 Third [[DataSet]] to write to disk
    * @param ds4 Fourth [[DataSet]] to write to disk
    * @param path1 Path for ds1
    * @param path2 Path for ds2
    * @param path3 Path for ds3
    * @param path4 Path for ds4
    * @tparam A Type of first [[DataSet]]'s elements
    * @tparam B Type of second [[DataSet]]'s elements
    * @tparam C Type of third [[DataSet]]'s elements
    * @tparam D Type of fourth [[DataSet]]'s elements
    * @return Tuple of [[DataSet]]s reading the just written files
    */
  def persist[A: ClassTag: TypeInformation ,B: ClassTag: TypeInformation,
  C: ClassTag: TypeInformation, D: ClassTag: TypeInformation](ds1: DataSet[A], ds2:  DataSet[B],
                                                              ds3: DataSet[C], ds4: DataSet[D],
                                                              path1: String, path2: String, path3:
                                                              String, path4: String):
  (DataSet[A], DataSet[B], DataSet[C], DataSet[D])  = {
    val env = ds1.getExecutionEnvironment

    val f1 = new Path(path1)

    val of1 = new TypeSerializerOutputFormat[A]
    of1.setOutputFilePath(f1)
    of1.setWriteMode(WriteMode.OVERWRITE)

    ds1.output(of1)

    val f2 = new Path(path2)

    val of2 = new TypeSerializerOutputFormat[B]
    of2.setOutputFilePath(f2)
    of2.setWriteMode(WriteMode.OVERWRITE)

    ds2.output(of2)

    val f3 = new Path(path3)

    val of3 = new TypeSerializerOutputFormat[C]
    of3.setOutputFilePath(f3)
    of3.setWriteMode(WriteMode.OVERWRITE)

    ds3.output(of3)

    val f4 = new Path(path4)

    val of4 = new TypeSerializerOutputFormat[D]
    of4.setOutputFilePath(f4)
    of4.setWriteMode(WriteMode.OVERWRITE)

    ds4.output(of4)

    env.execute("FlinkTools persist")

    val if1 = new TypeSerializerInputFormat[A](ds1.getType)
    if1.setFilePath(f1)

    val if2 = new TypeSerializerInputFormat[B](ds2.getType)
    if2.setFilePath(f2)

    val if3 = new TypeSerializerInputFormat[C](ds3.getType)
    if3.setFilePath(f3)

    val if4 = new TypeSerializerInputFormat[D](ds4.getType)
    if4.setFilePath(f4)

    (env.createInput(if1), env.createInput(if2), env.createInput(if3), env.createInput(if4))
  }

  /** Writes multiple [[DataSet]]s to the specified paths and returns them as DataSources for
    * subsequent operations.
    *
    * @param ds1 First [[DataSet]] to write to disk
    * @param ds2 Second [[DataSet]] to write to disk
    * @param ds3 Third [[DataSet]] to write to disk
    * @param ds4 Fourth [[DataSet]] to write to disk
    * @param ds5 Fifth [[DataSet]] to write to disk
    * @param path1 Path for ds1
    * @param path2 Path for ds2
    * @param path3 Path for ds3
    * @param path4 Path for ds4
    * @param path5 Path for ds5
    * @tparam A Type of first [[DataSet]]'s elements
    * @tparam B Type of second [[DataSet]]'s elements
    * @tparam C Type of third [[DataSet]]'s elements
    * @tparam D Type of fourth [[DataSet]]'s elements
    * @tparam E Type of fifth [[DataSet]]'s elements
    * @return Tuple of [[DataSet]]s reading the just written files
    */
  def persist[A: ClassTag: TypeInformation ,B: ClassTag: TypeInformation,
  C: ClassTag: TypeInformation, D: ClassTag: TypeInformation, E: ClassTag: TypeInformation]
  (ds1: DataSet[A], ds2:  DataSet[B], ds3: DataSet[C], ds4: DataSet[D], ds5: DataSet[E], path1:
  String, path2: String, path3: String, path4: String, path5: String): (DataSet[A], DataSet[B],
    DataSet[C], DataSet[D], DataSet[E])  = {
    val env = ds1.getExecutionEnvironment

    val f1 = new Path(path1)

    val of1 = new TypeSerializerOutputFormat[A]
    of1.setOutputFilePath(f1)
    of1.setWriteMode(WriteMode.OVERWRITE)

    ds1.output(of1)

    val f2 = new Path(path2)

    val of2 = new TypeSerializerOutputFormat[B]
    of2.setOutputFilePath(f2)
    of2.setOutputDirectoryMode(OutputDirectoryMode.ALWAYS)
    of2.setWriteMode(WriteMode.OVERWRITE)

    ds2.output(of2)

    val f3 = new Path(path3)

    val of3 = new TypeSerializerOutputFormat[C]
    of3.setOutputFilePath(f3)
    of3.setWriteMode(WriteMode.OVERWRITE)

    ds3.output(of3)

    val f4 = new Path(path4)

    val of4 = new TypeSerializerOutputFormat[D]
    of4.setOutputFilePath(f4)
    of4.setWriteMode(WriteMode.OVERWRITE)

    ds4.output(of4)

    val f5 = new Path(path5)

    val of5 = new TypeSerializerOutputFormat[E]
    of5.setOutputFilePath(f5)
    of5.setWriteMode(WriteMode.OVERWRITE)

    ds5.output(of5)

    env.execute("FlinkTools persist")

    val if1 = new TypeSerializerInputFormat[A](ds1.getType)
    if1.setFilePath(f1)

    val if2 = new TypeSerializerInputFormat[B](ds2.getType)
    if2.setFilePath(f2)

    val if3 = new TypeSerializerInputFormat[C](ds3.getType)
    if3.setFilePath(f3)

    val if4 = new TypeSerializerInputFormat[D](ds4.getType)
    if4.setFilePath(f4)

    val if5 = new TypeSerializerInputFormat[E](ds5.getType)
    if5.setFilePath(f5)

    (env.createInput(if1), env.createInput(if2), env.createInput(if3), env.createInput(if4), env
      .createInput(if5))
  }

  /** Groups the DataSet input into numBlocks blocks.
    * 
    * @param input
    * @param numBlocks Number of Blocks
    * @param partitionerOption Optional partitioner to control the partitioning
    * @tparam T
    * @return
    */
  def block[T: TypeInformation: ClassTag](
    input: DataSet[T],
    numBlocks: Int,
    partitionerOption: Option[Partitioner[Int]] = None)
  : DataSet[Block[T]] = {
    val blockIDInput = input map {
      element =>
        val blockID = element.hashCode() % numBlocks

        val blockIDResult = if(blockID < 0){
          blockID + numBlocks
        } else {
          blockID
        }

        (blockIDResult, element)
    }

    val preGroupBlockIDInput = partitionerOption match {
      case Some(partitioner) =>
        blockIDInput partitionCustom(partitioner, 0)

      case None => blockIDInput
    }

    preGroupBlockIDInput.groupBy(0).reduceGroup {
      iter => {
        val array = iter.toVector

        val blockID = array(0)._1
        val elements = array.map(_._2)

        Block[T](blockID, elements)
      }
    }.withForwardedFields("0 -> index")
  }

  /** Distributes the elements by taking the modulo of their keys and assigning it to this channel
    *
    */
  object ModuloKeyPartitioner extends Partitioner[Int] {
    override def partition(key: Int, numPartitions: Int): Int = {
      val result = key % numPartitions

      if(result < 0) {
        result + numPartitions
      } else {
        result
      }
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy