com.arosbio.ml.gridsearch.GridResultCSVWriter Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of confai Show documentation
Show all versions of confai Show documentation
Conformal AI package, including all data IO, transformations, machine learning models and predictor classes. Without inclusion of chemistry-dependent code.
/*
* Copyright (C) Aros Bio AB.
*
* CPSign is an Open Source Software that is dual licensed to allow you to choose a license that best suits your requirements:
*
* 1) GPLv3 (GNU General Public License Version 3) with Additional Terms, including an attribution clause as well as a limitation to use the software for commercial purposes.
*
* 2) CPSign Proprietary License that allows you to use CPSign for commercial activities, such as in a revenue-generating operation or environment, or integrate CPSign in your proprietary software without worrying about disclosing the source code of your proprietary software, which is required if you choose to use the software under GPLv3 license. See arosbio.com/cpsign/commercial-license for details.
*/
package com.arosbio.ml.gridsearch;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVPrinter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.arosbio.commons.MathUtils;
import com.arosbio.commons.Stopwatch;
import com.arosbio.commons.TypeUtils;
import com.arosbio.ml.gridsearch.GridSearch.EvalStatus;
import com.arosbio.ml.gridsearch.GridSearch.GSResult;
import com.arosbio.ml.metrics.SingleValuedMetric;
public class GridResultCSVWriter implements AutoCloseable {
private static final Logger LOGGER = (Logger) LoggerFactory.getLogger(GridResultCSVWriter.class);
private static final int NUM_SIGNIFICANT_FIGURES = 5;
private static final String CSV_RANK_HEADER = "Rank";
private static final String CSV_STATUS_HEADER = "Status";
private static final String CSV_RUNTIME_HEADER = "Runtime";
private static final String CSV_RUNTIME_MS_HEADER = "Runtime (ms)";
private static final String CSV_SET_CONFIDENCE_HEADER = "Chosen confidence";
private static final String CSV_ERROR_MSG_HEADER = "Comment";
private static final char NO_RESULT_INDICATOR = '-';
// Prior to initialization
private final Builder settings;
// Once the setupHeader.. method has been called
private CSVPrinter printer;
private List headers;
// Keeping track of Rank
private double previousScoure = Double.NaN;
private int rank = 0;
public static class Builder {
private Double conf = null;
private boolean useRanking = false;
private CSVFormat.Builder format = CSVFormat.DEFAULT.builder().setRecordSeparator(System.lineSeparator());
private List params;
private Appendable output;
/**
* If a confidence is used (a value in [0..1]) or not null
* @param conf a value in [0..1] or null
* @return the Builder
*/
public Builder confidence(Double conf) {
this.conf = conf;
return this;
}
public Builder skipConfidence() {
this.conf = null;
return this;
}
public Builder rank(boolean on) {
this.useRanking = on;
return this;
}
public Builder format(CSVFormat.Builder format) {
this.format = format;
return this;
}
public Builder format(CSVFormat format){
this.format = format.builder();
return this;
}
public Builder params(Collection paramNames) {
if (paramNames==null || paramNames.isEmpty())
this.params = new ArrayList<>();
else
this.params = new ArrayList<>(paramNames);
return this;
}
public Builder log(Appendable out) {
this.output = out;
return this;
}
public Builder out(Appendable out) {
this.output = out;
return this;
}
private Builder getCopy() {
return new Builder().confidence(conf).rank(useRanking).format(format).params(params).log(output);
}
public GridResultCSVWriter build(){
return new GridResultCSVWriter(this);
}
}
private GridResultCSVWriter(Builder builder) {
this.settings = builder.getCopy();
}
/**
* Sets up the header and the
*/
private void setupHeaderAndInitPrinter(GSResult res) throws IOException {
// Set up the header
headers = new ArrayList<>();
if (settings.useRanking) {
// First the Rank
headers.add(CSV_RANK_HEADER);
}
// Optimization metric first
headers.addAll(res.getOptimizationMetric().asMap().keySet());
// Secondary metrics
if (res.getSecondaryMetrics()!=null && !res.getSecondaryMetrics().isEmpty()) {
for (SingleValuedMetric m: res.getSecondaryMetrics()) {
headers.addAll(m.asMap().keySet());
}
}
if (settings.conf != null) {
headers.add(CSV_SET_CONFIDENCE_HEADER);
}
headers.addAll(settings.params); // Then the parameters
headers.add(CSV_RUNTIME_HEADER);
headers.add(CSV_RUNTIME_MS_HEADER);
headers.add(CSV_STATUS_HEADER);
headers.add(CSV_ERROR_MSG_HEADER);
printer = new CSVPrinter(settings.output, settings.format.setHeader(headers.toArray(new String[0])).build());
}
public void printRecord(GSResult res) throws IOException {
if (printer == null) {
LOGGER.debug("Setting up GS CSV printer");
setupHeaderAndInitPrinter(res);
}
// Print the current parameters and the metrics
List