io.virtdata.stathelpers.aliasmethod.AliasElementSampler Maven / Gradle / Ivy
package io.virtdata.stathelpers.aliasmethod;
import io.virtdata.annotations.ThreadSafeMapper;
import io.virtdata.stathelpers.ElemProbD;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedList;
import java.util.List;
import java.util.function.DoubleFunction;
import java.util.function.Function;
import java.util.stream.Collectors;
/**
* Uses the alias sampling method to encode and sample from discrete probabilities,
* even over larger sets of data. This form requires a unit interval sample value
* between 0.0 and 1.0. Assuming the maximal amount of memory is used for distinct
* outcomes N, a memory buffer of N*16 bytes is required for this implementation,
* requiring 32MB of memory for 1M entries. Not bad, eh?
*
* This sampler should be shared between threads, and will be by default, in order
* to avoid many instances of a 32MB buffer on heap.
*/
@ThreadSafeMapper
public class AliasElementSampler implements DoubleFunction {
private double[] biases;
private T[] elements;
private double slotCount; // The number of fair die-roll slotCount that contain unfair coin probabilities
/**
* Setup an alias table for T type objects.
* @param biases An array of the unfair die model values
* @param elements An array of elements of type T, two values per bias value. index 2n is bot, index 2n+1 is top.
*/
AliasElementSampler(double[] biases, T[] elements) {
this.biases = biases;
this.elements = elements;
}
AliasElementSampler(Collection elements, Function weightFunction) {
this(elements.stream().map(e -> new ElemProbD<>(e,weightFunction.apply(e))).collect(Collectors.toList()));
}
public AliasElementSampler(List> events) {
int size = events.size();
LinkedList> small = new LinkedList<>();
LinkedList> large = new LinkedList<>();
List> slots = new ArrayList<>();
// array-size normalization
double sumProbability = events.stream().mapToDouble(ElemProbD::getProbability).sum();
events = events.stream().map(
e -> new ElemProbD<>(e.getElement(), (e.getProbability() / sumProbability) * size)
).collect(Collectors.toList());
// presort
for (ElemProbD event : events) {
(event.getProbability()<1.0D ? small : large).addLast(event);
}
while (small.peekFirst()!=null && large.peekFirst()!=null) {
ElemProbD l = small.removeFirst();
ElemProbD g = large.removeFirst();
slots.add(new Slot<>(g.getElement(), l.getElement(), l.getProbability()));
g.setProbability((g.getProbability()+l.getProbability())-1);
(g.getProbability()<1.0D ? small : large).addLast(g); // requeue
}
while (large.peekFirst()!=null) {
ElemProbD g = large.removeFirst();
slots.add(new Slot<>(g.getElement(),g.getElement(),1.0));
}
while (small.peekFirst()!=null) {
ElemProbD l = small.removeFirst();
slots.add(new Slot<>(l.getElement(),l.getElement(),1.0));
}
if (slots.size()!=size) {
throw new RuntimeException("basis for average probability is incorrect, because only " + slots.size() + " slotCount of " + size + " were created.");
}
// align to indexes
for (int i = 0; i < slots.size(); i++) {
slots.get(i).rescale(i, i+1);
}
this.biases=new double[slots.size()];
//noinspection unchecked
elements = (T[]) new Object[biases.length*2];
for (int i = 0; i < biases.length; i++) {
biases[i]=slots.get(i).botProb;
elements[i*2] = slots.get(i).botItx;
elements[(i*2)+1] = slots.get(i).topIdx;
}
this.slotCount = biases.length;
}
@Override
public T apply(double value) {
double fractionlPoint = value * slotCount;
int offsetPoint = (int) fractionlPoint;
double divider = biases[offsetPoint];
int index = fractionlPoint>divider? (offsetPoint<<1)+1 : (offsetPoint<<1);
T element = elements[index];
return element;
}
private static class Slot {
public T topIdx;
public T botItx;
public double botProb;
public Slot(T topIdx, T botItx, double botProb) {
this.topIdx = topIdx;
this.botItx = botItx;
this.botProb = botProb;
}
public String toString() {
return "top:" + topIdx + ", bot:" + botItx + ", botProb: " + botProb;
}
public Slot rescale(double min, double max) {
botProb = (min + (botProb*(max-min)));
return this;
}
}
public static interface Weighted {
double getWeight();
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy