All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.cli.builders.data.impl.DataBuilderCrossValidation Maven / Gradle / Ivy

package com.expleague.ml.cli.builders.data.impl;

import com.expleague.commons.random.FastRandom;
import com.expleague.ml.cli.builders.data.DataBuilder;
import com.expleague.ml.cli.builders.data.PoolReader;
import com.expleague.ml.data.tools.Pool;
import com.expleague.ml.data.tools.SubPool;
import com.expleague.commons.util.Pair;
import com.expleague.ml.data.tools.DataTools;

/**
 * User: qdeee
 * Date: 03.09.14
 */
public class DataBuilderCrossValidation implements DataBuilder {
  private String learnPath;
  private long randomSeed = System.currentTimeMillis();
  private double partition = 0.8;
  private PoolReader reader;

  @Override
  public void setLearnPath(final String learnPath) {
    this.learnPath = learnPath;
  }

  @Override
  public void setReader(final PoolReader reader) {
    this.reader = reader;
  }

  public void setRandomSeed(final long randomSeed) {
    this.randomSeed = randomSeed;
  }

  public void setPartition(final String partition) {
    try {
      this.partition = Double.parseDouble(partition);
    }
    catch (NumberFormatException e) {
      this.partition = 1. / Integer.parseInt(partition);
    }
  }

  @Override
  public Pair create() {
    final Pool pool = reader.read(learnPath);
    final FastRandom rnd = new FastRandom(randomSeed);

    final int[][] cvSplit = DataTools.splitAtRandom(pool.size(), rnd, partition, 1.0 - partition);
    return Pair.create(new SubPool(pool, cvSplit[0]), new SubPool(pool, cvSplit[1]));
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy