All Downloads are FREE. Search and download functionalities are using the official Maven repository.
Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
com.expleague.ml.cli.modes.impl.Fit Maven / Gradle / Ivy
package com.expleague.ml.cli.modes.impl;
import com.expleague.ml.BFGrid;
import com.expleague.ml.ProgressHandler;
import com.expleague.ml.TargetFunc;
import com.expleague.ml.cli.modes.CliPoolReaderHelper;
import com.expleague.ml.cli.output.ModelWriter;
import com.expleague.commons.func.WeakListenerHolder;
import com.expleague.commons.io.StreamTools;
import com.expleague.commons.math.Func;
import com.expleague.commons.math.Trans;
import com.expleague.commons.random.FastRandom;
import com.expleague.commons.text.StringUtils;
import com.expleague.commons.util.Pair;
import com.expleague.commons.util.logging.Interval;
import com.expleague.ml.cli.output.printers.*;
import com.expleague.ml.cli.builders.data.DataBuilder;
import com.expleague.ml.cli.builders.data.impl.DataBuilderClassic;
import com.expleague.ml.cli.builders.data.impl.DataBuilderCrossValidation;
import com.expleague.ml.cli.builders.methods.MethodsBuilder;
import com.expleague.ml.cli.builders.methods.grid.DynamicGridBuilder;
import com.expleague.ml.cli.builders.methods.grid.GridBuilder;
import com.expleague.ml.cli.modes.AbstractMode;
import com.expleague.ml.data.tools.DataTools;
import com.expleague.ml.data.tools.MCTools;
import com.expleague.ml.data.tools.Pool;
import com.expleague.ml.loss.blockwise.BlockwiseMLLLogit;
import com.expleague.ml.loss.blockwise.BlockwiseMultiLabelLogit;
import com.expleague.ml.loss.multiclass.ClassicMulticlassLoss;
import com.expleague.ml.loss.multilabel.ClassicMultiLabelLoss;
import com.expleague.ml.loss.multilabel.MultiLabelOVRLogit;
import com.expleague.ml.methods.VecOptimization;
import com.expleague.ml.models.multiclass.JoinedBinClassModel;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.MissingArgumentException;
import java.io.File;
import static com.expleague.ml.cli.JMLLCLI.*;
/**
* User: qdeee
* Date: 16.09.15
*/
public class Fit extends AbstractMode {
@Override
public void run(final CommandLine command) throws Exception {
if (!command.hasOption(LEARN_OPTION)) {
throw new MissingArgumentException("Please provide 'LEARN_OPTION'");
}
//data loading
final DataBuilder dataBuilder;
if (command.hasOption(CROSS_VALIDATION_OPTION)) {
final DataBuilderCrossValidation dataBuilderCrossValidation = new DataBuilderCrossValidation();
final String[] cvOptions = StringUtils.split(command.getOptionValue(CROSS_VALIDATION_OPTION), "/", 2);
dataBuilderCrossValidation.setRandomSeed(Long.valueOf(cvOptions[0]));
dataBuilderCrossValidation.setPartition(cvOptions[1]);
dataBuilder = dataBuilderCrossValidation;
} else {
dataBuilder = new DataBuilderClassic();
((DataBuilderClassic) dataBuilder).setTestPath(command.getOptionValue(TEST_OPTION));
}
dataBuilder.setLearnPath(command.getOptionValue(LEARN_OPTION));
CliPoolReaderHelper.setPoolReader(command, dataBuilder);
final Pair extends Pool, ? extends Pool> pools = dataBuilder.create();
final Pool learn = pools.getFirst();
final Pool test = pools.getSecond();
//loading grid (if needed)
final GridBuilder gridBuilder = new GridBuilder();
if (command.hasOption(GRID_OPTION)) {
gridBuilder.setGrid(BFGrid.CONVERTER.convertFrom(StreamTools.readFile(new File(command.getOptionValue(GRID_OPTION)))));
} else {
gridBuilder.setBinsCount(Integer.valueOf(command.getOptionValue(BIN_FOLDS_COUNT_OPTION, "32")));
gridBuilder.setDataSet(learn.vecData());
}
final DynamicGridBuilder dynamicGridBuilder = new DynamicGridBuilder();
dynamicGridBuilder.setBinsCount(Integer.valueOf(command.getOptionValue(BIN_FOLDS_COUNT_OPTION, "1")));
dynamicGridBuilder.setDataSet(learn.vecData());
//choose optimization method
final MethodsBuilder methodsBuilder = new MethodsBuilder();
methodsBuilder.setRandom(new FastRandom());
methodsBuilder.setGridBuilder(gridBuilder);
methodsBuilder.setDynamicGridBuilder(dynamicGridBuilder);
final VecOptimization method = methodsBuilder.create(command.getOptionValue(OPTIMIZATION_OPTION, DEFAULT_OPTIMIZATION_SCHEME));
//set target
final String target = command.getOptionValue(TARGET_OPTION, DEFAULT_TARGET);
final TargetFunc loss = learn.target(DataTools.targetByName(target));
//set metrics
final String[] metricNames = command.getOptionValues(METRICS_OPTION);
final Func[] metrics;
if (metricNames != null) {
metrics = new Func[metricNames.length];
for (int i = 0; i < metricNames.length; i++) {
metrics[i] = test.targetByName(metricNames[i]);
}
} else {
metrics = new Func[]{test.targetByName(target)};
}
//added progress handlers
ProgressHandler progressPrinter = null;
if (method instanceof WeakListenerHolder && command.hasOption(VERBOSE_OPTION) && !command.hasOption(FAST_OPTION)) {
final int printPeriod = Integer.valueOf(command.getOptionValue(PRINT_PERIOD, "10"));
if (loss instanceof BlockwiseMLLLogit) {
progressPrinter = new MulticlassProgressPrinter(learn, test, printPeriod); //f*ck you with your custom different-dimensional metrics
} else if (loss instanceof BlockwiseMultiLabelLogit || loss instanceof MultiLabelOVRLogit) {
progressPrinter = new MultiLabelLogitProgressPrinter(learn, test, printPeriod);
} else {
progressPrinter = new DefaultProgressPrinter(learn, test, loss, metrics, printPeriod);
}
//noinspection unchecked
((WeakListenerHolder) method).addListener(progressPrinter);
}
if (method instanceof WeakListenerHolder && command.hasOption(HIST_OPTION)) {
final ProgressHandler histogramPrinter = new HistogramPrinter();
//noinspection unchecked
((WeakListenerHolder) method).addListener(histogramPrinter);
}
//fitting
Interval.start();
Interval.suspend();
final Trans result = method.fit(learn.vecData(), loss);
Interval.stopAndPrint("Total fit time:");
//calc & print scores
if (!command.hasOption(FAST_OPTION) && !command.hasOption(SKIP_FINAL_EVAL_OPTION)) {
ResultsPrinter.printResults(result, learn, test, loss, metrics);
if (loss instanceof BlockwiseMLLLogit) {
ResultsPrinter.printMulticlassResults(result, learn, test);
} else if (loss instanceof ClassicMulticlassLoss) {
final int printPeriod = Integer.valueOf(command.getOptionValue(PRINT_PERIOD, "20"));
MCTools.makeOneVsRestReport(learn, test, (JoinedBinClassModel) result, printPeriod);
} else if (loss instanceof ClassicMultiLabelLoss || loss instanceof BlockwiseMultiLabelLogit || loss instanceof MultiLabelOVRLogit) {
ResultsPrinter.printMultilabelResult(result, learn, test);
}
}
//write model
final String outName = getOutputName(command);
final ModelWriter modelWriter = new ModelWriter(outName);
if (command.hasOption(WRITE_BIN_FORMULA)) {
modelWriter.tryWriteBinFormula(result);
} else {
if (!command.hasOption(GRID_OPTION)) {
modelWriter.tryWriteGrid(result);
}
modelWriter.tryWriteDynamicGrid(result);
modelWriter.writeModel(result);
}
}
}