ml.comet.experiment.impl.BaseExperiment Maven / Gradle / Ivy
package ml.comet.experiment.impl;
import io.reactivex.rxjava3.core.Single;
import io.reactivex.rxjava3.functions.BiFunction;
import io.reactivex.rxjava3.functions.Function;
import lombok.Getter;
import lombok.NonNull;
import ml.comet.experiment.Experiment;
import ml.comet.experiment.context.ExperimentContext;
import ml.comet.experiment.exception.CometApiException;
import ml.comet.experiment.exception.CometGeneralException;
import ml.comet.experiment.impl.asset.Asset;
import ml.comet.experiment.impl.asset.AssetType;
import ml.comet.experiment.impl.http.Connection;
import ml.comet.experiment.impl.http.ConnectionInitializer;
import ml.comet.experiment.impl.utils.CometUtils;
import ml.comet.experiment.model.CreateExperimentRequest;
import ml.comet.experiment.model.CreateExperimentResponse;
import ml.comet.experiment.model.ExperimentAssetLink;
import ml.comet.experiment.model.ExperimentMetadataRest;
import ml.comet.experiment.model.ExperimentStatusResponse;
import ml.comet.experiment.model.GitMetadata;
import ml.comet.experiment.model.GitMetadataRest;
import ml.comet.experiment.model.LogDataResponse;
import ml.comet.experiment.model.ValueMinMaxDto;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import java.io.File;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.List;
import java.util.Optional;
import static ml.comet.experiment.impl.asset.AssetType.ASSET_TYPE_ASSET;
import static ml.comet.experiment.impl.asset.AssetType.ASSET_TYPE_SOURCE_CODE;
import static ml.comet.experiment.impl.resources.LogMessages.EXPERIMENT_CLEANUP_PROMPT;
import static ml.comet.experiment.impl.resources.LogMessages.EXPERIMENT_LIVE;
import static ml.comet.experiment.impl.resources.LogMessages.FAILED_READ_DATA_FOR_EXPERIMENT;
import static ml.comet.experiment.impl.resources.LogMessages.getString;
import static ml.comet.experiment.impl.utils.DataUtils.createGraphRequest;
import static ml.comet.experiment.impl.utils.DataUtils.createLogEndTimeRequest;
import static ml.comet.experiment.impl.utils.DataUtils.createLogHtmlRequest;
import static ml.comet.experiment.impl.utils.DataUtils.createLogLineRequest;
import static ml.comet.experiment.impl.utils.DataUtils.createLogMetricRequest;
import static ml.comet.experiment.impl.utils.DataUtils.createLogOtherRequest;
import static ml.comet.experiment.impl.utils.DataUtils.createLogParamRequest;
import static ml.comet.experiment.impl.utils.DataUtils.createLogStartTimeRequest;
import static ml.comet.experiment.impl.utils.DataUtils.createTagRequest;
/**
* The base class for all synchronous experiment implementations providing implementation of common routines
* using synchronous networking.
*/
abstract class BaseExperiment implements Experiment {
final String apiKey;
final String baseUrl;
final int maxAuthRetries;
final Duration cleaningTimeout;
String projectName;
String workspaceName;
String experimentKey;
String experimentLink;
String experimentName;
boolean alive;
@Getter
private RestApiClient restApiClient;
@Getter
private Connection connection;
/**
* Returns logger instance associated with particular experiment. The subclasses should override this method to
* provide specific logger instance.
*
* @return the logger instance associated with particular experiment.
*/
protected abstract Logger getLogger();
BaseExperiment(@NonNull final String apiKey,
@NonNull final String baseUrl,
int maxAuthRetries,
final String experimentKey,
@NonNull final Duration cleaningTimeout) {
this(apiKey, baseUrl, maxAuthRetries, experimentKey, cleaningTimeout, StringUtils.EMPTY, StringUtils.EMPTY);
}
BaseExperiment(@NonNull final String apiKey,
@NonNull final String baseUrl,
int maxAuthRetries,
final String experimentKey,
@NonNull final Duration cleaningTimeout,
final String projectName,
final String workspaceName) {
this.apiKey = apiKey;
this.baseUrl = baseUrl;
this.maxAuthRetries = maxAuthRetries;
this.experimentKey = experimentKey;
this.cleaningTimeout = cleaningTimeout;
this.projectName = projectName;
this.workspaceName = workspaceName;
}
/**
* Invoked to validate and initialize common fields used by all subclasses.
*/
void init() {
CometUtils.printCometSdkVersion();
validateInitialParams();
this.connection = ConnectionInitializer.initConnection(
this.apiKey, this.baseUrl, this.maxAuthRetries, this.getLogger());
this.restApiClient = new RestApiClient(this.connection);
// mark as initialized
this.alive = true;
}
/**
* Validates initial parameters and throws exception if validation failed.
*
* @throws IllegalArgumentException if validation failed.
*/
private void validateInitialParams() throws IllegalArgumentException {
if (StringUtils.isBlank(this.apiKey)) {
throw new IllegalArgumentException("API key is not specified!");
}
if (StringUtils.isBlank(this.baseUrl)) {
throw new IllegalArgumentException("The Comet base URL is not specified!");
}
}
/**
* Synchronously registers experiment at the Comet server.
*
* @throws CometGeneralException if failed to register experiment.
*/
void registerExperiment() throws CometGeneralException {
if (StringUtils.isNotBlank(this.experimentKey)) {
getLogger().debug("Not registering a new experiment. Using previous experiment key {}", this.experimentKey);
return;
}
// do synchronous call to register experiment
CreateExperimentResponse result = this.restApiClient.registerExperiment(
new CreateExperimentRequest(this.workspaceName, this.projectName, this.experimentName))
.blockingGet();
this.experimentKey = result.getExperimentKey();
this.experimentLink = result.getLink();
getLogger().info(getString(EXPERIMENT_LIVE, this.experimentLink));
if (StringUtils.isBlank(this.experimentKey)) {
throw new CometGeneralException("Failed to register onlineExperiment with Comet ML");
}
}
@Override
public String getExperimentKey() {
return this.experimentKey;
}
@Override
public String getProjectName() {
return this.projectName;
}
@Override
public String getWorkspaceName() {
return this.workspaceName;
}
@Override
public String getExperimentName() {
return this.experimentName;
}
@Override
public void setExperimentName(@NonNull String experimentName) {
logOther("Name", experimentName);
this.experimentName = experimentName;
}
/**
* Synchronous version that waits for result or exception. Also, it checks the response status for failure.
*
* @param metricName The name for the metric to be logged
* @param metricValue The new value for the metric. If the values for a metric are plottable we will plot them
* @param context the context to be associated with the parameter.
* @throws CometApiException if received response with failure code.
*/
@Override
public void logMetric(@NonNull String metricName, @NonNull Object metricValue,
@NonNull ExperimentContext context) {
if (getLogger().isDebugEnabled()) {
getLogger().debug("logMetric {} = {}, context: {}", metricName, metricValue, context);
}
sendSynchronously(restApiClient::logMetric,
createLogMetricRequest(metricName, metricValue, context));
}
@Override
public void logMetric(String metricName, Object metricValue, long step, long epoch) {
this.logMetric(metricName, metricValue, new ExperimentContext(step, epoch));
}
/**
* Synchronous version that waits for result or exception. Also, it checks the response status for failure.
*
* @param parameterName The name of the param being logged
* @param paramValue The value for the param being logged
* @param context the context to be associated with the parameter.
*/
@Override
public void logParameter(String parameterName, Object paramValue, ExperimentContext context) {
if (getLogger().isDebugEnabled()) {
getLogger().debug("logParameter {} = {}, context: {}", parameterName, paramValue, context);
}
sendSynchronously(restApiClient::logParameter,
createLogParamRequest(parameterName, paramValue, context));
}
@Override
public void logParameter(String parameterName, Object paramValue, long step) {
this.logParameter(parameterName, paramValue, new ExperimentContext(step));
}
/**
* Synchronous version that waits for result or exception. Also, it checks the response status for failure.
*
* @param line Text to be logged
* @param offset Offset describes the place for current text to be inserted
* @param stderr the flag to indicate if this is StdErr message.
* @param context the context to be associated with the parameter.
*/
@Override
public void logLine(String line, long offset, boolean stderr, String context) {
validate();
sendSynchronously(restApiClient::logOutputLine,
createLogLineRequest(line, offset, stderr, context));
}
/**
* Synchronous version that waits for result or exception. Also, it checks the response status for failure.
*
* @param html A block of html to be sent to Comet
* @param override Whether previous html sent should be deleted.
* If true
the old html will be deleted.
*/
@Override
public void logHtml(@NonNull String html, boolean override) {
if (getLogger().isDebugEnabled()) {
getLogger().debug("logHtml {}, override: {}", html, override);
}
sendSynchronously(restApiClient::logHtml, createLogHtmlRequest(html, override));
}
/**
* Synchronous version that waits for result or exception. Also, it checks the response status for failure.
*
* @param key The key for the data to be stored
* @param value The value for said key
*/
@Override
public void logOther(@NonNull String key, @NonNull Object value) {
if (getLogger().isDebugEnabled()) {
getLogger().debug("logOther {} {}", key, value);
}
sendSynchronously(restApiClient::logOther, createLogOtherRequest(key, value));
}
/**
* Synchronous version that waits for result or exception. Also, it checks the response status for failure.
*
* @param tag The tag to be added
*/
@Override
public void addTag(@NonNull String tag) {
if (getLogger().isDebugEnabled()) {
getLogger().debug("addTag {}", tag);
}
sendSynchronously(restApiClient::addTag, createTagRequest(tag));
}
/**
* Synchronous version that waits for result or exception. Also, it checks the response status for failure.
*
* @param graph The graph to be logged.
*/
@Override
public void logGraph(@NonNull String graph) {
if (getLogger().isDebugEnabled()) {
getLogger().debug("logGraph {}", graph);
}
sendSynchronously(restApiClient::logGraph, createGraphRequest(graph));
}
/**
* Synchronous version that waits for result or exception. Also, it checks the response status for failure.
*
* @param startTimeMillis When you want to say that the experiment started
*/
@Override
public void logStartTime(long startTimeMillis) {
if (getLogger().isDebugEnabled()) {
getLogger().debug("logStartTime {}", startTimeMillis);
}
sendSynchronously(restApiClient::logStartEndTime, createLogStartTimeRequest(startTimeMillis));
}
/**
* Synchronous version that waits for result or exception. Also, it checks the response status for failure.
*
* @param endTimeMillis When you want to say that the experiment ended
*/
@Override
public void logEndTime(long endTimeMillis) {
if (getLogger().isDebugEnabled()) {
getLogger().debug("logEndTime {}", endTimeMillis);
}
sendSynchronously(restApiClient::logStartEndTime, createLogEndTimeRequest(endTimeMillis));
}
/**
* Synchronous version that waits for result or exception. Also, it checks the response status for failure.
*
* @param gitMetadata The Git Metadata for the experiment.
*/
@Override
public void logGitMetadata(GitMetadata gitMetadata) {
if (getLogger().isDebugEnabled()) {
getLogger().debug("logGitMetadata {}", gitMetadata);
}
sendSynchronously(restApiClient::logGitMetadata, gitMetadata);
}
@Override
public void logCode(@NonNull String code, @NonNull String fileName, @NonNull ExperimentContext context) {
if (getLogger().isDebugEnabled()) {
getLogger().debug("log raw source code, file name: {}", fileName);
}
Asset asset = new Asset();
asset.setFileLikeData(code.getBytes(StandardCharsets.UTF_8));
asset.setFileName(fileName);
asset.setExperimentContext(context);
asset.setType(ASSET_TYPE_SOURCE_CODE);
sendSynchronously(restApiClient::logAsset, asset);
}
@Override
public void logCode(String code, String fileName) {
this.logCode(code, fileName, ExperimentContext.empty());
}
@Override
public void logCode(@NonNull File file, @NonNull ExperimentContext context) {
if (getLogger().isDebugEnabled()) {
getLogger().debug("log source code from file {}", file.getName());
}
Asset asset = new Asset();
asset.setFile(file);
asset.setFileName(file.getName());
asset.setExperimentContext(context);
asset.setType(ASSET_TYPE_SOURCE_CODE);
sendSynchronously(restApiClient::logAsset, asset);
}
@Override
public void logCode(File file) {
this.logCode(file, ExperimentContext.empty());
}
@Override
public void uploadAsset(@NonNull File file, @NonNull String fileName,
boolean overwrite, @NonNull ExperimentContext context) {
if (getLogger().isDebugEnabled()) {
getLogger().debug("uploadAsset from file {}, name {}, override {}, context {}",
file.getName(), fileName, overwrite, context);
}
Asset asset = new Asset();
asset.setFile(file);
asset.setFileName(fileName);
asset.setExperimentContext(context);
asset.setOverwrite(overwrite);
asset.setType(ASSET_TYPE_ASSET);
sendSynchronously(restApiClient::logAsset, asset);
}
@Override
public void uploadAsset(@NonNull File asset, String fileName, boolean overwrite, long step, long epoch) {
this.uploadAsset(asset, fileName, overwrite, new ExperimentContext(step, epoch));
}
@Override
public void uploadAsset(@NonNull File asset, boolean overwrite, @NonNull ExperimentContext context) {
this.uploadAsset(asset, asset.getName(), overwrite, context);
}
@Override
public void uploadAsset(@NonNull File asset, boolean overwrite, long step, long epoch) {
this.uploadAsset(asset, overwrite, new ExperimentContext(step, epoch));
}
@Override
public ExperimentMetadataRest getMetadata() {
if (getLogger().isDebugEnabled()) {
getLogger().debug("get metadata for experiment {}", this.experimentKey);
}
return loadRemote(restApiClient::getMetadata, "METADATA");
}
@Override
public GitMetadataRest getGitMetadata() {
if (getLogger().isDebugEnabled()) {
getLogger().debug("get git metadata for experiment {}", this.experimentKey);
}
return loadRemote(restApiClient::getGitMetadata, "GIT METADATA");
}
@Override
public Optional getHtml() {
if (getLogger().isDebugEnabled()) {
getLogger().debug("get html for experiment {}", this.experimentKey);
}
return Optional.ofNullable(loadRemote(restApiClient::getHtml, "HTML").getHtml());
}
@Override
public Optional getOutput() {
if (getLogger().isDebugEnabled()) {
getLogger().debug("get output for experiment {}", this.experimentKey);
}
return Optional.ofNullable(loadRemote(restApiClient::getOutput, "StdOut").getOutput());
}
@Override
public Optional getGraph() {
if (getLogger().isDebugEnabled()) {
getLogger().debug("get graph for experiment {}", this.experimentKey);
}
return Optional.ofNullable(loadRemote(restApiClient::getGraph, "GRAPH").getGraph());
}
@Override
public List getParameters() {
if (getLogger().isDebugEnabled()) {
getLogger().debug("get params for experiment {}", this.experimentKey);
}
return loadRemote(restApiClient::getParameters, "PARAMETERS").getValues();
}
@Override
public List getMetrics() {
if (getLogger().isDebugEnabled()) {
getLogger().debug("get metrics summary for experiment {}", this.experimentKey);
}
return loadRemote(restApiClient::getMetrics, "METRICS").getValues();
}
@Override
public List getLogOther() {
if (getLogger().isDebugEnabled()) {
getLogger().debug("get log other for experiment {}", this.experimentKey);
}
return loadRemote(restApiClient::getLogOther, "OTHER PARAMETERS").getValues();
}
@Override
public List getTags() {
if (getLogger().isDebugEnabled()) {
getLogger().debug("get tags for experiment {}", this.experimentKey);
}
return loadRemote(restApiClient::getTags, "TAGs").getTags();
}
@Override
public List getAssetList(@NonNull AssetType type) {
if (getLogger().isDebugEnabled()) {
getLogger().debug("get assets with type {} for experiment {}", type, this.experimentKey);
}
return validateAndGetExperimentKey()
.concatMap(experimentKey -> restApiClient.getAssetList(experimentKey, type))
.doOnError(ex -> getLogger().error("Failed to read ASSETS list for the experiment, experiment key: {}",
this.experimentKey, ex))
.blockingGet()
.getAssets();
}
@Override
public void end() {
if (!this.alive) {
return;
}
getLogger().info(getString(EXPERIMENT_CLEANUP_PROMPT, cleaningTimeout.getSeconds()));
// mark as not alive
this.alive = false;
// close REST API
if (this.restApiClient != null) {
this.restApiClient.dispose();
}
// close connection
if (this.connection != null) {
try {
this.connection.waitAndClose(this.cleaningTimeout);
this.connection = null;
} catch (Exception e) {
getLogger().error("failed to close connection", e);
}
}
}
/**
* Sends heartbeat to the server and returns the status response.
*
* @return the status response of the experiment.
*/
Optional sendExperimentStatus() {
return Optional.ofNullable(validateAndGetExperimentKey()
.concatMap(experimentKey -> restApiClient.sendExperimentStatus(experimentKey))
.onErrorComplete()
.blockingGet());
}
/**
* Synchronously loads remote data using provided load function or throws an exception.
*
* @param loadFunc the function to be applied to load remote data.
* @param alias the data type alias used for logging.
* @param the data type to be returned.
* @return the loaded data.
*/
private T loadRemote(final Function> loadFunc, String alias) {
return validateAndGetExperimentKey()
.concatMap(loadFunc)
.doOnError(ex -> getLogger().error(
getString(FAILED_READ_DATA_FOR_EXPERIMENT, alias, this.experimentKey), ex))
.blockingGet();
}
/**
* Uses provided function to send request data synchronously. If response indicating the remote error
* received the {@link CometApiException} will be thrown.
*
* @param func the function to be invoked to send request data.
* @param request the request data object.
* @param the type of the request data object.
* @throws CometApiException if received response with error indicating that data was not saved.
*/
private void sendSynchronously(final BiFunction> func,
final T request) throws CometApiException {
LogDataResponse response = validateAndGetExperimentKey()
.concatMap(experimentKey -> func.apply(request, experimentKey))
.blockingGet();
if (response.hasFailed()) {
throw new CometApiException("Failed to log {}, reason: %s", request, response.getMsg());
}
}
/**
* Validates the state of the experiment.
*
* @throws IllegalStateException if current state of the experiment is wrong, i.e., no experiment key found or
* experiment already ended.
*/
private void validate() throws IllegalStateException {
if (StringUtils.isBlank(this.experimentKey)) {
throw new IllegalStateException("Experiment key must be present!");
}
if (!this.alive) {
throw new IllegalStateException("Experiment was not initialized. You need to call init().");
}
}
/**
* Validates the experiment state and return the experiment key or error as a {@link Single}.
*
* @return the experiment key or error as {@link Single}.
*/
Single validateAndGetExperimentKey() {
if (StringUtils.isBlank(this.experimentKey)) {
return Single.error(new IllegalStateException("Experiment key must be present!"));
}
if (!this.alive) {
return Single.error(new IllegalStateException("Experiment is not alive or already closed."));
}
return Single.just(getExperimentKey());
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy