All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.virtdata.stathelpers.aliasmethod.AliasElementSampler Maven / Gradle / Ivy

There is a newer version: 2.12.15
Show newest version
package io.virtdata.stathelpers.aliasmethod;

import io.virtdata.stathelpers.ElemProbD;

import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedList;
import java.util.List;
import java.util.function.DoubleFunction;
import java.util.function.Function;
import java.util.stream.Collectors;

/**
 * Uses the alias sampling method to encode and sample from discrete probabilities,
 * even over larger sets of data. This form requires a unit interval sample value
 * between 0.0 and 1.0. Assuming the maximal amount of memory is used for distinct
 * outcomes N, a memory buffer of N*16 bytes is required for this implementation,
 * requiring 32MB of memory for 1M entries.
 *
 * This sampler should be shared between threads, and will be by default, in order
 * to avoid many instances of a 32MB buffer on heap.
 */
public class AliasElementSampler implements DoubleFunction {

    private double[] biases;
    private T[] elements;
    private double slotCount; // The number of fair die-roll slotCount that contain unfair coin probabilities


    /**
     * Setup an alias table for T type objects.
     * @param biases An array of the unfair die model values
     * @param elements An array of elements of type T, two values per bias value. index 2n is bot, index 2n+1 is top.
     */
    AliasElementSampler(double[] biases, T[] elements) {
        this.biases = biases;
        this.elements = elements;
    }

    AliasElementSampler(Collection elements, Function weightFunction) {
        this(elements.stream().map(e -> new ElemProbD<>(e,weightFunction.apply(e))).collect(Collectors.toList()));
    }

    @SuppressWarnings("unchecked")
    public AliasElementSampler(List> events) {
        int size = events.size();

        LinkedList> small = new LinkedList<>();
        LinkedList> large = new LinkedList<>();
        List> slots = new ArrayList<>();

        // array-size normalization
        double sumProbability = events.stream().mapToDouble(ElemProbD::getProbability).sum();
        events = events.stream().map(
                e -> new ElemProbD<>(e.getElement(), (e.getProbability() / sumProbability) * size)
        ).collect(Collectors.toList());

        // presort
        for (ElemProbD event : events) {
            (event.getProbability()<1.0D ? small : large).addLast(event);
        }

        while (small.peekFirst()!=null && large.peekFirst()!=null) {
            ElemProbD l = small.removeFirst();
            ElemProbD g = large.removeFirst();
            slots.add(new Slot<>(g.getElement(), l.getElement(), l.getProbability()));
            g.setProbability((g.getProbability()+l.getProbability())-1);
            (g.getProbability()<1.0D ? small : large).addLast(g); // requeue
        }
        while (large.peekFirst()!=null) {
            ElemProbD g = large.removeFirst();
            slots.add(new Slot<>(g.getElement(),g.getElement(),1.0));
        }
        while (small.peekFirst()!=null) {
            ElemProbD l = small.removeFirst();
            slots.add(new Slot<>(l.getElement(),l.getElement(),1.0));
        }
        if (slots.size()!=size) {
            throw new RuntimeException("basis for average probability is incorrect, because only " + slots.size() + " slotCount of " + size + " were created.");
        }
        // align to indexes
        for (int i = 0; i < slots.size(); i++) {
            slots.get(i).rescale(i, i+1);
        }
        this.biases=new double[slots.size()];

        //noinspection unchecked
        elements = (T[]) new Object[biases.length*2];

        for (int i = 0; i < biases.length; i++) {
            biases[i]=slots.get(i).botProb;
            elements[i*2] = slots.get(i).botItx;
            elements[(i*2)+1] = slots.get(i).topIdx;
        }
        this.slotCount = biases.length;

    }

    @Override
    public T apply(double value) {
        double fractionlPoint = value * slotCount;
        int offsetPoint = (int) fractionlPoint;
        double divider = biases[offsetPoint];
        int index = fractionlPoint>divider? (offsetPoint<<1)+1 : (offsetPoint<<1);
        T element = elements[index];
        return element;
    }

    private static class Slot {
        public T topIdx;
        public T botItx;
        public double botProb;

        public Slot(T topIdx, T botItx, double botProb) {
            this.topIdx = topIdx;
            this.botItx = botItx;
            this.botProb = botProb;
        }

        public String toString() {
            return "top:" + topIdx + ", bot:" + botItx + ", botProb: " + botProb;
        }

        public Slot rescale(double min, double max) {
            botProb = (min + (botProb*(max-min)));
            return this;
        }
    }

    public static interface Weighted {
        double getWeight();
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy