au.csiro.variantspark.api.ImportanceAnalysis.scala Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of variant-spark_2.11 Show documentation
Show all versions of variant-spark_2.11 Show documentation
Genomic variants interpretation toolkit
The newest version!
package au.csiro.variantspark.api
import java.util
import au.csiro.pbdava.ssparkle.spark.SparkUtils
import au.csiro.variantspark.algo.{RandomForest, RandomForestModel, RandomForestParams}
import au.csiro.variantspark.data.BoundedOrdinalVariable
import au.csiro.variantspark.input.{FeatureSource, LabelSource}
import it.unimi.dsi.fastutil.longs.Long2DoubleOpenHashMap
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructType}
import scala.collection.JavaConverters._
/**
* A class to represent an instance of the Importance Analysis
*
* @constructor Create a new `Importance Analysis` by specifying the parameters listed below
*
* @param sqlContext The SQL context.
* @param featureSource The feature source.
* @param labelSource The label source.
* @param rfParams The Random Forest parameters.
* @param nTrees The number of Decision Trees.
* @param rfBatchSize The batch size of the Random Forest
* @param varOrdinalLevels The level of ordinal
*
* @example class ImportanceAnalysis(featureSource, labelSource, nTrees = 1000)
*/
class ImportanceAnalysis(val sqlContext: SQLContext, val featureSource: FeatureSource,
val labelSource: LabelSource, val rfParams: RandomForestParams, val nTrees: Int,
val rfBatchSize: Int, varOrdinalLevels: Int) {
private def sc = featureSource.features.sparkContext
private lazy val inputData = featureSource.features.zipWithIndex().cache()
val variableImportanceSchema: StructType =
StructType(Seq(StructField("variable", StringType, true),
StructField("importance", DoubleType, true)))
lazy val rfModel: RandomForestModel = {
val labels = labelSource.getLabels(featureSource.sampleNames)
val rf = new RandomForest(rfParams)
rf.batchTrain(inputData, labels, nTrees, rfBatchSize)
}
val oobError: Double = rfModel.oobError
private lazy val br_normalizedVariableImportance = {
val indexImportance = rfModel.normalizedVariableImportance()
sc.broadcast(
new Long2DoubleOpenHashMap(
indexImportance.asInstanceOf[Map[java.lang.Long, java.lang.Double]].asJava))
}
def variableImportance: DataFrame = {
val local_br_normalizedVariableImportance = br_normalizedVariableImportance
val importanceRDD = inputData.map({
case (f, i) => Row(f.label, local_br_normalizedVariableImportance.value.get(i))
})
sqlContext.createDataFrame(importanceRDD, variableImportanceSchema)
}
def importantVariables(nTopLimit: Int = 100): Seq[(String, Double)] = {
// build index for names
val topImportantVariables =
rfModel.normalizedVariableImportance().toSeq.sortBy(-_._2).take(nTopLimit)
val topImportantVariableIndexes = topImportantVariables.map(_._1).toSet
val index =
SparkUtils.withBroadcast(featureSource.features.sparkContext)(topImportantVariableIndexes) {
br_indexes =>
inputData
.filter(t => br_indexes.value.contains(t._2))
.map({ case (f, i) => (i, f.label) })
.collectAsMap()
}
topImportantVariables.map({ case (i, importance) => (index(i), importance) })
}
def importantVariablesJavaMap(nTopLimit: Int = 100): util.Map[String, Double] = {
val impVarMap = collection.mutable.Map(importantVariables(nTopLimit).toMap.toSeq: _*)
impVarMap.map { case (k, v) => k -> double2Double(v) }
impVarMap.asJava
}
}
object ImportanceAnalysis {
val defaultRFParams: RandomForestParams = RandomForestParams()
def apply(featureSource: FeatureSource, labelSource: LabelSource, nTrees: Int = 1000,
mtryFraction: Option[Double] = None, oob: Boolean = true, seed: Option[Long] = None,
batchSize: Int = 100, varOrdinalLevels: Int = 3)(
implicit vsContext: SqlContextHolder): ImportanceAnalysis = {
new ImportanceAnalysis(vsContext.sqlContext, featureSource, labelSource,
rfParams = RandomForestParams(
nTryFraction = mtryFraction.getOrElse(defaultRFParams.nTryFraction),
seed = seed.getOrElse(defaultRFParams.seed), oob = oob),
nTrees = nTrees, rfBatchSize = batchSize, varOrdinalLevels = varOrdinalLevels)
}
def fromParams(featureSource: FeatureSource, labelSource: LabelSource,
rfParams: RandomForestParams, nTrees: Int = 1000, batchSize: Int = 100,
varOrdinalLevels: Int = 3)(implicit vsContext: SqlContextHolder): ImportanceAnalysis = {
new ImportanceAnalysis(vsContext.sqlContext, featureSource, labelSource, rfParams = rfParams,
nTrees = nTrees, rfBatchSize = batchSize, varOrdinalLevels = varOrdinalLevels)
}
}