org.apache.sysml.api.ml.PredictionUtils.scala Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of systemml Show documentation
Show all versions of systemml Show documentation
Declarative Machine Learning
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.api.ml
import org.apache.spark.sql.functions.udf
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.SparkContext
import org.apache.sysml.runtime.matrix.data.MatrixBlock
import org.apache.sysml.runtime.DMLRuntimeException
import org.apache.sysml.runtime.matrix.MatrixCharacteristics
import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtils
import org.apache.sysml.api.mlcontext.MLResults
import org.apache.sysml.api.mlcontext.ScriptFactory._
import org.apache.sysml.api.mlcontext.Script
import org.apache.sysml.api.mlcontext.BinaryBlockMatrix
object PredictionUtils {
def getGLMPredictionScript(B_full: BinaryBlockMatrix, isSingleNode:Boolean, dfam:java.lang.Integer=1): (Script, String) = {
val script = dml(ScriptsUtils.getDMLScript(LogisticRegressionModel.scriptPath))
.in("$X", " ")
.in("$B", " ")
.in("$dfam", dfam)
.out("means")
val ret = if(isSingleNode) {
script.in("B_full", B_full.getMatrixBlock, B_full.getMatrixMetadata)
}
else {
script.in("B_full", B_full)
}
(ret, "X")
}
def fillLabelMapping(df: ScriptsUtils.SparkDataType, revLabelMapping: java.util.HashMap[Int, String]): RDD[String] = {
val temp = df.select("label").distinct.rdd.map(_.apply(0).toString).collect()
val labelMapping = new java.util.HashMap[String, Int]
for(i <- 0 until temp.length) {
labelMapping.put(temp(i), i+1)
revLabelMapping.put(i+1, temp(i))
}
df.select("label").rdd.map( x => labelMapping.get(x.apply(0).toString).toString )
}
def fillLabelMapping(y_mb: MatrixBlock, revLabelMapping: java.util.HashMap[Int, String]): Unit = {
val labelMapping = new java.util.HashMap[String, Int]
if(y_mb.getNumColumns != 1) {
throw new RuntimeException("Expected a column vector for y")
}
if(y_mb.isInSparseFormat()) {
throw new DMLRuntimeException("Sparse block is not implemented for fit")
}
else {
val denseBlock = y_mb.getDenseBlock()
var id:Int = 1
for(i <- 0 until denseBlock.length) {
val v = denseBlock(i).toString()
if(!labelMapping.containsKey(v)) {
labelMapping.put(v, id)
revLabelMapping.put(id, v)
id += 1
}
denseBlock.update(i, labelMapping.get(v))
}
}
}
class LabelMappingData(val labelMapping: java.util.HashMap[Int, String]) extends Serializable {
def mapLabelStr(x:Double):String = {
if(labelMapping.containsKey(x.toInt))
labelMapping.get(x.toInt)
else
throw new RuntimeException("Incorrect label mapping")
}
def mapLabelDouble(x:Double):Double = {
if(labelMapping.containsKey(x.toInt))
labelMapping.get(x.toInt).toDouble
else
throw new RuntimeException("Incorrect label mapping")
}
val mapLabel_udf = {
try {
val it = labelMapping.values().iterator()
while(it.hasNext()) {
it.next().toDouble
}
udf(mapLabelDouble _)
} catch {
case e: Exception => udf(mapLabelStr _)
}
}
}
def updateLabels(isSingleNode:Boolean, df:DataFrame, X: MatrixBlock, labelColName:String, labelMapping: java.util.HashMap[Int, String]): DataFrame = {
if(isSingleNode) {
if(X.isInSparseFormat()) {
throw new RuntimeException("Since predicted label is a column vector, expected it to be in dense format")
}
for(i <- 0 until X.getNumRows) {
val v:Int = X.getValue(i, 0).toInt
if(labelMapping.containsKey(v)) {
X.setValue(i, 0, labelMapping.get(v).toDouble)
}
else {
throw new RuntimeException("No mapping found for " + v + " in " + labelMapping.toString())
}
}
return null
}
else {
val serObj = new LabelMappingData(labelMapping)
return df.withColumn(labelColName, serObj.mapLabel_udf(df(labelColName)))
.withColumnRenamed(labelColName, "prediction")
}
}
def joinUsingID(df1:DataFrame, df2:DataFrame):DataFrame = {
df1.join(df2, RDDConverterUtils.DF_ID_COLUMN)
}
def computePredictedClassLabelsFromProbability(mlscoreoutput:MLResults, isSingleNode:Boolean, sc:SparkContext, inProbVar:String): MLResults = {
val ml = new org.apache.sysml.api.mlcontext.MLContext(sc)
val script = dml(
"""
Prob = read("temp1");
Prediction = rowIndexMax(Prob); # assuming one-based label mapping
write(Prediction, "tempOut", "csv");
""").out("Prediction")
val probVar = mlscoreoutput.getBinaryBlockMatrix(inProbVar)
if(isSingleNode) {
ml.execute(script.in("Prob", probVar.getMatrixBlock, probVar.getMatrixMetadata))
}
else {
ml.execute(script.in("Prob", probVar))
}
}
}