All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.flink.ml.MLUtils.scala Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.ml

import org.apache.flink.api.common.functions.{RichFlatMapFunction, RichMapFunction}
import org.apache.flink.api.java.operators.DataSink
import org.apache.flink.api.scala._
import org.apache.flink.configuration.Configuration
import org.apache.flink.ml.common.LabeledVector
import org.apache.flink.ml.math.SparseVector
import org.apache.flink.util.Collector

/** Convenience functions for machine learning tasks
  *
  * This object contains convenience functions for machine learning tasks:
  *
  * - readLibSVM:
  *   Reads a libSVM/SVMLight input file and returns a data set of [[LabeledVector]].
  *   The file format is specified [http://svmlight.joachims.org/ here].
  *
  * - writeLibSVM:
  *   Writes a data set of [[LabeledVector]] in libSVM/SVMLight format to disk. THe file format
  *   is specified [http://svmlight.joachims.org/ here].
  */
object MLUtils {

  val DIMENSION = "dimension"

  /** Reads a file in libSVM/SVMLight format and converts the data into a data set of
    * [[LabeledVector]]. The dimension of the [[LabeledVector]] is determined automatically.
    *
    * Since the libSVM/SVMLight format stores a vector in its sparse form, the [[LabeledVector]]
    * will also be instantiated with a [[SparseVector]].
    *
    * @param env executionEnvironment [[ExecutionEnvironment]]
    * @param filePath Path to the input file
    * @return [[DataSet]] of [[LabeledVector]] containing the information of the libSVM/SVMLight
    *        file
    */
  def readLibSVM(env: ExecutionEnvironment, filePath: String): DataSet[LabeledVector] = {
    val labelCOODS = env.readTextFile(filePath).flatMap(
      new RichFlatMapFunction[String, (Double, Array[(Int, Double)])] {
        val splitPattern = "\\s+".r

        override def flatMap(
          line: String,
          out: Collector[(Double, Array[(Int, Double)])]
        ): Unit = {
          val commentFreeLine = line.takeWhile(_ != '#').trim

          if (commentFreeLine.nonEmpty) {
            val splits = splitPattern.split(commentFreeLine)
            val label = splits.head.toDouble
            val sparseFeatures = splits.tail
            val coos = sparseFeatures.flatMap { str =>
              val pair = str.split(':')
              require(pair.length == 2, "Each feature entry has to have the form :")

              // libSVM index is 1-based, but we expect it to be 0-based
              val index = pair(0).toInt - 1
              val value = pair(1).toDouble

              Some((index, value))
            }

            out.collect((label, coos))
          }
        }
      })

    // Calculate maximum dimension of vectors
    val dimensionDS = labelCOODS.map {
      labelCOO =>
        labelCOO._2.map( _._1 + 1 ).max
    }.reduce(scala.math.max(_, _))

    labelCOODS.map{ new RichMapFunction[(Double, Array[(Int, Double)]), LabeledVector] {
      var dimension = 0

      override def open(configuration: Configuration): Unit = {
        dimension = getRuntimeContext.getBroadcastVariable(DIMENSION).get(0)
      }

      override def map(value: (Double, Array[(Int, Double)])): LabeledVector = {
        new LabeledVector(value._1, SparseVector.fromCOO(dimension, value._2))
      }
    }}.withBroadcastSet(dimensionDS, DIMENSION)
  }

  /** Writes a [[DataSet]] of [[LabeledVector]] to a file using the libSVM/SVMLight format.
    * 
    * @param filePath Path to output file
    * @param labeledVectors [[DataSet]] of [[LabeledVector]] to write to disk
    * @return
    */
  def writeLibSVM(filePath: String, labeledVectors: DataSet[LabeledVector]): DataSink[String] = {
    val stringRepresentation = labeledVectors.map{
      labeledVector =>
        val vectorStr = labeledVector.vector.
          // remove zero entries
          filter( _._2 != 0).
          map{case (idx, value) => (idx + 1) + ":" + value}.
          mkString(" ")

        labeledVector.label + " " + vectorStr
    }

    stringRepresentation.writeAsText(filePath)
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy