
org.apache.flink.ml.MLUtils.scala Maven / Gradle / Ivy
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.ml
import org.apache.flink.api.common.functions.{RichFlatMapFunction, RichMapFunction}
import org.apache.flink.api.java.operators.DataSink
import org.apache.flink.api.scala._
import org.apache.flink.configuration.Configuration
import org.apache.flink.ml.common.LabeledVector
import org.apache.flink.ml.math.SparseVector
import org.apache.flink.util.Collector
/** Convenience functions for machine learning tasks
*
* This object contains convenience functions for machine learning tasks:
*
* - readLibSVM:
* Reads a libSVM/SVMLight input file and returns a data set of [[LabeledVector]].
* The file format is specified [http://svmlight.joachims.org/ here].
*
* - writeLibSVM:
* Writes a data set of [[LabeledVector]] in libSVM/SVMLight format to disk. THe file format
* is specified [http://svmlight.joachims.org/ here].
*/
object MLUtils {
val DIMENSION = "dimension"
/** Reads a file in libSVM/SVMLight format and converts the data into a data set of
* [[LabeledVector]]. The dimension of the [[LabeledVector]] is determined automatically.
*
* Since the libSVM/SVMLight format stores a vector in its sparse form, the [[LabeledVector]]
* will also be instantiated with a [[SparseVector]].
*
* @param env executionEnvironment [[ExecutionEnvironment]]
* @param filePath Path to the input file
* @return [[DataSet]] of [[LabeledVector]] containing the information of the libSVM/SVMLight
* file
*/
def readLibSVM(env: ExecutionEnvironment, filePath: String): DataSet[LabeledVector] = {
val labelCOODS = env.readTextFile(filePath).flatMap(
new RichFlatMapFunction[String, (Double, Array[(Int, Double)])] {
val splitPattern = "\\s+".r
override def flatMap(
line: String,
out: Collector[(Double, Array[(Int, Double)])]
): Unit = {
val commentFreeLine = line.takeWhile(_ != '#').trim
if (commentFreeLine.nonEmpty) {
val splits = splitPattern.split(commentFreeLine)
val label = splits.head.toDouble
val sparseFeatures = splits.tail
val coos = sparseFeatures.flatMap { str =>
val pair = str.split(':')
require(pair.length == 2, "Each feature entry has to have the form :")
// libSVM index is 1-based, but we expect it to be 0-based
val index = pair(0).toInt - 1
val value = pair(1).toDouble
Some((index, value))
}
out.collect((label, coos))
}
}
})
// Calculate maximum dimension of vectors
val dimensionDS = labelCOODS.map {
labelCOO =>
labelCOO._2.map( _._1 + 1 ).max
}.reduce(scala.math.max(_, _))
labelCOODS.map{ new RichMapFunction[(Double, Array[(Int, Double)]), LabeledVector] {
var dimension = 0
override def open(configuration: Configuration): Unit = {
dimension = getRuntimeContext.getBroadcastVariable(DIMENSION).get(0)
}
override def map(value: (Double, Array[(Int, Double)])): LabeledVector = {
new LabeledVector(value._1, SparseVector.fromCOO(dimension, value._2))
}
}}.withBroadcastSet(dimensionDS, DIMENSION)
}
/** Writes a [[DataSet]] of [[LabeledVector]] to a file using the libSVM/SVMLight format.
*
* @param filePath Path to output file
* @param labeledVectors [[DataSet]] of [[LabeledVector]] to write to disk
* @return
*/
def writeLibSVM(filePath: String, labeledVectors: DataSet[LabeledVector]): DataSink[String] = {
val stringRepresentation = labeledVectors.map{
labeledVector =>
val vectorStr = labeledVector.vector.
// remove zero entries
filter( _._2 != 0).
map{case (idx, value) => (idx + 1) + ":" + value}.
mkString(" ")
labeledVector.label + " " + vectorStr
}
stringRepresentation.writeAsText(filePath)
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy