All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.sysml.api.ml.PredictionUtils.scala Maven / Gradle / Ivy

There is a newer version: 1.2.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.sysml.api.ml

import org.apache.spark.sql.functions.udf
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.SparkContext
import org.apache.sysml.runtime.matrix.data.MatrixBlock
import org.apache.sysml.runtime.DMLRuntimeException
import org.apache.sysml.runtime.matrix.MatrixCharacteristics
import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtils
import org.apache.sysml.api.mlcontext.MLResults
import org.apache.sysml.api.mlcontext.ScriptFactory._
import org.apache.sysml.api.mlcontext.Script
import org.apache.sysml.api.mlcontext.BinaryBlockMatrix

object PredictionUtils {
  
  def getGLMPredictionScript(B_full: BinaryBlockMatrix, isSingleNode:Boolean, dfam:java.lang.Integer=1): (Script, String)  = {
    val script = dml(ScriptsUtils.getDMLScript(LogisticRegressionModel.scriptPath))
      .in("$X", " ")
      .in("$B", " ")
      .in("$dfam", dfam)
      .out("means")
    val ret = if(isSingleNode) {
      script.in("B_full", B_full.getMatrixBlock, B_full.getMatrixMetadata)
    }
    else {
      script.in("B_full", B_full)
    }
    (ret, "X")
  }
  
  def fillLabelMapping(df: ScriptsUtils.SparkDataType, revLabelMapping: java.util.HashMap[Int, String]): RDD[String]  = {
    val temp = df.select("label").distinct.rdd.map(_.apply(0).toString).collect()
    val labelMapping = new java.util.HashMap[String, Int]
    for(i <- 0 until temp.length) {
      labelMapping.put(temp(i), i+1)
      revLabelMapping.put(i+1, temp(i))
    }
    df.select("label").rdd.map( x => labelMapping.get(x.apply(0).toString).toString )
  }
  
  def fillLabelMapping(y_mb: MatrixBlock, revLabelMapping: java.util.HashMap[Int, String]): Unit = {
    val labelMapping = new java.util.HashMap[String, Int]
    if(y_mb.getNumColumns != 1) {
      throw new RuntimeException("Expected a column vector for y")
    }
    if(y_mb.isInSparseFormat()) {
      throw new DMLRuntimeException("Sparse block is not implemented for fit")
    }
    else {
      val denseBlock = y_mb.getDenseBlock()
      var id:Int = 1
      for(i <- 0 until denseBlock.length) {
        val v = denseBlock(i).toString()
        if(!labelMapping.containsKey(v)) {
          labelMapping.put(v, id)
          revLabelMapping.put(id, v)
          id += 1
        }
        denseBlock.update(i, labelMapping.get(v))
      }  
    }
  }
  
  class LabelMappingData(val labelMapping: java.util.HashMap[Int, String]) extends Serializable {
   def mapLabelStr(x:Double):String = {
     if(labelMapping.containsKey(x.toInt))
       labelMapping.get(x.toInt)
     else
       throw new RuntimeException("Incorrect label mapping")
   }
   def mapLabelDouble(x:Double):Double = {
     if(labelMapping.containsKey(x.toInt))
       labelMapping.get(x.toInt).toDouble
     else
       throw new RuntimeException("Incorrect label mapping")
   }
   val mapLabel_udf =  {
        try {
          val it = labelMapping.values().iterator()
          while(it.hasNext()) {
            it.next().toDouble
          }
          udf(mapLabelDouble _)
        } catch {
          case e: Exception => udf(mapLabelStr _)
        }
      }
  }  
  def updateLabels(isSingleNode:Boolean, df:DataFrame, X: MatrixBlock, labelColName:String, labelMapping: java.util.HashMap[Int, String]): DataFrame = {
    if(isSingleNode) {
      if(X.isInSparseFormat()) {
        throw new RuntimeException("Since predicted label is a column vector, expected it to be in dense format")
      }
      for(i <- 0 until X.getNumRows) {
        val v:Int = X.getValue(i, 0).toInt
        if(labelMapping.containsKey(v)) {
          X.setValue(i, 0, labelMapping.get(v).toDouble)
        }
        else {
          throw new RuntimeException("No mapping found for " + v + " in " + labelMapping.toString())
        }
      }
      return null
    }
    else {
      val serObj = new LabelMappingData(labelMapping)
      return df.withColumn(labelColName, serObj.mapLabel_udf(df(labelColName)))
               .withColumnRenamed(labelColName, "prediction")
    }
  }
  
  def joinUsingID(df1:DataFrame, df2:DataFrame):DataFrame = {
    df1.join(df2, RDDConverterUtils.DF_ID_COLUMN)
  }
  
  def computePredictedClassLabelsFromProbability(mlscoreoutput:MLResults, isSingleNode:Boolean, sc:SparkContext, inProbVar:String): MLResults = {
    val ml = new org.apache.sysml.api.mlcontext.MLContext(sc)
    val script = dml(
        """
        Prob = read("temp1");
        Prediction = rowIndexMax(Prob); # assuming one-based label mapping
        write(Prediction, "tempOut", "csv");
        """).out("Prediction")
    val probVar = mlscoreoutput.getBinaryBlockMatrix(inProbVar)
    if(isSingleNode) {
      ml.execute(script.in("Prob", probVar.getMatrixBlock, probVar.getMatrixMetadata))
    }
    else {
      ml.execute(script.in("Prob", probVar))
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy