com.spotify.scio.values.DoubleSCollectionFunctions.scala Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of scio-core_2.13 Show documentation
Show all versions of scio-core_2.13 Show documentation
Scio - A Scala API for Apache Beam and Google Cloud Dataflow
The newest version!
/*
* Copyright 2019 Spotify AB.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package com.spotify.scio.values
import com.spotify.scio.util.StatCounter
/** Extra functions available on SCollections of `Double`s through an implicit conversion. */
class DoubleSCollectionFunctions(self: SCollection[Double]) {
/**
* Return an SCollection with a single [[com.spotify.scio.util.StatCounter StatCounter]] object
* that captures the mean, variance and count of the SCollection's elements in one operation.
*/
def stats: SCollection[StatCounter] =
self.combine(StatCounter(_))(_.merge(_))(_.merge(_))
// Implemented in SCollection
// def mean: SCollection[Double] = this.stats().map(_.mean)
// Implemented in SCollection
// def sum: SCollection[Double] = this.stats().map(_.sum)
/** Compute the standard deviation of this SCollection's elements. */
def stdev: SCollection[Double] = self.transform(_.stats.map(_.stdev))
/** Compute the variance of this SCollection's elements. */
def variance: SCollection[Double] = self.transform(_.stats.map(_.variance))
/**
* Compute the sample standard deviation of this SCollection's elements (which corrects for bias
* in estimating the standard deviation by dividing by N-1 instead of N).
*/
def sampleStdev: SCollection[Double] =
self.transform(_.stats.map(_.sampleStdev))
/**
* Compute the sample variance of this SCollection's elements (which corrects for bias in
* estimating the variance by dividing by N-1 instead of N).
*/
def sampleVariance: SCollection[Double] =
self.transform(_.stats.map(_.sampleVariance))
// Ported from org.apache.spark.rdd.DoubleRDDFunctions
/**
* Compute a histogram of the data using `bucketCount` number of buckets evenly spaced between the
* minimum and maximum of the SCollection. For example if the min value is 0 and the max is 100
* and there are two buckets the resulting buckets will be `[0, 50)` `[50, 100]`. `bucketCount`
* must be at least 1. If the SCollection contains infinity, NaN throws an exception. If the
* elements in SCollection do not vary (max == min) always returns a single bucket.
*/
def histogram(bucketCount: Int): (SCollection[Array[Double]], SCollection[Array[Long]]) = {
// Compute the minimum and the maximum
val minMax =
self.aggregate((Double.PositiveInfinity, Double.NegativeInfinity))(
(acc, x) => (x.min(acc._1), x.max(acc._2)),
(l, r) => (l._1.min(r._1), l._2.max(r._2))
)
val buckets = minMax.map { case (min, max) =>
if (min.isNaN || max.isNaN || max.isInfinity || min.isInfinity) {
throw new UnsupportedOperationException(
"Histogram on either an empty SCollection or SCollection containing +/-infinity or NaN"
)
}
val range = if (min != max) {
// Range.Double.inclusive(min, max, increment)
// The above code doesn't always work. See Scala bug #SI-8782.
// https://issues.scala-lang.org/browse/SI-8782
val span = max - min
val steps = bucketCount
Range.Int(0, steps, 1).map(s => min + (s * span) / steps) :+ max
} else {
List(min, min)
}
range.toArray
}
(buckets, histogramImpl(buckets, true))
}
/**
* Compute a histogram using the provided buckets. The buckets are all open to the right except
* for the last which is closed e.g. for the array `[1, 10, 20, 50]` the buckets are `[1, 10) [10,
* 20) [20, 50]` e.g `1<=x<10`, `10<=x<20`, `20<=x<=50`. And on the input of 1 and 50 we would
* have a histogram of `[1, 0, 1]`.
*
* Note: if your histogram is evenly spaced (e.g. `[0, 10, 20, 30]`) this can be switched from an
* O(log n) insertion to O(1) per element. (where n = # buckets) if you set `evenBuckets` to true.
*
* buckets must be sorted and not contain any duplicates. buckets array must be at least two
* elements. All NaN entries are treated the same. If you have a NaN bucket it must be the maximum
* value of the last position and all NaN entries will be counted in that bucket.
*/
def histogram(buckets: Array[Double], evenBuckets: Boolean = false): SCollection[Array[Long]] =
histogramImpl(self.context.parallelize(Seq(buckets)), evenBuckets)
private def histogramImpl(
buckets: SCollection[Array[Double]],
evenBuckets: Boolean
): SCollection[Array[Long]] = {
import com.spotify.scio.values.BucketFunctions._
// Map buckets into a side input of bucket function
val side = buckets.map { b =>
require(b.length >= 2, "buckets array must have at least two elements")
// Decide which bucket function to pass to histogramPartition. We decide here
// rather than having a general function so that the decision need only be made
// once rather than once per shard
val bucketParams: Either[(Double, Double, Int), Array[Double]] = if (evenBuckets) {
Left((b.head, b.last, b.length - 1))
} else {
Right(b)
}
bucketParams
}.asSingletonSideInput
val bucketSize = buckets.map(_.length - 1)
val hist = self
.withSideInputs(side)
.flatMap { (x, c) =>
// Map values to buckets
val bucketFunction = c(side) match {
case Left(p) => (fastBucketFunction _).tupled(p)
case Right(b) => basicBucketFunction(b) _
}
bucketFunction(x).iterator
}
.toSCollection
.countByValue // Count occurrences of each bucket
.cross(bucketSize) // Replicate bucket size
.map { case ((bin, count), size) =>
val b = Array.fill(size)(0L)
b(bin) = count
b
}
.reduce { (x, y) =>
val r = x.clone()
var i = 0
while (i < x.length) {
r(i) += y(i)
i += 1
}
r
}
// Workaround since hist may be empty
val bSide = bucketSize.asSingletonSideInput
val hSide = hist.asListSideInput
self.context
.parallelize(Seq(0))
.withSideInputs(bSide, hSide)
.map { (_, c) =>
val h = c(hSide)
if (h.isEmpty) {
Array.fill(c(bSide))(0L)
} else {
h.head
}
}
.toSCollection
}
}
private object BucketFunctions {
// Basic bucket function. This works using Java's built in Array
// binary search. Takes log(size(buckets))
def basicBucketFunction(b: Array[Double])(e: Double): Option[Int] = {
val location = java.util.Arrays.binarySearch(b, e)
if (location < 0) {
// If the location is less than 0 then the insertion point in the array
// to keep it sorted is -location-1
val insertionPoint = -location - 1
// If we have to insert before the first element or after the last one
// its out of bounds.
// We do this rather than buckets.lengthCompare(insertionPoint)
// because Array[Double] fails to override it (for now).
if (insertionPoint > 0 && insertionPoint < b.length) {
Some(insertionPoint - 1)
} else {
None
}
} else if (location < b.length - 1) {
// Exact match, just insert here
Some(location)
} else {
// Exact match to the last element
Some(location - 1)
}
}
// Determine the bucket function in constant time. Requires that buckets are evenly spaced
def fastBucketFunction(min: Double, max: Double, count: Int)(e: Double): Option[Int] =
// If our input is not a number unless the increment is also NaN then we fail fast
if (e.isNaN || e < min || e > max) {
None
} else {
// Compute ratio of e's distance along range to total range first, for better precision
val bucketNumber = (((e - min) / (max - min)) * count).toInt
// should be less than count, but will equal count if e == max, in which case
// it's part of the last end-range-inclusive bucket, so return count-1
Some(math.min(bucketNumber, count - 1))
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy