com.expleague.ml.cli.modes.impl.CrossValidation Maven / Gradle / Ivy
package com.expleague.ml.cli.modes.impl;
import com.expleague.commons.math.vectors.Mx;
import com.expleague.commons.util.BestHolder;
import com.expleague.ml.cli.builders.data.impl.DataBuilderClassic;
import com.expleague.ml.cli.cv.KFoldCrossValidation;
import com.expleague.ml.cli.gridsearch.ParametersExtractor;
import com.expleague.commons.io.StreamTools;
import com.expleague.commons.random.FastRandom;
import com.expleague.commons.text.StringUtils;
import com.expleague.ml.cli.builders.methods.MethodsBuilder;
import com.expleague.ml.cli.builders.methods.grid.GridBuilder;
import com.expleague.ml.cli.modes.AbstractMode;
import com.expleague.ml.cli.modes.CliPoolReaderHelper;
import com.expleague.ml.data.tools.Pool;
import com.expleague.ml.BFGrid;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.MissingArgumentException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.*;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import static com.expleague.ml.cli.JMLLCLI.*;
/**
* User: qdeee
* Date: 16.09.15
*
* Запуск K-fold кросс-валидации: делим пул на K фолдов, на K-1 обучаемся, на последнем измеряем целевую функцию, повторяем K раз.
* Далее усредняем по числу фолдов.
*
* Количество фолдов и сид рандома передаются через CROSS_VALIDATION_OPTIONS в формате -X / (-X 100500/3)
* Доступно 2 варианта запуска:
* 1. Перебор сетки параметров.
* В команде запуска перебираемые параметры оставляем как %s и добавляем -r ::. Пример команды :
* "cross-validation -X 100500/3 -f pool.tsv -T blockwise.BlockwiseMLLLogit -O GradientBoosting(weak=MultiClassSplit(weak=GreedyObliviousTree),iterations=%s,step=%s) -r 10:15:1;0.1:0.9:0.1"
* 2. Одиночный запуск для фиксированных значений параметров. В этом случае передавать RANGES не нужно, а параметры указываются как обычно.
*/
public class CrossValidation extends AbstractMode {
private static final Logger LOG = LoggerFactory.getLogger(CrossValidation.class);
public void run(final CommandLine command) throws MissingArgumentException, IOException {
if (!command.hasOption(LEARN_OPTION)) {
throw new MissingArgumentException("Please provide 'LEARN_OPTION'");
}
final boolean schemeBatchComparision = command.hasOption(LOAD_OPTIMIZATION_SCHEMES_FROM_FILE_OPTION);
if (!schemeBatchComparision && !command.hasOption(OPTIMIZATION_OPTION)) {
throw new MissingArgumentException("Please provide 'OPTIMIZATION_OPTION'");
}
if (!command.hasOption(CROSS_VALIDATION_OPTION)) {
throw new MissingArgumentException("Please provide 'CROSS_VALIDATION_OPTIONS");
}
final DataBuilderClassic dataBuilder;
dataBuilder = new DataBuilderClassic();
dataBuilder.setLearnPath(command.getOptionValue(LEARN_OPTION));
CliPoolReaderHelper.setPoolReader(command, dataBuilder);
final Pool sourcePool = dataBuilder.create().getFirst();
final GridBuilder gridBuilder = new GridBuilder();
if (command.hasOption(GRID_OPTION)) {
gridBuilder.setGrid(BFGrid.CONVERTER.convertFrom(StreamTools.readFile(new File(command.getOptionValue(GRID_OPTION)))));
} else {
gridBuilder.setBinsCount(Integer.valueOf(command.getOptionValue(BIN_FOLDS_COUNT_OPTION, "32")));
gridBuilder.setDataSet(sourcePool.vecData());
}
final String[] cvOptions = StringUtils.split(command.getOptionValue(CROSS_VALIDATION_OPTION), "/", 2);
final FastRandom random = new FastRandom(Long.valueOf(cvOptions[0]));
final int foldsCount = Integer.parseInt(cvOptions[1]);
final MethodsBuilder methodsBuilder = new MethodsBuilder();
methodsBuilder.setRandom(random);
methodsBuilder.setGridBuilder(gridBuilder);
final String targetClassName = command.getOptionValue(TARGET_OPTION, DEFAULT_TARGET);
final KFoldCrossValidation crossValidation = new KFoldCrossValidation(sourcePool, random, foldsCount, targetClassName, methodsBuilder);
if (schemeBatchComparision) {
if (command.hasOption(RANGES_OPTION)) {
throw new RuntimeException("Error: range option is not supported for batch model comparision cv mode");
}
List schemes = loadSchemesFromFile(command.getOptionValue(LOAD_OPTIMIZATION_SCHEMES_FROM_FILE_OPTION));
KFoldCrossValidation.CrossValidationModelComparisonResult result = crossValidation.evaluateSchemesBatch(schemes);
dumpResult(result, command.getOptionValue(CROSS_VALIDATION_RESULT_OPTION, DEFAULT_MODELS_COMPARISION_CV_OUTPUT_FILE));
} else {
final String commonScheme = command.getOptionValue(OPTIMIZATION_OPTION);
final double score;
if (command.hasOption(RANGES_OPTION)) {
final String[][] parametersSpace = ParametersExtractor.parse(command.getOptionValue(RANGES_OPTION));
final BestHolder