All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.virtdata.stathelpers.aliasmethod.AliasSamplerDoubleInt Maven / Gradle / Ivy

There is a newer version: 2.12.15
Show newest version
package io.virtdata.stathelpers.aliasmethod;

import io.virtdata.stathelpers.EvProbD;

import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.function.DoubleToIntFunction;
import java.util.stream.Collectors;

/**
 * Uses the alias sampling method to encode and sample from discrete probabilities,
 * even over larger sets of data. This form requires a unit interval sample value
 * between 0.0 and 1.0. Assuming the maximal amount of memory is used for distinct
 * outcomes N, a memory buffer of N*16 bytes is required for this implementation,
 * requiring 32MB of memory for 1M entries.
 *
 * This sampler should be shared between threads, and will be by default, in order
 * to avoid many instances of a 32MB buffer on heap.
 */
public class AliasSamplerDoubleInt implements DoubleToIntFunction {

    private ByteBuffer stats; // tuples of double,int,int (unfair coin, direct pointers to referents)
    private double slotCount; // The number of fair die-roll slotCount that contain unfair coin probabilities
    private static int _r0=0;
    private static int _r1=_r0+Double.BYTES;
    private static int _r2=_r1+Integer.BYTES;
    public static int RECORD_LEN = _r2 + Integer.BYTES; // Record size for the above.

    // for testing
    AliasSamplerDoubleInt(ByteBuffer stats) {
        this.stats = stats;
        if ((stats.capacity()% RECORD_LEN)!=0) {
            throw new RuntimeException("Misaligned ByteBuffer size, must be a multiple of " + RECORD_LEN);
        }
        slotCount = (stats.capacity()/ RECORD_LEN);
    }

    public AliasSamplerDoubleInt(List events) {
        int size = events.size();

        int[] alias = new int[events.size()];
        double[] prob = new double[events.size()];

        LinkedList small = new LinkedList<>();
        LinkedList large = new LinkedList<>();
        List slots = new ArrayList<>();

        // array-size normalization
        double sumProbability = events.stream().mapToDouble(EvProbD::getProbability).sum();
        events = events.stream().map(e -> new EvProbD(e.getEventId(), (e.getProbability()/sumProbability)*size)).collect(Collectors.toList());

        // presort
        for (EvProbD event : events) {
            (event.getProbability()<1.0D ? small : large).addLast(event);
        }

        while (small.peekFirst()!=null && large.peekFirst()!=null) {
            EvProbD l = small.removeFirst();
            EvProbD g = large.removeFirst();
            slots.add(new Slot(g.getEventId(), l.getEventId(), l.getProbability()));
            g.setProbability((g.getProbability()+l.getProbability())-1);
            (g.getProbability()<1.0D ? small : large).addLast(g); // requeue
        }
        while (large.peekFirst()!=null) {
            EvProbD g = large.removeFirst();
            slots.add(new Slot(g.getEventId(),g.getEventId(),1.0));
        }
        while (small.peekFirst()!=null) {
            EvProbD l = small.removeFirst();
            slots.add(new Slot(l.getEventId(),l.getEventId(),1.0));
        }
        if (slots.size()!=size) {
            throw new RuntimeException("basis for average probability is incorrect, because only " + slots.size() + " slotCount of " + size + " were created.");
        }
        // align to indexes
        for (int i = 0; i < slots.size(); i++) {
            slots.get(i).rescale(i, i+1);
        }
        this.stats = ByteBuffer.allocate(slots.size()* RECORD_LEN);

        for (Slot slot : slots) {
            stats.putDouble(slot.botProb);
            stats.putInt(slot.botItx);
            stats.putInt(slot.topIdx);
        }
        stats.flip();
        this.slotCount = (stats.capacity()/ RECORD_LEN);

    }

    @Override
    public int applyAsInt(double value) {
        double fractionlPoint = value * slotCount;
        int offsetPoint = (int) fractionlPoint * RECORD_LEN;
        double divider = stats.getDouble(offsetPoint);
        int selector = offsetPoint+ (fractionlPoint>divider?_r2:_r1);
        int referentId = stats.getInt(selector);
        return referentId;
    }

    private static class Slot {
        public int topIdx;
        public int botItx;
        public double botProb;

        public Slot(int topIdx, int botItx, double botProb) {
            this.topIdx = topIdx;
            this.botItx = botItx;
            this.botProb = botProb;
        }

        public String toString() {
            return "top:" + topIdx + ", bot:" + botItx + ", botProb: " + botProb;
        }

        public Slot rescale(double min, double max) {
            botProb = (min + (botProb*(max-min)));
            return this;
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy