All Downloads are FREE. Search and download functionalities are using the official Maven repository.

au.csiro.variantspark.utils.Projection.scala Maven / Gradle / Ivy

The newest version!
package au.csiro.variantspark.utils

import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.rdd.RDD

class Projector(indexSet: Set[Int], include: Boolean = true) extends Serializable {

  def projectVector(v: Vector): Vector = {
    val a = v.toArray
    Vectors.dense((for (i <- a.indices if indexSet.contains(i) == include) yield a(i)).toArray)
  }

  def projectArray(a: Array[Int]): Array[Int] = {
    (for (i <- a.indices if indexSet.contains(i) == include) yield a(i)).toArray
  }

  def inverted: Projector = new Projector(indexSet, !include)

  def toPair: (Projector, Projector) = (this, this.inverted)

  def indexes: Set[Int] = indexSet
}

object Projector {

  def apply(indexes: Array[Int], include: Boolean = true): Projector =
    new Projector(indexes.toSet, include)

  def subsample(v: Vector, fraction: Double): Projector =
    Projector(Sampling.subsampleFraction(v.size, fraction)(defRng))

  def split(v: Vector, fraction: Double): (Projector, Projector) = subsample(v, fraction).toPair

  def folds(v: Vector, nFolds: Int, testFolds: Boolean = true): List[Projector] =
    Sampling.folds(v.size, nFolds).map(Projector(_, testFolds))

  def splitRDD(rdd: RDD[Vector], fraction: Double): (Projector, Projector) =
    split(rdd.first, fraction)
  def rddFolds(rdd: RDD[Vector], nFolds: Int, testFolds: Boolean = true): List[Projector] =
    folds(rdd.first, nFolds, testFolds)

  // TODO: (Refactoring) Find a better place for these (if needed at all)

  def projectVector(indexSet: Set[Int], invert: Boolean = false)(v: Vector): Vector = {
    val a = v.toArray
    Vectors.dense((for (i <- a.indices if indexSet.contains(i) == !invert) yield a(i)).toArray)
  }

  def projectArray(indexSet: Set[Int], invert: Boolean = false)(a: Array[Int]): Array[Int] = {
    (for (i <- a.indices if indexSet.contains(i) == !invert) yield a(i)).toArray
  }
}

object RDDProjections {
  implicit def toVectorRDD(rdd: RDD[Vector]): VectorRDDFunction =
    new VectorRDDFunction(rdd)
  implicit def toIndexedVectorRDD(rdd: RDD[(Vector, Long)]): IndexedVectorRDDFunction =
    new IndexedVectorRDDFunction(rdd)
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy