com.spotify.scio.values.SampleSCollectionFunctions.scala Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of scio-core_2.13 Show documentation
Show all versions of scio-core_2.13 Show documentation
Scio - A Scala API for Apache Beam and Google Cloud Dataflow
The newest version!
/*
* Copyright 2024 Spotify AB
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.spotify.scio.values
import com.spotify.scio.util.ScioUtil
import com.spotify.scio.util.random.{BernoulliSampler, PoissonSampler}
import org.apache.beam.sdk.transforms.Sample
import scala.collection.mutable
import scala.jdk.CollectionConverters._
import scala.util.Random
object SampleSCollectionFunctions {
// weight-based sampling models
final private case class WeightedSample[T](id: Long, value: T, weight: Long)
private object WeightedSample {
implicit def ordering[T]: Ordering[WeightedSample[T]] =
Ordering.by[WeightedSample[T], Long](_.id).reverse
}
final private case class WeightedCombiner[T](
weight: Long = 0L,
queue: mutable.PriorityQueue[WeightedSample[T]] = mutable.PriorityQueue.empty[WeightedSample[T]]
)
}
class SampleSCollectionFunctions[T](self: SCollection[T]) {
import SampleSCollectionFunctions._
/**
* Return a sampled subset of this SCollection containing exactly `sampleSize` items. Involves
* combine operation resulting in shuffling. All the elements of the output should fit into main
* memory of a single worker machine.
*
* @return
* a new SCollection whose single value is an `Iterable` of the samples
* @group transform
*/
def sample(sampleSize: Int): SCollection[Iterable[T]] = self.transform {
import self.coder
_.pApply(Sample.fixedSizeGlobally(sampleSize)).map(_.asScala)
}
/**
* Return a sampled subset of this SCollection containing weighted items with at most
* `totalWeight` sum. Involves combine operation resulting in shuffling. All the elements of the
* output should fit into main memory of a single worker machine.
*
* @return
* a new SCollection whose single value is an `Iterable` of the samples
* @group transform
*/
def sampleWeighted(
totalWeight: Long,
cost: T => Long
): SCollection[Iterable[T]] = {
import self.coder
val mergeValue = { (c: WeightedCombiner[T], x: T) =>
val sample = WeightedSample(Random.nextLong(), x, cost(x))
val queue = c.queue
var weight = c.weight
// drop all elements with lower priority than the current sample if the queue is full
while (weight + sample.weight > totalWeight && queue.headOption.exists(_.id < sample.id)) {
val removed = queue.dequeue()
weight -= removed.weight
}
// add the current sample if there is enough space
if (weight + sample.weight <= totalWeight) {
queue += sample
weight += sample.weight
}
WeightedCombiner(weight, queue)
}
val mergeCombiners = { (l: WeightedCombiner[T], r: WeightedCombiner[T]) =>
// merge the two queues, reusing the left one
val queue = l.queue ++= r.queue
var weight = l.weight + r.weight
// drop all elements with low priority until the total weight is less than the limit
while (weight > totalWeight && queue.nonEmpty) {
val removed = queue.dequeue()
weight -= removed.weight
}
WeightedCombiner(weight, queue)
}
self
.aggregate(WeightedCombiner[T]())(mergeValue, mergeCombiners)
.map(_.queue.iterator.map(_.value).toSeq)
}
/**
* Return a sampled subset of this SCollection containing at most serializable `totalByteSize`.
* Involves combine operation resulting in shuffling. All the elements of the output should fit
* into main memory of a single worker machine.
*
* @return
* a new SCollection whose single value is an `Iterable` of the samples
* @group transform
*/
def sampleByteSized(totalByteSize: Long): SCollection[Iterable[T]] = {
import self.coder
sampleWeighted(totalByteSize, ScioUtil.elementByteSize(self.context))
}
/**
* Return a sampled subset of this SCollection. Does not trigger shuffling.
*
* @param withReplacement
* if `true` the same element can be produced more than once, otherwise the same element will be
* sampled only once
* @param fraction
* the sampling fraction
* @group transform
*/
def sample(withReplacement: Boolean, fraction: Double): SCollection[T] = {
import self.coder
if (withReplacement) {
self.parDo(new PoissonSampler[T](fraction))
} else {
self.parDo(new BernoulliSampler[T](fraction))
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy