ml.comet.experiment.impl.BaseExperiment Maven / Gradle / Ivy
package ml.comet.experiment.impl;
import io.reactivex.rxjava3.core.Single;
import io.reactivex.rxjava3.disposables.Disposable;
import io.reactivex.rxjava3.functions.BiFunction;
import io.reactivex.rxjava3.functions.Function;
import lombok.Getter;
import lombok.NonNull;
import ml.comet.experiment.Experiment;
import ml.comet.experiment.artifact.Artifact;
import ml.comet.experiment.artifact.ArtifactAsset;
import ml.comet.experiment.artifact.ArtifactDownloadException;
import ml.comet.experiment.artifact.ArtifactException;
import ml.comet.experiment.artifact.ArtifactNotFoundException;
import ml.comet.experiment.artifact.AssetOverwriteStrategy;
import ml.comet.experiment.artifact.GetArtifactOptions;
import ml.comet.experiment.artifact.InvalidArtifactStateException;
import ml.comet.experiment.artifact.LoggedArtifact;
import ml.comet.experiment.artifact.LoggedArtifactAsset;
import ml.comet.experiment.asset.LoggedExperimentAsset;
import ml.comet.experiment.context.ExperimentContext;
import ml.comet.experiment.exception.CometApiException;
import ml.comet.experiment.exception.CometGeneralException;
import ml.comet.experiment.impl.asset.ArtifactAssetImpl;
import ml.comet.experiment.impl.asset.AssetImpl;
import ml.comet.experiment.impl.asset.DownloadArtifactAssetOptions;
import ml.comet.experiment.impl.http.Connection;
import ml.comet.experiment.impl.http.ConnectionInitializer;
import ml.comet.experiment.impl.rest.ArtifactDto;
import ml.comet.experiment.impl.rest.ArtifactEntry;
import ml.comet.experiment.impl.rest.ArtifactRequest;
import ml.comet.experiment.impl.rest.ArtifactVersionAssetResponse;
import ml.comet.experiment.impl.rest.ArtifactVersionDetail;
import ml.comet.experiment.impl.rest.ArtifactVersionState;
import ml.comet.experiment.impl.rest.CreateExperimentRequest;
import ml.comet.experiment.impl.rest.CreateExperimentResponse;
import ml.comet.experiment.impl.rest.ExperimentStatusResponse;
import ml.comet.experiment.impl.rest.MinMaxResponse;
import ml.comet.experiment.impl.rest.RestApiResponse;
import ml.comet.experiment.impl.utils.CometUtils;
import ml.comet.experiment.impl.utils.ExceptionUtils;
import ml.comet.experiment.impl.utils.FileUtils;
import ml.comet.experiment.impl.utils.SystemUtils;
import ml.comet.experiment.model.ExperimentMetadata;
import ml.comet.experiment.model.GitMetaData;
import ml.comet.experiment.model.Value;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import java.io.File;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.FileAlreadyExistsException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import static java.util.Optional.empty;
import static ml.comet.experiment.impl.asset.AssetType.ALL;
import static ml.comet.experiment.impl.asset.AssetType.SOURCE_CODE;
import static ml.comet.experiment.impl.constants.SdkErrorCodes.artifactVersionStateNotClosed;
import static ml.comet.experiment.impl.constants.SdkErrorCodes.artifactVersionStateNotClosedErrorOccurred;
import static ml.comet.experiment.impl.constants.SdkErrorCodes.noArtifactFound;
import static ml.comet.experiment.impl.resources.LogMessages.ARTIFACT_ASSETS_FILE_EXISTS_PRESERVING;
import static ml.comet.experiment.impl.resources.LogMessages.ARTIFACT_DOWNLOAD_FILE_OVERWRITTEN;
import static ml.comet.experiment.impl.resources.LogMessages.ARTIFACT_HAS_NO_DETAILS;
import static ml.comet.experiment.impl.resources.LogMessages.ARTIFACT_NOT_FOUND;
import static ml.comet.experiment.impl.resources.LogMessages.ARTIFACT_NOT_READY;
import static ml.comet.experiment.impl.resources.LogMessages.ARTIFACT_VERSION_CREATED_WITHOUT_PREVIOUS;
import static ml.comet.experiment.impl.resources.LogMessages.ARTIFACT_VERSION_CREATED_WITH_PREVIOUS;
import static ml.comet.experiment.impl.resources.LogMessages.COMPLETED_DOWNLOAD_ARTIFACT_ASSET;
import static ml.comet.experiment.impl.resources.LogMessages.EXPERIMENT_CREATED;
import static ml.comet.experiment.impl.resources.LogMessages.EXPERIMENT_LIVE;
import static ml.comet.experiment.impl.resources.LogMessages.FAILED_READ_DATA_FOR_EXPERIMENT;
import static ml.comet.experiment.impl.resources.LogMessages.FAILED_REGISTER_EXPERIMENT;
import static ml.comet.experiment.impl.resources.LogMessages.FAILED_TO_COMPARE_CONTENT_OF_FILES;
import static ml.comet.experiment.impl.resources.LogMessages.FAILED_TO_CREATE_TEMPORARY_ASSET_DOWNLOAD_FILE;
import static ml.comet.experiment.impl.resources.LogMessages.FAILED_TO_DELETE_TEMPORARY_ASSET_FILE;
import static ml.comet.experiment.impl.resources.LogMessages.FAILED_TO_DOWNLOAD_ASSET;
import static ml.comet.experiment.impl.resources.LogMessages.FAILED_TO_DOWNLOAD_ASSET_FILE_ALREADY_EXISTS;
import static ml.comet.experiment.impl.resources.LogMessages.FAILED_TO_READ_DOWNLOADED_FILE_SIZE;
import static ml.comet.experiment.impl.resources.LogMessages.FAILED_TO_READ_LOGGED_ARTIFACT_ASSETS;
import static ml.comet.experiment.impl.resources.LogMessages.FAILED_TO_RESOLVE_ASSET_FILE;
import static ml.comet.experiment.impl.resources.LogMessages.FAILED_TO_UPDATE_ARTIFACT_VERSION_STATE;
import static ml.comet.experiment.impl.resources.LogMessages.FAILED_TO_UPSERT_ARTIFACT;
import static ml.comet.experiment.impl.resources.LogMessages.GET_ARTIFACT_FAILED_UNEXPECTEDLY;
import static ml.comet.experiment.impl.resources.LogMessages.REMOTE_ASSET_CANNOT_BE_DOWNLOADED;
import static ml.comet.experiment.impl.resources.LogMessages.getString;
import static ml.comet.experiment.impl.utils.AssetUtils.createAssetFromData;
import static ml.comet.experiment.impl.utils.AssetUtils.createAssetFromFile;
import static ml.comet.experiment.impl.utils.RestApiUtils.createArtifactUpsertRequest;
import static ml.comet.experiment.impl.utils.RestApiUtils.createArtifactVersionStateRequest;
import static ml.comet.experiment.impl.utils.RestApiUtils.createGitMetadataRequest;
import static ml.comet.experiment.impl.utils.RestApiUtils.createGraphRequest;
import static ml.comet.experiment.impl.utils.RestApiUtils.createLogEndTimeRequest;
import static ml.comet.experiment.impl.utils.RestApiUtils.createLogHtmlRequest;
import static ml.comet.experiment.impl.utils.RestApiUtils.createLogLineRequest;
import static ml.comet.experiment.impl.utils.RestApiUtils.createLogMetricRequest;
import static ml.comet.experiment.impl.utils.RestApiUtils.createLogOtherRequest;
import static ml.comet.experiment.impl.utils.RestApiUtils.createLogParamRequest;
import static ml.comet.experiment.impl.utils.RestApiUtils.createLogStartTimeRequest;
import static ml.comet.experiment.impl.utils.RestApiUtils.createTagRequest;
/**
* The base class for all synchronous experiment implementations providing implementation of common routines
* using synchronous networking.
*/
abstract class BaseExperiment implements Experiment {
final String apiKey;
final String baseUrl;
final int maxAuthRetries;
final Duration cleaningTimeout;
String projectName;
String workspaceName;
String experimentKey;
String experimentLink;
String experimentName;
boolean alive;
@Getter
private RestApiClient restApiClient;
@Getter
private Connection connection;
/**
* Returns logger instance associated with particular experiment. The subclasses should override this method to
* provide specific logger instance.
*
* @return the logger instance associated with particular experiment.
*/
protected abstract Logger getLogger();
BaseExperiment(@NonNull final String apiKey,
@NonNull final String baseUrl,
int maxAuthRetries,
final String experimentKey,
@NonNull final Duration cleaningTimeout) {
this(apiKey, baseUrl, maxAuthRetries, experimentKey, cleaningTimeout, StringUtils.EMPTY, StringUtils.EMPTY);
}
BaseExperiment(@NonNull final String apiKey,
@NonNull final String baseUrl,
int maxAuthRetries,
final String experimentKey,
@NonNull final Duration cleaningTimeout,
final String projectName,
final String workspaceName) {
this.apiKey = apiKey;
this.baseUrl = baseUrl;
this.maxAuthRetries = maxAuthRetries;
this.experimentKey = experimentKey;
this.cleaningTimeout = cleaningTimeout;
this.projectName = projectName;
this.workspaceName = workspaceName;
}
/**
* Invoked to validate and initialize common fields used by all subclasses.
*/
void init() {
CometUtils.printCometSdkVersion();
validateInitialParams();
this.connection = ConnectionInitializer.initConnection(
this.apiKey, this.baseUrl, this.maxAuthRetries, this.getLogger());
this.restApiClient = new RestApiClient(this.connection);
// mark as initialized
this.alive = true;
}
/**
* Validates initial parameters and throws exception if validation failed.
*
* @throws IllegalArgumentException if validation failed.
*/
private void validateInitialParams() throws IllegalArgumentException {
if (StringUtils.isBlank(this.apiKey)) {
throw new IllegalArgumentException("API key is not specified!");
}
if (StringUtils.isBlank(this.baseUrl)) {
throw new IllegalArgumentException("The Comet base URL is not specified!");
}
}
/**
* Synchronously registers experiment at the Comet server.
*
* @throws CometGeneralException if failed to register experiment.
*/
void registerExperiment() throws CometGeneralException {
if (StringUtils.isNotBlank(this.experimentKey)) {
getLogger().debug("Not registering a new experiment. Using previous experiment key {}", this.experimentKey);
return;
}
// do synchronous call to register experiment
try {
CreateExperimentResponse result = this.restApiClient.registerExperiment(
new CreateExperimentRequest(this.workspaceName, this.projectName, this.experimentName))
.blockingGet();
if (StringUtils.isBlank(result.getExperimentKey())) {
throw new CometGeneralException(getString(FAILED_REGISTER_EXPERIMENT));
}
this.experimentKey = result.getExperimentKey();
this.experimentLink = result.getLink();
this.workspaceName = result.getWorkspaceName();
this.projectName = result.getProjectName();
if (StringUtils.isBlank(this.experimentName)) {
this.experimentName = result.getName();
}
} catch (CometApiException ex) {
this.getLogger().error(getString(FAILED_REGISTER_EXPERIMENT), ex);
throw new CometGeneralException(getString(FAILED_REGISTER_EXPERIMENT), ex);
}
getLogger().info(getString(EXPERIMENT_CREATED, this.workspaceName, this.projectName, this.experimentName));
getLogger().info(getString(EXPERIMENT_LIVE, this.experimentLink));
}
/**
* Allows logging of all available details about the host system where experiment is executing. This is blocking
* operation which will block invoking thread until it completes.
*
* @throws CometApiException if API access exception occurs.
*/
void logSystemDetails() throws CometApiException {
sendSynchronously(restApiClient::logSystemDetails, SystemUtils.readSystemDetails());
}
@Override
public String getExperimentKey() {
return this.experimentKey;
}
@Override
public String getProjectName() {
return this.projectName;
}
@Override
public String getWorkspaceName() {
return this.workspaceName;
}
@Override
public String getExperimentName() {
return this.experimentName;
}
@Override
public void setExperimentName(@NonNull String experimentName) {
logOther("Name", experimentName);
this.experimentName = experimentName;
}
/**
* Synchronous version that waits for result or exception. Also, it checks the response status for failure.
*
* @param metricName The name for the metric to be logged
* @param metricValue The new value for the metric. If the values for a metric are plottable we will plot them
* @param context the context to be associated with the parameter.
* @throws CometApiException if received response with failure code.
*/
@Override
public void logMetric(@NonNull String metricName, @NonNull Object metricValue,
@NonNull ExperimentContext context) {
if (getLogger().isDebugEnabled()) {
getLogger().debug("logMetric {} = {}, context: {}", metricName, metricValue, context);
}
sendSynchronously(restApiClient::logMetric,
createLogMetricRequest(metricName, metricValue, context));
}
@Override
public void logMetric(String metricName, Object metricValue, long step, long epoch) {
this.logMetric(metricName, metricValue, new ExperimentContext(step, epoch));
}
/**
* Synchronous version that waits for result or exception. Also, it checks the response status for failure.
*
* @param parameterName The name of the param being logged
* @param paramValue The value for the param being logged
* @param context the context to be associated with the parameter.
*/
@Override
public void logParameter(String parameterName, Object paramValue, ExperimentContext context) {
if (getLogger().isDebugEnabled()) {
getLogger().debug("logParameter {} = {}, context: {}", parameterName, paramValue, context);
}
sendSynchronously(restApiClient::logParameter,
createLogParamRequest(parameterName, paramValue, context));
}
@Override
public void logParameter(String parameterName, Object paramValue, long step) {
this.logParameter(parameterName, paramValue, new ExperimentContext(step));
}
/**
* Synchronous version that waits for result or exception. Also, it checks the response status for failure.
*
* @param line Text to be logged
* @param offset Offset describes the place for current text to be inserted
* @param stderr the flag to indicate if this is StdErr message.
* @param context the context to be associated with the parameter.
*/
@Override
public void logLine(String line, long offset, boolean stderr, String context) {
validate();
sendSynchronously(restApiClient::logOutputLine,
createLogLineRequest(line, offset, stderr, context));
}
/**
* Synchronous version that waits for result or exception. Also, it checks the response status for failure.
*
* @param html A block of html to be sent to Comet
* @param override Whether previous html sent should be deleted.
* If true
the old html will be deleted.
*/
@Override
public void logHtml(@NonNull String html, boolean override) {
if (getLogger().isDebugEnabled()) {
getLogger().debug("logHtml {}, override: {}", html, override);
}
sendSynchronously(restApiClient::logHtml, createLogHtmlRequest(html, override));
}
/**
* Synchronous version that waits for result or exception. Also, it checks the response status for failure.
*
* @param key The key for the data to be stored
* @param value The value for said key
*/
@Override
public void logOther(@NonNull String key, @NonNull Object value) {
if (getLogger().isDebugEnabled()) {
getLogger().debug("logOther {} {}", key, value);
}
sendSynchronously(restApiClient::logOther, createLogOtherRequest(key, value));
}
/**
* Synchronous version that waits for result or exception. Also, it checks the response status for failure.
*
* @param tag The tag to be added
*/
@Override
public void addTag(@NonNull String tag) {
if (getLogger().isDebugEnabled()) {
getLogger().debug("addTag {}", tag);
}
sendSynchronously(restApiClient::addTag, createTagRequest(tag));
}
/**
* Synchronous version that waits for result or exception. Also, it checks the response status for failure.
*
* @param graph The graph to be logged.
*/
@Override
public void logGraph(@NonNull String graph) {
if (getLogger().isDebugEnabled()) {
getLogger().debug("logGraph {}", graph);
}
sendSynchronously(restApiClient::logGraph, createGraphRequest(graph));
}
/**
* Synchronous version that waits for result or exception. Also, it checks the response status for failure.
*
* @param startTimeMillis When you want to say that the experiment started
*/
@Override
public void logStartTime(long startTimeMillis) {
if (getLogger().isDebugEnabled()) {
getLogger().debug("logStartTime {}", startTimeMillis);
}
sendSynchronously(restApiClient::logStartEndTime, createLogStartTimeRequest(startTimeMillis));
}
/**
* Synchronous version that waits for result or exception. Also, it checks the response status for failure.
*
* @param endTimeMillis When you want to say that the experiment ended
*/
@Override
public void logEndTime(long endTimeMillis) {
if (getLogger().isDebugEnabled()) {
getLogger().debug("logEndTime {}", endTimeMillis);
}
sendSynchronously(restApiClient::logStartEndTime, createLogEndTimeRequest(endTimeMillis));
}
/**
* Synchronous version that waits for result or exception. Also, it checks the response status for failure.
*
* @param gitMetaData The Git Metadata for the experiment.
*/
@Override
public void logGitMetadata(GitMetaData gitMetaData) {
if (getLogger().isDebugEnabled()) {
getLogger().debug("logGitMetadata {}", gitMetaData);
}
sendSynchronously(restApiClient::logGitMetadata, createGitMetadataRequest(gitMetaData));
}
@Override
public void logCode(@NonNull String code, @NonNull String logicalPath, @NonNull ExperimentContext context) {
if (getLogger().isDebugEnabled()) {
getLogger().debug("log raw source code, file name: {}", logicalPath);
}
AssetImpl asset = createAssetFromData(code.getBytes(StandardCharsets.UTF_8), logicalPath, false,
empty(), Optional.of(SOURCE_CODE.type()));
this.logAsset(asset, context);
}
@Override
public void logCode(String code, String logicalPath) {
this.logCode(code, logicalPath, ExperimentContext.empty());
}
@Override
public void logCode(@NonNull File file, @NonNull ExperimentContext context) {
if (getLogger().isDebugEnabled()) {
getLogger().debug("log source code from file {}", file.getName());
}
AssetImpl asset = createAssetFromFile(file, empty(), false,
empty(), Optional.of(SOURCE_CODE.type()));
this.logAsset(asset, context);
}
@Override
public void logCode(File file) {
this.logCode(file, ExperimentContext.empty());
}
@Override
public void uploadAsset(@NonNull File file, @NonNull String logicalPath,
boolean overwrite, @NonNull ExperimentContext context) {
if (getLogger().isDebugEnabled()) {
getLogger().debug("uploadAsset from file {}, name {}, override {}, context {}",
file.getName(), logicalPath, overwrite, context);
}
AssetImpl asset = createAssetFromFile(file, Optional.of(logicalPath), overwrite, empty(), empty());
this.logAsset(asset, context);
}
@Override
public void uploadAsset(@NonNull File asset, String logicalPath, boolean overwrite, long step, long epoch) {
this.uploadAsset(asset, logicalPath, overwrite, new ExperimentContext(step, epoch));
}
@Override
public void uploadAsset(@NonNull File asset, boolean overwrite, @NonNull ExperimentContext context) {
this.uploadAsset(asset, asset.getName(), overwrite, context);
}
@Override
public void uploadAsset(@NonNull File asset, boolean overwrite, long step, long epoch) {
this.uploadAsset(asset, overwrite, new ExperimentContext(step, epoch));
}
/**
* Synchronously logs provided asset.
*
* @param asset the {@link AssetImpl} to be uploaded
*/
void logAsset(@NonNull final AssetImpl asset, @NonNull ExperimentContext context) {
asset.setContext(context);
sendSynchronously(restApiClient::logAsset, asset);
}
/**
* Synchronously upsert provided Comet artifact into Comet backend.
*
* @param artifact the {@link ArtifactImpl} instance.
* @return the {@link ArtifactEntry} describing saved artifact.
* @throws ArtifactException if operation failed.
*/
ArtifactEntry upsertArtifact(@NonNull final Artifact artifact) throws ArtifactException {
try {
ArtifactImpl artifactImpl = (ArtifactImpl) artifact;
ArtifactRequest request = createArtifactUpsertRequest(artifactImpl);
ArtifactEntry response = validateAndGetExperimentKey()
.concatMap(experimentKey -> getRestApiClient().upsertArtifact(request, experimentKey))
.blockingGet();
if (StringUtils.isBlank(response.getPreviousVersion())) {
getLogger().info(
getString(ARTIFACT_VERSION_CREATED_WITHOUT_PREVIOUS,
artifactImpl.getName(), response.getCurrentVersion()));
} else {
getLogger().info(
getString(ARTIFACT_VERSION_CREATED_WITH_PREVIOUS,
artifactImpl.getName(), response.getCurrentVersion(), response.getPreviousVersion())
);
}
return response;
} catch (Throwable e) {
throw new ArtifactException(getString(FAILED_TO_UPSERT_ARTIFACT, artifact), e);
}
}
/**
* Synchronously updates the state associated with Comet artifact version.
*
* @param artifactVersionId the artifact version identifier.
* @param state the state to be associated.
* @throws ArtifactException is operation failed.
*/
void updateArtifactVersionState(@NonNull String artifactVersionId, @NonNull ArtifactVersionState state)
throws ArtifactException {
try {
ArtifactRequest request = createArtifactVersionStateRequest(artifactVersionId, state);
sendSynchronously(getRestApiClient()::updateArtifactState, request);
} catch (Throwable e) {
throw new ArtifactException(getString(FAILED_TO_UPDATE_ARTIFACT_VERSION_STATE, artifactVersionId), e);
}
}
/**
* Synchronously retrieves all data about a specific Artifact Version.
*
* @param options the {@link GetArtifactOptions} defining query options.
* @return the {@link LoggedArtifact} instance holding all data about a specific Artifact Version.
* @throws ArtifactNotFoundException if artifact is not found or no artifact data returned.
* @throws InvalidArtifactStateException if artifact was not closed or has empty artifact data returned.
* @throws ArtifactException if failed to get artifact due to the unexpected error.
*/
LoggedArtifact getArtifactVersionDetail(@NonNull GetArtifactOptions options)
throws ArtifactNotFoundException, InvalidArtifactStateException, ArtifactException {
try {
ArtifactVersionDetail detail = validateAndGetExperimentKey()
.concatMap(experimentKey -> getRestApiClient().getArtifactVersionDetail(options, experimentKey))
.blockingGet();
ArtifactDto artifactDto = detail.getArtifact();
if (artifactDto == null) {
throw new InvalidArtifactStateException(getString(ARTIFACT_HAS_NO_DETAILS, options));
}
return detail.copyToLoggedArtifact(
new LoggedArtifactImpl(artifactDto.getArtifactName(), artifactDto.getArtifactType(), this));
} catch (CometApiException apiException) {
switch (apiException.getSdkErrorCode()) {
case noArtifactFound:
throw new ArtifactNotFoundException(getString(ARTIFACT_NOT_FOUND, options), apiException);
case artifactVersionStateNotClosed:
case artifactVersionStateNotClosedErrorOccurred:
throw new InvalidArtifactStateException(getString(ARTIFACT_NOT_READY, options), apiException);
default:
throw new ArtifactException(getString(GET_ARTIFACT_FAILED_UNEXPECTEDLY, options), apiException);
}
} catch (Throwable e) {
throw new ArtifactException(getString(GET_ARTIFACT_FAILED_UNEXPECTEDLY, options), e);
}
}
@Override
public LoggedArtifact getArtifact(@NonNull String name, @NonNull String workspace, @NonNull String versionOrAlias)
throws ArtifactException {
if (name.contains("/") || name.contains(":")) {
throw new IllegalArgumentException(
"Only simple artifact name allowed for this method without slash (/) or colon (:) characters.");
}
return this.getArtifact(GetArtifactOptions.Op()
.name(name)
.workspaceName(workspace)
.versionOrAlias(versionOrAlias)
.consumerExperimentKey(this.experimentKey)
.build());
}
@Override
public LoggedArtifact getArtifact(@NonNull String name, @NonNull String workspace) throws ArtifactException {
if (name.contains("/")) {
throw new IllegalArgumentException(
"The name of artifact for this method should not include workspace or workspace separator (/).");
}
return this.getArtifact(GetArtifactOptions.Op()
.fullName(name)
.workspaceName(workspace)
.consumerExperimentKey(this.experimentKey)
.build());
}
@Override
public LoggedArtifact getArtifact(@NonNull String name) throws ArtifactException {
return this.getArtifact(GetArtifactOptions.Op()
.fullName(name)
.consumerExperimentKey(this.experimentKey)
.build());
}
LoggedArtifact getArtifact(@NonNull GetArtifactOptions options) throws ArtifactException {
return this.getArtifactVersionDetail(options);
}
/**
* Reads list of assets associated with provided Comet artifact.
*
* @param artifact the {@link LoggedArtifact} to get assets.
* @return the list of assets associated with provided Comet artifact.
* @throws ArtifactException if failed to read list of associated assets.
*/
Collection readArtifactAssets(@NonNull LoggedArtifactImpl artifact) throws ArtifactException {
GetArtifactOptions options = GetArtifactOptions.Op()
.artifactId(artifact.getArtifactId())
.versionId(artifact.getVersionId())
.build();
try {
ArtifactVersionAssetResponse response = this.getRestApiClient()
.getArtifactVersionFiles(options)
.blockingGet();
return response.getFiles()
.stream()
.collect(ArrayList::new,
(assets, artifactVersionAsset) -> assets.add(
artifactVersionAsset.copyTo(new LoggedArtifactAssetImpl(artifact))),
ArrayList::addAll);
} catch (Throwable t) {
String message = getString(FAILED_TO_READ_LOGGED_ARTIFACT_ASSETS, artifact.getFullName());
this.getLogger().error(message, t);
throw new ArtifactException(message, t);
}
}
/**
* Allows to synchronously download specific {@link LoggedArtifactAsset} to the local file system.
*
* @param asset the asset to be downloaded.
* @param dir the parent directory where asset file should be stored.
* @param file the relative path to the asset file.
* @param overwriteStrategy the overwrite strategy to be applied if file already exists.
* @return the {@link ArtifactAsset} instance with details about downloaded asset file.
* @throws ArtifactDownloadException if failed to download asset.
*/
ArtifactAssetImpl downloadArtifactAsset(@NonNull LoggedArtifactAssetImpl asset, @NonNull Path dir,
@NonNull Path file, @NonNull AssetOverwriteStrategy overwriteStrategy)
throws ArtifactDownloadException {
if (asset.isRemote()) {
throw new ArtifactDownloadException(getString(REMOTE_ASSET_CANNOT_BE_DOWNLOADED, asset));
}
Path resolved;
boolean fileAlreadyExists = false;
try {
Optional optionalPath = FileUtils.resolveAssetPath(dir, file, overwriteStrategy);
if (optionalPath.isPresent()) {
// new or overwrite
resolved = optionalPath.get();
if (overwriteStrategy == AssetOverwriteStrategy.OVERWRITE) {
getLogger().warn(getString(ARTIFACT_DOWNLOAD_FILE_OVERWRITTEN, resolved, asset.getLogicalPath(),
asset.artifact.getFullName()));
}
} else {
// preventing original file - just warning and return FileAsset pointing to it
resolved = dir.resolve(file);
this.getLogger().warn(
getString(ARTIFACT_ASSETS_FILE_EXISTS_PRESERVING, resolved, asset.artifact.getFullName()));
return new ArtifactAssetImpl(asset.getLogicalPath(), resolved, Files.size(resolved),
asset.getMetadata(), asset.getAssetType());
}
} catch (FileAlreadyExistsException e) {
if (overwriteStrategy == AssetOverwriteStrategy.FAIL_IF_DIFFERENT) {
try {
resolved = Files.createTempFile(asset.getLogicalPath(), null);
this.getLogger().debug(
"File '{}' already exists for asset {} and FAIL override strategy selected. "
+ "Start downloading to the temporary file '{}'", file, asset, resolved);
} catch (IOException ex) {
String msg = getString(FAILED_TO_CREATE_TEMPORARY_ASSET_DOWNLOAD_FILE, file, asset);
this.getLogger().error(msg, ex);
throw new ArtifactDownloadException(msg, ex);
}
fileAlreadyExists = true;
} else {
this.getLogger().error(
getString(FAILED_TO_DOWNLOAD_ASSET_FILE_ALREADY_EXISTS, asset, file), e);
throw new ArtifactDownloadException(
getString(FAILED_TO_DOWNLOAD_ASSET_FILE_ALREADY_EXISTS, asset, file), e);
}
} catch (IOException e) {
this.getLogger().error(getString(FAILED_TO_RESOLVE_ASSET_FILE, file, asset), e);
throw new ArtifactDownloadException(getString(FAILED_TO_RESOLVE_ASSET_FILE, file, asset), e);
}
DownloadArtifactAssetOptions opts = new DownloadArtifactAssetOptions(
asset.getAssetId(), asset.getArtifactVersionId(), resolved.toFile());
RestApiResponse response = validateAndGetExperimentKey()
.concatMap(experimentKey -> getRestApiClient().downloadArtifactAsset(opts, experimentKey))
.blockingGet();
if (response.hasFailed()) {
this.getLogger().error(getString(FAILED_TO_DOWNLOAD_ASSET, asset, response));
throw new ArtifactDownloadException(getString(FAILED_TO_DOWNLOAD_ASSET, asset, response));
}
// check the content of the downloaded file in case of FAIL overwrite strategy when file already exists
// this is just to mirror the Python SDK's behavior - potential performance bottleneck and system resource eater
if (fileAlreadyExists) {
Path assetFilePath = FileUtils.assetFilePath(dir, file);
try {
if (!FileUtils.fileContentsEquals(assetFilePath, resolved)) {
this.getLogger().error(
getString(FAILED_TO_DOWNLOAD_ASSET_FILE_ALREADY_EXISTS, asset, file));
throw new ArtifactDownloadException(
getString(FAILED_TO_DOWNLOAD_ASSET_FILE_ALREADY_EXISTS, asset, file));
}
} catch (IOException e) {
this.getLogger().error(getString(FAILED_TO_COMPARE_CONTENT_OF_FILES, file, resolved), e);
throw new ArtifactDownloadException(getString(FAILED_TO_COMPARE_CONTENT_OF_FILES, file, resolved), e);
} finally {
try {
Files.deleteIfExists(resolved);
} catch (IOException e) {
this.getLogger().error(getString(FAILED_TO_DELETE_TEMPORARY_ASSET_FILE, resolved, asset), e);
}
}
resolved = assetFilePath;
}
getLogger().info(getString(COMPLETED_DOWNLOAD_ARTIFACT_ASSET, asset.getLogicalPath(), resolved));
try {
return new ArtifactAssetImpl(asset.getLogicalPath(), resolved, Files.size(resolved),
asset.getMetadata(), asset.getAssetType());
} catch (IOException e) {
this.getLogger().error(getString(FAILED_TO_READ_DOWNLOADED_FILE_SIZE, resolved), e);
throw new ArtifactDownloadException(getString(FAILED_TO_READ_DOWNLOADED_FILE_SIZE, resolved), e);
}
}
@Override
public ExperimentMetadata getMetadata() {
if (getLogger().isDebugEnabled()) {
getLogger().debug("get metadata for experiment {}", this.experimentKey);
}
return loadRemote(restApiClient::getMetadata, "METADATA").toExperimentMetadata();
}
@Override
public GitMetaData getGitMetadata() {
if (getLogger().isDebugEnabled()) {
getLogger().debug("get git metadata for experiment {}", this.experimentKey);
}
return loadRemote(restApiClient::getGitMetadata, "GIT METADATA").toGitMetaData();
}
@Override
public Optional getHtml() {
if (getLogger().isDebugEnabled()) {
getLogger().debug("get html for experiment {}", this.experimentKey);
}
return Optional.ofNullable(loadRemote(restApiClient::getHtml, "HTML").getHtml());
}
@Override
public Optional getOutput() {
if (getLogger().isDebugEnabled()) {
getLogger().debug("get output for experiment {}", this.experimentKey);
}
return Optional.ofNullable(loadRemote(restApiClient::getOutput, "StdOut").getOutput());
}
@Override
public Optional getGraph() {
if (getLogger().isDebugEnabled()) {
getLogger().debug("get graph for experiment {}", this.experimentKey);
}
return Optional.ofNullable(loadRemote(restApiClient::getGraph, "GRAPH").getGraph());
}
@Override
public List getParameters() {
if (getLogger().isDebugEnabled()) {
getLogger().debug("get params for experiment {}", this.experimentKey);
}
return loadRemoteValues(restApiClient::getParameters, "PARAMETERS");
}
@Override
public List getMetrics() {
if (getLogger().isDebugEnabled()) {
getLogger().debug("get metrics summary for experiment {}", this.experimentKey);
}
return loadRemoteValues(restApiClient::getMetrics, "METRICS");
}
@Override
public List getLogOther() {
if (getLogger().isDebugEnabled()) {
getLogger().debug("get log other for experiment {}", this.experimentKey);
}
return loadRemoteValues(restApiClient::getLogOther, "OTHER PARAMETERS");
}
@Override
public List getTags() {
if (getLogger().isDebugEnabled()) {
getLogger().debug("get tags for experiment {}", this.experimentKey);
}
return loadRemote(restApiClient::getTags, "TAGs").getTags();
}
@Override
public List getAssetList(@NonNull String type) {
if (getLogger().isDebugEnabled()) {
getLogger().debug("get assets with type {} for experiment {}", type, this.experimentKey);
}
return validateAndGetExperimentKey()
.concatMap(experimentKey -> restApiClient.getAssetList(experimentKey, type))
.doOnError(ex -> getLogger().error("Failed to read ASSETS list for the experiment, experiment key: {}",
this.experimentKey, ex))
.blockingGet()
.getAssets()
.stream()
.collect(ArrayList::new,
(assets, experimentAssetLink) -> assets.add(experimentAssetLink.toExperimentAsset(getLogger())),
ArrayList::addAll);
}
@Override
public List getAllAssetList() {
return this.getAssetList(ALL.type());
}
@Override
public void end() {
if (!this.alive) {
return;
}
// mark as not alive
this.alive = false;
// close REST API
if (this.restApiClient != null) {
this.restApiClient.dispose();
}
// close connection
if (this.connection != null) {
try {
this.connection.waitAndClose(this.cleaningTimeout);
this.connection = null;
} catch (Exception e) {
getLogger().error("failed to close connection", e);
}
}
}
/**
* Sends heartbeat to the server and returns the status response.
*
* @return the status response of the experiment.
*/
Optional sendExperimentStatus() {
if (!this.alive) {
return Optional.empty();
}
return Optional.ofNullable(validateAndGetExperimentKey()
.concatMap(experimentKey -> restApiClient.sendExperimentStatus(experimentKey))
.onErrorComplete()
.blockingGet());
}
/**
* Synchronously loads remote data values.
*
* @param loadFunc the function to be applied to load remote data.
* @param alias the data type alias used for logging.
* @return the list of values returned by REST API endpoint.
*/
private List loadRemoteValues(final Function> loadFunc, String alias) {
return this.loadRemote(loadFunc, alias)
.getValues()
.stream()
.collect(ArrayList::new,
(values, valueMinMaxRest) -> values.add(valueMinMaxRest.toValue()),
ArrayList::addAll);
}
/**
* Synchronously loads remote data using provided load function or throws an exception.
*
* @param loadFunc the function to be applied to load remote data.
* @param alias the data type alias used for logging.
* @param the data type to be returned.
* @return the loaded data.
*/
private T loadRemote(final Function> loadFunc, String alias) {
return validateAndGetExperimentKey()
.concatMap(loadFunc)
.doOnError(ex -> getLogger().error(
getString(FAILED_READ_DATA_FOR_EXPERIMENT, alias, this.experimentKey), ex))
.blockingGet();
}
/**
* Uses provided function to send request data synchronously. If response indicating the remote error
* received the {@link CometApiException} will be thrown.
*
* @param func the function to be invoked to send request data.
* @param request the request data object.
* @param the type of the request data object.
* @throws CometApiException if received response with error indicating that data was not saved.
*/
private void sendSynchronously(final BiFunction> func,
final T request) throws CometApiException {
CompletableFuture future = new CompletableFuture<>();
Disposable disposable = validateAndGetExperimentKey()
.concatMap(experimentKey -> func.apply(request, experimentKey))
.subscribe(
future::complete,
future::completeExceptionally
);
try {
RestApiResponse response = future.get();
if (response.hasFailed()) {
throw new CometApiException("Failed to log {}, reason: %s, sdk error code: %d",
request, response.getMsg(), response.getSdkErrorCode());
}
} catch (InterruptedException | ExecutionException e) {
Throwable rootCause = ExceptionUtils.unwrap(e);
if (rootCause instanceof CometApiException) {
// the root is CometApiException - rethrow it
throw (CometApiException) rootCause;
} else {
// wrap into runtime exception
throw new RuntimeException(e);
}
} finally {
disposable.dispose();
}
}
/**
* Validates the state of the experiment.
*
* @throws IllegalStateException if current state of the experiment is wrong, i.e., no experiment key found or
* experiment already ended.
*/
private void validate() throws IllegalStateException {
if (StringUtils.isBlank(this.experimentKey)) {
throw new IllegalStateException("Experiment key must be present!");
}
if (!this.alive) {
throw new IllegalStateException("Experiment was not initialized. You need to call init().");
}
}
/**
* Validates the experiment state and return the experiment key or error as a {@link Single}.
*
* @return the experiment key or error as {@link Single}.
*/
Single validateAndGetExperimentKey() {
if (StringUtils.isBlank(this.experimentKey)) {
return Single.error(new IllegalStateException("Experiment key must be present!"));
}
if (!this.alive) {
return Single.error(new IllegalStateException("Experiment is not alive or already closed."));
}
return Single.just(getExperimentKey());
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy