All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.cli.cv.FoldsEnumerator Maven / Gradle / Ivy

There is a newer version: 1.4.9
Show newest version
package com.expleague.ml.cli.cv;

import com.expleague.commons.random.FastRandom;
import com.expleague.ml.data.tools.Pool;
import com.expleague.commons.util.ArrayTools;
import com.expleague.commons.util.Pair;
import com.expleague.ml.data.tools.DataTools;

/**
 * User: qdeee
 * Date: 16.09.15
 */
public class FoldsEnumerator {
  private final Pool sourcePool;
  private final int foldsCount;

  private int[][] foldIndices;
  private int currentFold = 0;

  public FoldsEnumerator(final Pool sourcePool, final FastRandom random, final int foldsCount) {
    this.sourcePool = sourcePool;
    this.foldsCount = foldsCount;

    final double[] probs = ArrayTools.fill(new double[foldsCount], 1. / foldsCount);
    foldIndices = DataTools.splitAtRandom(sourcePool.size(), random, probs);
  }

  public int getFoldsCount() {
    return foldsCount;
  }

  public boolean hasNext() {
    return currentFold < foldsCount;
  }

  public Pair next() {
    final int[] learnIndices = getLearnIndices();
    final int[] testIndices = foldIndices[currentFold];
    currentFold++;
    return Pair.create(sourcePool.sub(learnIndices), sourcePool.sub(testIndices));
  }

  private int[] getLearnIndices() {
    final int learnSize = sourcePool.size() - foldIndices[currentFold].length;
    final int[] learnIndices = new int[learnSize];
    int currentTotalLength = 0;
    for (int i = 0; i < currentFold; i++) {
      final int foldLength = foldIndices[i].length;
      System.arraycopy(foldIndices[i], 0, learnIndices, currentTotalLength, foldLength);
      currentTotalLength += foldLength;
    }
    for (int i = currentFold + 1; i < foldIndices.length; i++) {
      final int foldLength = foldIndices[i].length;
      System.arraycopy(foldIndices[i], 0, learnIndices, currentTotalLength, foldLength);
      currentTotalLength += foldLength;
    }
    return learnIndices;
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy