All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.data.softBorders.GibbsExpWeightedPermutationsWalker Maven / Gradle / Ivy

package com.expleague.ml.data.softBorders;

import com.expleague.commons.random.FastRandom;
import com.expleague.commons.util.ArrayTools;
import com.expleague.commons.util.Combinatorics;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
 * Created by noxoomo on 06/11/2016.
 */

public class GibbsExpWeightedPermutationsWalker implements Sampler {
  private final int blockSize;
  private final double lambda;
  private final List blockPermutations;
  private final double[] localWeights;
  private final int[] localRanks;
  private final FastRandom random = new FastRandom();

  private final double[] rankWeights;
  private final int[] cursor;

  public GibbsExpWeightedPermutationsWalker(final int n,
                                            final double lambda) {
    this(n, lambda, null);
  }

  public GibbsExpWeightedPermutationsWalker(final int n,
                                            final double lambda,
                                            final double[] rankWeights) {

    this.blockSize = Math.min(4, n);

    this.cursor = ArrayTools.sequence(0, n);
    if (rankWeights != null) {
      assert (rankWeights.length == n);
      this.rankWeights = rankWeights;
    } else {
      this.rankWeights = new double[n];
      ArrayTools.fill(this.rankWeights, 1.0);
    }
    for (int i = 1; i < n; ++i) {
      this.rankWeights[i] += this.rankWeights[i - 1];
    }
    this.lambda = lambda;

    this.blockPermutations = new ArrayList<>();
    Combinatorics.Permutations permutations = new Combinatorics.Permutations(blockSize);
    while (permutations.hasNext()) {
      blockPermutations.add(permutations.next());
    }

    this.localWeights = new double[blockPermutations.size()];
    this.localRanks = new int[blockSize];
  }

  private void sampleBlock(int start) {
    System.arraycopy(cursor, start, localRanks, 0, blockSize);

    double totalWeight = 0;
    for (int idx = 0; idx < blockPermutations.size(); ++idx) {
      final int[] permutation = blockPermutations.get(idx);

      double permutationWeight = 0;

      for (int j = 0; j < blockSize; ++j) {
        final int rk = localRanks[permutation[j]];
        final double weightDiff = Math.abs(rankWeights[rk] - rankWeights[start + j]);
        permutationWeight += weightDiff;
      }
      permutationWeight *= lambda;
      localWeights[idx] = Math.exp(-permutationWeight);
      totalWeight += localWeights[idx];
    }

    double gain = random.nextDouble() * totalWeight;
    int takenIdx = -1;
    while (gain > 0) {
      gain -= localWeights[++takenIdx];
    }

    final int[] permutation = blockPermutations.get(takenIdx);
    for (int i = 0; i < blockSize; ++i) {
      cursor[start + i] = localRanks[permutation[i]];
    }
  }

  public int[] sample() {
    for (int iter = 0; iter < cursor.length; ++iter) {
      sampleBlock(random.nextInt(cursor.length - blockSize + 1));
    }
    return cursor;
  }

  public static void main(final String[] args) {
    final GibbsExpWeightedPermutationsWalker permutations = new GibbsExpWeightedPermutationsWalker(20, 0.5);
    for (int i = 0; i < 50; ++i) {
      System.out.println(Arrays.toString(permutations.sample()));
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy