
org.mlflow.tracking.MlflowClient Maven / Gradle / Ivy
package org.mlflow.tracking;
import com.google.common.collect.Lists;
import org.apache.http.client.utils.URIBuilder;
import org.mlflow.artifacts.ArtifactRepository;
import org.mlflow.artifacts.ArtifactRepositoryFactory;
import org.mlflow.artifacts.CliBasedArtifactRepository;
import org.mlflow.api.proto.ModelRegistry.*;
import org.mlflow.api.proto.Service.*;
import org.mlflow.tracking.creds.*;
import java.io.Closeable;
import java.io.File;
import java.io.Serializable;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
/**
* Client to an MLflow Tracking Sever.
*/
public class MlflowClient implements Serializable, Closeable {
protected static final String DEFAULT_EXPERIMENT_ID = "0";
private static final String DEFAULT_MODELS_ARTIFACT_REPOSITORY_SCHEME = "models";
private final MlflowProtobufMapper mapper = new MlflowProtobufMapper();
private final ArtifactRepositoryFactory artifactRepositoryFactory;
private final MlflowHttpCaller httpCaller;
private final MlflowHostCredsProvider hostCredsProvider;
/** Return a default client based on the MLFLOW_TRACKING_URI environment variable. */
public MlflowClient() {
this(getDefaultTrackingUri());
}
/** Instantiate a new client using the provided tracking uri. */
public MlflowClient(String trackingUri) {
this(getHostCredsProviderFromTrackingUri(trackingUri));
}
/**
* Create a new MlflowClient; users should prefer constructing ApiClients via
* {@link #MlflowClient()} or {@link #MlflowClient(String)} if possible.
*/
public MlflowClient(MlflowHostCredsProvider hostCredsProvider) {
this.hostCredsProvider = hostCredsProvider;
this.httpCaller = new MlflowHttpCaller(hostCredsProvider);
this.artifactRepositoryFactory = new ArtifactRepositoryFactory(hostCredsProvider);
}
/**
* Get metadata, params, tags, and metrics for a run. A single value is returned for each metric
* key: the most recently logged metric value at the largest step.
*
* @return Run associated with the ID.
*/
public Run getRun(String runId) {
URIBuilder builder = newURIBuilder("runs/get")
.setParameter("run_uuid", runId)
.setParameter("run_id", runId);
return mapper.toGetRunResponse(httpCaller.get(builder.toString())).getRun();
}
public List getMetricHistory(String runId, String key) {
URIBuilder builder = newURIBuilder("metrics/get-history")
.setParameter("run_uuid", runId)
.setParameter("run_id", runId)
.setParameter("metric_key", key)
.setParameter("max_results", "25000");
GetMetricHistory.Response response = mapper
.toGetMetricHistoryResponse(httpCaller.get(builder.toString()));
List metrics = response.getMetricsList();
String token = response.getNextPageToken();
while (!token.isEmpty()) {
URIBuilder bld = builder.setParameter("page_token", token);
GetMetricHistory.Response resp = mapper
.toGetMetricHistoryResponse(httpCaller.get(bld.toString()));
metrics.addAll(resp.getMetricsList());
token = resp.getNextPageToken();
}
return metrics;
}
/**
* Create a new run under the default experiment with no application name.
* @return RunInfo created by the server.
*/
public RunInfo createRun() {
return createRun(DEFAULT_EXPERIMENT_ID);
}
/**
* Create a new run under the given experiment.
* @return RunInfo created by the server.
*/
public RunInfo createRun(String experimentId) {
CreateRun.Builder request = CreateRun.newBuilder();
request.setExperimentId(experimentId);
request.setStartTime(System.currentTimeMillis());
// userId is deprecated and will be removed in a future release.
// It should be set as the `mlflow.user` tag instead.
String username = System.getProperty("user.name");
if (username != null) {
request.setUserId(System.getProperty("user.name"));
}
return createRun(request.build());
}
/**
* Create a new run. This method allows providing all possible fields of CreateRun, and can be
* invoked as follows:
*
*
* import org.mlflow.api.proto.Service.CreateRun;
* CreateRun.Builder request = CreateRun.newBuilder();
* request.setExperimentId(experimentId);
* request.setSourceVersion("my-version");
* createRun(request.build());
*
*
* @return RunInfo created by the server.
*/
public RunInfo createRun(CreateRun request) {
String ijson = mapper.toJson(request);
String ojson = sendPost("runs/create", ijson);
return mapper.toCreateRunResponse(ojson).getRun().getInfo();
}
/**
* @return A list of all RunInfos associated with the given experiment.
*/
public List listRunInfos(String experimentId) {
List experimentIds = new ArrayList<>();
experimentIds.add(experimentId);
return searchRuns(experimentIds, null);
}
/**
* Return RunInfos from provided list of experiments that satisfy the search query.
* @deprecated As of 1.1.0 - please use {@link #searchRuns(List, String, ViewType, int)} or
* similar that returns a page of Run results.
*
* @param experimentIds List of experiment IDs.
* @param searchFilter SQL compatible search query string. Format of this query string is
* similar to that specified on MLflow UI.
* Example : "params.model = 'LogisticRegression' and metrics.acc = 0.9"
* If null, the result will be equivalent to having an empty search filter.
*
* @return A list of all RunInfos that satisfy search filter.
*/
public List searchRuns(List experimentIds, String searchFilter) {
return searchRuns(experimentIds, searchFilter, ViewType.ACTIVE_ONLY, 1000).getItems().stream()
.map(Run::getInfo).collect(Collectors.toList());
}
/**
* Return RunInfos from provided list of experiments that satisfy the search query.
* @deprecated As of 1.1.0 - please use {@link #searchRuns(List, String, ViewType, int)} or
* similar that returns a page of Run results.
*
* @param experimentIds List of experiment IDs.
* @param searchFilter SQL compatible search query string. Format of this query string is
* similar to that specified on MLflow UI.
* Example : "params.model = 'LogisticRegression' and metrics.acc != 0.9"
* If null, the result will be equivalent to having an empty search filter.
* @param runViewType ViewType for expected runs. One of (ACTIVE_ONLY, DELETED_ONLY, ALL)
* If null, only runs with viewtype ACTIVE_ONLY will be searched.
*
* @return A list of all RunInfos that satisfy search filter.
*/
public List searchRuns(List experimentIds,
String searchFilter,
ViewType runViewType) {
return searchRuns(experimentIds, searchFilter, runViewType, 1000).getItems().stream()
.map(Run::getInfo).collect(Collectors.toList());
}
/**
* Return runs from provided list of experiments that satisfy the search query.
*
* @param experimentIds List of experiment IDs.
* @param searchFilter SQL compatible search query string. Format of this query string is
* similar to that specified on MLflow UI.
* Example : "params.model = 'LogisticRegression' and metrics.acc != 0.9"
* If null, the result will be equivalent to having an empty search filter.
* @param runViewType ViewType for expected runs. One of (ACTIVE_ONLY, DELETED_ONLY, ALL)
* If null, only runs with viewtype ACTIVE_ONLY will be searched.
* @param maxResults Maximum number of runs desired in one page.
*
* @return A list of all Runs that satisfy search filter.
*/
public RunsPage searchRuns(List experimentIds,
String searchFilter,
ViewType runViewType,
int maxResults) {
return searchRuns(experimentIds, searchFilter, runViewType, maxResults, new ArrayList<>(),
null);
}
/**
* Return runs from provided list of experiments that satisfy the search query.
*
* @param experimentIds List of experiment IDs.
* @param searchFilter SQL compatible search query string. Format of this query string is
* similar to that specified on MLflow UI.
* Example : "params.model = 'LogisticRegression' and metrics.acc != 0.9"
* If null, the result will be equivalent to having an empty search filter.
* @param runViewType ViewType for expected runs. One of (ACTIVE_ONLY, DELETED_ONLY, ALL)
* If null, only runs with viewtype ACTIVE_ONLY will be searched.
* @param maxResults Maximum number of runs desired in one page.
* @param orderBy List of properties to order by. Example: "metrics.acc DESC".
*
* @return A list of all Runs that satisfy search filter.
*/
public RunsPage searchRuns(List experimentIds,
String searchFilter,
ViewType runViewType,
int maxResults,
List orderBy) {
return searchRuns(experimentIds, searchFilter, runViewType, maxResults, orderBy, null);
}
/**
* Return runs from provided list of experiments that satisfy the search query.
*
* @param experimentIds List of experiment IDs.
* @param searchFilter SQL compatible search query string. Format of this query string is
* similar to that specified on MLflow UI.
* Example : "params.model = 'LogisticRegression' and metrics.acc != 0.9"
* If null, the result will be equivalent to having an empty search filter.
* @param runViewType ViewType for expected runs. One of (ACTIVE_ONLY, DELETED_ONLY, ALL)
* If null, only runs with viewtype ACTIVE_ONLY will be searched.
* @param maxResults Maximum number of runs desired in one page.
* @param orderBy List of properties to order by. Example: "metrics.acc DESC".
* @param pageToken String token specifying the next page of results. It should be obtained from
* a call to {@link #searchRuns(List, String)}.
*
* @return A page of Runs that satisfy the search filter.
*/
public RunsPage searchRuns(List experimentIds,
String searchFilter,
ViewType runViewType,
int maxResults,
List orderBy,
String pageToken) {
SearchRuns.Builder builder = SearchRuns.newBuilder()
.addAllExperimentIds(experimentIds)
.addAllOrderBy(orderBy)
.setMaxResults(maxResults);
if (searchFilter != null) {
builder.setFilter(searchFilter);
}
if (runViewType != null) {
builder.setRunViewType(runViewType);
}
if (pageToken != null) {
builder.setPageToken(pageToken);
}
SearchRuns request = builder.build();
String ijson = mapper.toJson(request);
String ojson = sendPost("runs/search", ijson);
SearchRuns.Response response = mapper.toSearchRunsResponse(ojson);
return new RunsPage(response.getRunsList(), response.getNextPageToken(), experimentIds,
searchFilter, runViewType, maxResults, orderBy, this);
}
/**
* Return experiments that satisfy the search query.
*
* @param searchFilter SQL compatible search query string.
* Examples:
* - "attribute.name = 'MyExperiment'"
* - "tags.problem_type = 'iris_regression'"
* If null, the result will be equivalent to having an empty search filter.
* @param experimentViewType ViewType for expected experiments. One of
* (ACTIVE_ONLY, DELETED_ONLY, ALL). If null, only experiments with
* viewtype ACTIVE_ONLY will be searched.
* @param maxResults Maximum number of experiments desired in one page.
* @param orderBy List of properties to order by. Example: "metrics.acc DESC".
*
* @return A page of experiments that satisfy the search filter.
*/
public ExperimentsPage searchExperiments(String searchFilter,
ViewType experimentViewType,
int maxResults,
List orderBy) {
return searchExperiments(searchFilter, experimentViewType, maxResults, orderBy, null);
}
/**
* Return up to 1000 active experiments.
*
* @return A page of active experiments with up to 1000 items.
*/
public ExperimentsPage searchExperiments() {
return searchExperiments("", null, 1000, new ArrayList<>(), null);
}
/**
* Return up to the first 1000 active experiments that satisfy the search query.
*
* @param searchFilter SQL compatible search query string.
* Examples:
* - "attribute.name = 'MyExperiment'"
* - "tags.problem_type = 'iris_regression'"
* If null, the result will be equivalent to having an empty search filter.
*
* @return A page of up to active 1000 experiments that satisfy the search filter.
*/
public ExperimentsPage searchExperiments(String searchFilter) {
return searchExperiments(searchFilter, null, 1000, new ArrayList<>(), null);
}
/**
* Return experiments that satisfy the search query.
*
* @param searchFilter SQL compatible search query string.
* Examples:
* - "attribute.name = 'MyExperiment'"
* - "tags.problem_type = 'iris_regression'"
* If null, the result will be equivalent to having an empty search filter.
* @param experimentViewType ViewType for expected experiments. One of
* (ACTIVE_ONLY, DELETED_ONLY, ALL). If null, only experiments with
* viewtype ACTIVE_ONLY will be searched.
* @param maxResults Maximum number of experiments desired in one page.
* @param orderBy List of properties to order by. Example: "metrics.acc DESC".
* @param pageToken String token specifying the next page of results. It should be obtained from
* a call to {@link #searchExperiments(String)}.
*
* @return A page of experiments that satisfy the search filter.
*/
public ExperimentsPage searchExperiments(String searchFilter,
ViewType experimentViewType,
int maxResults,
List orderBy,
String pageToken) {
SearchExperiments.Builder builder = SearchExperiments.newBuilder()
.addAllOrderBy(orderBy)
.setMaxResults(maxResults);
if (searchFilter != null) {
builder.setFilter(searchFilter);
}
if (experimentViewType != null) {
builder.setViewType(experimentViewType);
} else {
builder.setViewType(ViewType.ACTIVE_ONLY);
}
if (pageToken != null) {
builder.setPageToken(pageToken);
}
SearchExperiments request = builder.build();
String ijson = mapper.toJson(request);
String ojson = sendPost("experiments/search", ijson);
SearchExperiments.Response response = mapper.toSearchExperimentsResponse(ojson);
return new ExperimentsPage(response.getExperimentsList(), response.getNextPageToken(),
searchFilter, experimentViewType, maxResults, orderBy, this);
}
/** @return An experiment with the given ID. */
public Experiment getExperiment(String experimentId) {
URIBuilder builder = newURIBuilder("experiments/get")
.setParameter("experiment_id", experimentId);
return mapper.toGetExperimentResponse(httpCaller.get(builder.toString())).getExperiment();
}
/** @return The experiment associated with the given name or Optional.empty if none exists. */
public Optional getExperimentByName(String experimentName) {
URIBuilder builder = newURIBuilder("experiments/get-by-name")
.setParameter("experiment_name", experimentName);
try {
return Optional.of(
mapper.toGetExperimentByNameResponse(httpCaller.get(builder.toString())).getExperiment()
);
} catch (MlflowHttpException e) {
if (e.getStatusCode() == 404) {
return Optional.empty();
} else {
throw e;
}
}
}
/**
* Create a new experiment using the default artifact location provided by the server.
* @param experimentName Name of the experiment. This must be unique across all experiments.
* @return Experiment ID of the newly created experiment.
*/
public String createExperiment(String experimentName) {
String ijson = mapper.makeCreateExperimentRequest(experimentName);
String ojson = httpCaller.post("experiments/create", ijson);
return mapper.toCreateExperimentResponse(ojson).getExperimentId();
}
/**
* Create a new experiment. This method allows providing all possible
* fields of CreateExperiment, and can be invoked as follows:
*
*
* import org.mlflow.api.proto.Service.CreateExperiment;
* CreateExperiment.Builder request = CreateExperiment.newBuilder();
* request.setName(name);
* request.setArtifactLocation(artifactLocation);
* request.addTags(experimentTag);
* createExperiment(request.build());
*
*
* @return ID of the experiment created by the server.
*/
public String createExperiment(CreateExperiment request) {
String ijson = mapper.toJson(request);
String ojson = sendPost("experiments/create", ijson);
return mapper.toCreateExperimentResponse(ojson).getExperimentId();
}
/** Mark an experiment and associated runs, params, metrics, etc. for deletion. */
public void deleteExperiment(String experimentId) {
String ijson = mapper.makeDeleteExperimentRequest(experimentId);
httpCaller.post("experiments/delete", ijson);
}
/** Restore an experiment marked for deletion. */
public void restoreExperiment(String experimentId) {
String ijson = mapper.makeRestoreExperimentRequest(experimentId);
httpCaller.post("experiments/restore", ijson);
}
/** Update an experiment's name. The new name must be unique. */
public void renameExperiment(String experimentId, String newName) {
String ijson = mapper.makeUpdateExperimentRequest(experimentId, newName);
httpCaller.post("experiments/update", ijson);
}
/**
* Delete a run with the given ID.
*/
public void deleteRun(String runId) {
String ijson = mapper.makeDeleteRun(runId);
httpCaller.post("runs/delete", ijson);
}
/**
* Restore a deleted run with the given ID.
*/
public void restoreRun(String runId) {
String ijson = mapper.makeRestoreRun(runId);
httpCaller.post("runs/restore", ijson);
}
/**
* Log a parameter against the given run, as a key-value pair.
* This cannot be called against the same parameter key more than once.
*/
public void logParam(String runId, String key, String value) {
sendPost("runs/log-parameter", mapper.makeLogParam(runId, key, value));
}
/**
* Log a new metric against the given run, as a key-value pair. Metrics are recorded
* against two axes: timestamp and step. This method uses the number of milliseconds
* since the Unix epoch for the timestamp, and it uses the default step of zero.
*
* @param runId The ID of the run in which to record the metric.
* @param key The key identifying the metric for which to record the specified value.
* @param value The value of the metric.
*/
public void logMetric(String runId, String key, double value) {
logMetric(runId, key, value, System.currentTimeMillis(), 0);
}
/**
* Log a new metric against the given run, as a key-value pair. Metrics are recorded
* against two axes: timestamp and step.
*
* @param runId The ID of the run in which to record the metric.
* @param key The key identifying the metric for which to record the specified value.
* @param value The value of the metric.
* @param timestamp The timestamp at which to record the metric value.
* @param step The step at which to record the metric value.
*/
public void logMetric(String runId, String key, double value, long timestamp, long step) {
sendPost("runs/log-metric", mapper.makeLogMetric(runId, key, value, timestamp, step));
}
/**
* Log a new tag against the given experiment as a key-value pair.
* @param experimentId The ID of the experiment on which to set the tag
* @param key The key used to identify the tag.
* @param value The value of the tag.
*/
public void setExperimentTag(String experimentId, String key, String value) {
sendPost("experiments/set-experiment-tag",
mapper.makeSetExperimentTag(experimentId, key, value));
}
/**
* Log a new tag against the given run, as a key-value pair.
* @param runId The ID of the run on which to set the tag
* @param key The key used to identify the tag.
* @param value The value of the tag.
*/
public void setTag(String runId, String key, String value) {
sendPost("runs/set-tag", mapper.makeSetTag(runId, key, value));
}
/**
* Delete a tag on the run ID with a specific key. This is irreversible.
* @param runId String ID of the run
* @param key Name of the tag
*/
public void deleteTag(String runId, String key) {
sendPost("runs/delete-tag", mapper.makeDeleteTag(runId, key));
}
/**
* Log multiple metrics, params, and/or tags against a given run (argument runId).
* Argument metrics, params, and tag iterables can be nulls.
*/
public void logBatch(String runId,
Iterable metrics,
Iterable params,
Iterable tags) {
sendPost("runs/log-batch", mapper.makeLogBatch(runId, metrics, params, tags));
}
/** Set the status of a run to be FINISHED at the current time. */
public void setTerminated(String runId) {
setTerminated(runId, RunStatus.FINISHED);
}
/** Set the status of a run to be completed at the current time. */
public void setTerminated(String runId, RunStatus status) {
setTerminated(runId, status, System.currentTimeMillis());
}
/** Set the status of a run to be completed at the given endTime. */
public void setTerminated(String runId, RunStatus status, long endTime) {
sendPost("runs/update", mapper.makeUpdateRun(runId, status, endTime));
}
/**
* Send a GET to the following path, including query parameters.
* This is mostly an internal API, but allows making lower-level or unsupported requests.
* @return JSON response from the server.
*/
public String sendGet(String path) {
return httpCaller.get(path);
}
/**
* Send a POST to the following path, with a String-encoded JSON body.
* This is mostly an internal API, but allows making lower-level or unsupported requests.
* @return JSON response from the server.
*/
public String sendPost(String path, String json) {
return httpCaller.post(path, json);
}
public String sendPatch(String path, String json) {
return httpCaller.patch(path, json);
}
/**
* @return HostCredsProvider backing this MlflowClient. Visible for testing.
*/
MlflowHostCredsProvider getInternalHostCredsProvider() {
return hostCredsProvider;
}
private URIBuilder newURIBuilder(String base) {
try {
return new URIBuilder(base);
} catch (URISyntaxException e) {
throw new MlflowClientException("Failed to construct URI for " + base, e);
}
}
/**
* Return the tracking URI from MLFLOW_TRACKING_URI or throws if not available.
* This is used as the body of the no-argument constructor, as constructors must first call
* this().
*/
private static String getDefaultTrackingUri() {
String defaultTrackingUri = System.getenv("MLFLOW_TRACKING_URI");
if (defaultTrackingUri == null) {
throw new IllegalStateException("Default client requires MLFLOW_TRACKING_URI is set." +
" Use fromTrackingUri() instead.");
}
return defaultTrackingUri;
}
/**
* Return the MlflowHostCredsProvider associated with the given tracking URI.
* This is used as the body of the String-argument constructor, as constructors must first call
* this().
*/
private static MlflowHostCredsProvider getHostCredsProviderFromTrackingUri(String trackingUri) {
URI uri = URI.create(trackingUri);
MlflowHostCredsProvider provider;
if ("http".equals(uri.getScheme()) || "https".equals(uri.getScheme())) {
provider = new BasicMlflowHostCreds(trackingUri);
} else if (trackingUri.equals("databricks")) {
MlflowHostCredsProvider profileProvider = new DatabricksConfigHostCredsProvider();
MlflowHostCredsProvider dynamicProvider =
DatabricksDynamicHostCredsProvider.createIfAvailable();
if (dynamicProvider != null) {
provider = new HostCredsProviderChain(dynamicProvider, profileProvider);
} else {
provider = profileProvider;
}
} else if ("databricks".equals(uri.getScheme())) {
provider = new DatabricksConfigHostCredsProvider(uri.getHost());
} else if (uri.getScheme() == null || "file".equals(uri.getScheme())) {
throw new IllegalArgumentException("Java Client currently does not support" +
" local tracking URIs. Please point to a Tracking Server.");
} else {
throw new IllegalArgumentException("Invalid tracking server uri: " + trackingUri);
}
return provider;
}
/**
* Upload the given local file or directory to the run's root artifact directory. For example,
*
*
* logArtifact(runId, "/my/localModel")
* listArtifacts(runId) // returns "localModel"
*
*
* @param runId Run ID of an existing MLflow run.
* @param localFile File or directory to upload. Must exist.
*/
public void logArtifact(String runId, File localFile) {
if (localFile.isDirectory()) {
getArtifactRepository(runId).logArtifacts(localFile, localFile.getName());
}
else {
getArtifactRepository(runId).logArtifact(localFile);
}
}
/**
* Upload the given local file or directory to an artifactPath
* within the run's root directory. For example,
*
*
* logArtifact(runId, "/my/localModel", "model")
* listArtifacts(runId, "model") // returns "model/localModel"
*
*
* (i.e., the localModel file is now available in model/localModel).
* If logging a directory, the directory is renamed to artifactPath.
*
* @param runId Run ID of an existing MLflow run.
* @param localFile File or directory to upload. Must exist.
* @param artifactPath Artifact path relative to the run's root directory. Should NOT
* start with a /.
*/
public void logArtifact(String runId, File localFile, String artifactPath) {
if (localFile.isDirectory()) {
getArtifactRepository(runId).logArtifacts(localFile, artifactPath);
}
else {
getArtifactRepository(runId).logArtifact(localFile, artifactPath);
}
}
/**
* Upload all files within the given local directory the run's root artifact directory.
* For example, if /my/local/dir/ contains two files "file1" and "file2", then
*
*
* logArtifacts(runId, "/my/local/dir")
* listArtifacts(runId) // returns "file1" and "file2"
*
*
* @param runId Run ID of an existing MLflow run.
* @param localDir Directory to upload. Must exist, and must be a directory (not a simple file).
*/
public void logArtifacts(String runId, File localDir) {
getArtifactRepository(runId).logArtifacts(localDir);
}
/**
* Upload all files within the given local director an artifactPath within the run's root
* artifact directory. For example, if /my/local/dir/ contains two files "file1" and "file2", then
*
*
* logArtifacts(runId, "/my/local/dir", "model")
* listArtifacts(runId, "model") // returns "model/file1" and "model/file2"
*
*
* (i.e., the contents of the local directory are now available in model/).
*
* @param runId Run ID of an existing MLflow run.
* @param localDir Directory to upload. Must exist, and must be a directory (not a simple file).
* @param artifactPath Artifact path relative to the run's root directory. Should NOT
* start with a /.
*/
public void logArtifacts(String runId, File localDir, String artifactPath) {
getArtifactRepository(runId).logArtifacts(localDir, artifactPath);
}
/**
* List the artifacts immediately under the run's root artifact directory. This does not
* recursively list; instead, it will return FileInfos with isDir=true where further
* listing may be done.
* @param runId Run ID of an existing MLflow run.
*/
public List listArtifacts(String runId) {
return getArtifactRepository(runId).listArtifacts();
}
/**
* List the artifacts immediately under the given artifactPath within the run's root artifact
* directory. This does not recursively list; instead, it will return FileInfos with isDir=true
* where further listing may be done.
* @param runId Run ID of an existing MLflow run.
* @param artifactPath Artifact path relative to the run's root directory. Should NOT
* start with a /.
*/
public List listArtifacts(String runId, String artifactPath) {
return getArtifactRepository(runId).listArtifacts(artifactPath);
}
/**
* Return a local directory containing *all* artifacts within the run's artifact directory.
* Note that this will download the entire directory path, and so may be expensive if
* the directory has a lot of data.
* @param runId Run ID of an existing MLflow run.
*/
public File downloadArtifacts(String runId) {
return getArtifactRepository(runId).downloadArtifacts();
}
/**
* Return a local file or directory containing all artifacts within the given artifactPath
* within the run's root artifactDirectory. For example, if "model/file1" and "model/file2"
* exist within the artifact directory, then
*
*
* downloadArtifacts(runId, "model") // returns a local directory containing "file1" and "file2"
* downloadArtifacts(runId, "model/file1") // returns a local *file* with the contents of file1.
*
*
* Note that this will download the entire subdirectory path, and so may be expensive if
* the subdirectory has a lot of data.
*
* @param runId Run ID of an existing MLflow run.
* @param artifactPath Artifact path relative to the run's root directory. Should NOT
* start with a /.
*/
public File downloadArtifacts(String runId, String artifactPath) {
return getArtifactRepository(runId).downloadArtifacts(artifactPath);
}
/**
* @param runId Run ID of an existing MLflow run.
* @return ArtifactRepository, capable of uploading and downloading MLflow artifacts.
*/
private ArtifactRepository getArtifactRepository(String runId) {
URI baseArtifactUri = URI.create(getRun(runId).getInfo().getArtifactUri());
return artifactRepositoryFactory.getArtifactRepository(baseArtifactUri, runId);
}
// ********************
// * Model Registry *
// ********************
/**
* Return the latest model version for each stage.
* The current available stages are: [None, Staging, Production, Archived].
*
*
* import org.mlflow.api.proto.ModelRegistry.ModelVersion;
* List{@code } detailsList = getLatestVersions("model");
*
* for (ModelVersion details : detailsList) {
* System.out.println("Model Name: " + details.getModelVersion()
* .getRegisteredModel()
* .getName());
* System.out.println("Model Version: " + details.getModelVersion().getVersion());
* System.out.println("Current Stage: " + details.getCurrentStage());
* }
*
*
* @param modelName The name of the model
* @return A collection of {@link org.mlflow.api.proto.ModelRegistry.ModelVersion}
*/
public List getLatestVersions(String modelName) {
return getLatestVersions(modelName, Collections.emptyList());
}
/**
* Return the latest model version for each stage requested.
* The current available stages are: [None, Staging, Production, Archived].
*
*
* import org.mlflow.api.proto.ModelRegistry.ModelVersion;
* List{@code } detailsList =
* getLatestVersions("model", Lists.newArrayList{@code }("Staging"));
*
* for (ModelVersion details : detailsList) {
* System.out.println("Model Name: " + details.getModelVersion()
* .getRegisteredModel()
* .getName());
* System.out.println("Model Version: " + details.getModelVersion().getVersion());
* System.out.println("Current Stage: " + details.getCurrentStage());
* }
*
*
* @param modelName The name of the model
* @param stages A list of stages
* @return The latest model version
* {@link org.mlflow.api.proto.ModelRegistry.ModelVersion}
*/
public List getLatestVersions(String modelName, Iterable stages) {
String json = sendGet(mapper.makeGetLatestVersion(modelName, stages));
GetLatestVersions.Response response = mapper.toGetLatestVersionsResponse(json);
return response.getModelVersionsList();
}
/**
*
*
* import org.mlflow.api.proto.ModelRegistry.ModelVersion;
* ModelVersion modelVersion = getModelVersion("model", "version");
*
*
* @param modelName Name of the containing registered model. *
* @param version Version number as a string of the model version.
* @return a single model version
* {@link org.mlflow.api.proto.ModelRegistry.ModelVersion}
*/
public ModelVersion getModelVersion(String modelName, String version) {
String json = sendGet(mapper.makeGetModelVersion(modelName, version));
GetModelVersion.Response response = mapper.toGetModelVersionResponse(json);
return response.getModelVersion();
}
/**
* Returns a RegisteredModel from the model registry for the given model name.
*
* import org.mlflow.api.proto.ModelRegistry.RegisteredModel;
* RegisteredModel registeredModel = getRegisteredModel("model");
*
*
* @param modelName Name of the containing registered model. *
* @return a registered model {@link org.mlflow.api.proto.ModelRegistry.RegisteredModel}
*/
public RegisteredModel getRegisteredModel(String modelName) {
String json = sendGet(mapper.makeGetRegisteredModel(modelName));
GetRegisteredModel.Response response = mapper.toGetRegisteredModelResponse(json);
return response.getRegisteredModel();
}
/**
* Return the model URI containing for the given model version. The model URI can be used
* to download the model version artifacts.
*
*
* String modelUri = getModelVersionDownloadUri("model", 0);
*
*
* @param modelName The name of the model
* @param version The version number of the model
* @return The specified model version's URI.
*/
public String getModelVersionDownloadUri(String modelName, String version) {
String json = sendGet(mapper.makeGetModelVersionDownloadUri(modelName, version));
return mapper.toGetModelVersionDownloadUriResponse(json);
}
/**
* Returns a directory containing all artifacts within the given registered model
* version. The method will download the model version artifacts to the local file system. Note
* that this method will not work if the `download_uri` refers to a single file (and not a
* directory) due to the way many ArtifactRepository's `download_artifacts` handle empty subpaths.
*
*
* File modelVersionDir = downloadModelVersion("model", 0);
*
*
* @param modelName The name of the model
* @param version The version number of the model
* @return A directory ({@link java.io.File}) containing model artifacts
*/
public File downloadModelVersion(String modelName, String version) {
String path = modelName + "/" + version;
URIBuilder downloadUriBuilder = new URIBuilder()
.setScheme(DEFAULT_MODELS_ARTIFACT_REPOSITORY_SCHEME).setPath(path);
CliBasedArtifactRepository repository = new CliBasedArtifactRepository(null, null,
hostCredsProvider);
return repository.downloadArtifactFromUri(downloadUriBuilder.toString());
}
/**
* Returns a directory containing all artifacts within the latest registered
* model version in the given stage. The method will download the model version artifacts
* to the local file system.
*
*
* File modelVersionDir = downloadLatestModelVersion("model", "Staging");
*
*
* (i.e., the contents of the local directory are now available).
*
* @param modelName The name of the model
* @param stage The name of the stage
* @return A directory ({@link java.io.File}) containing model artifacts
*/
public File downloadLatestModelVersion(String modelName, String stage) {
List versions = getLatestVersions(modelName, Lists.newArrayList(stage));
if (versions.size() < 1) {
throw new MlflowClientException("No model version found for " + modelName +
"and stage " + stage);
}
ModelVersion details = versions.get(0);
return downloadModelVersion(modelName, details.getVersion());
}
/**
* Return model versions that satisfy the search query.
*
* @param searchFilter SQL compatible search query string.
* Examples:
* - "name = 'model_name'"
* - "run_id = '...'"
* If null, the result will be equivalent to having an empty search filter.
* @param maxResults Maximum number of model versions desired in one page.
* @param orderBy List of properties to order by. Example: "name DESC".
*
* @return A page of model versions that satisfy the search filter.
*/
public ModelVersionsPage searchModelVersions(String searchFilter,
int maxResults,
List orderBy) {
return searchModelVersions(searchFilter, maxResults, orderBy, null);
}
/**
* Return up to 1000 model versions.
*
* @return A page of model versions with up to 1000 items.
*/
public ModelVersionsPage searchModelVersions() {
return searchModelVersions("", 1000, new ArrayList<>(), null);
}
/**
* Return up to 1000 model versions that satisfy the search query.
*
* @param searchFilter SQL compatible search query string.
* Examples:
* - "name = 'model_name'"
* - "run_id = '...'"
* If null, the result will be equivalent to having an empty search filter.
*
* @return A page of model versions with up to 1000 items.
*/
public ModelVersionsPage searchModelVersions(String searchFilter) {
return searchModelVersions(searchFilter, 1000, new ArrayList<>(), null);
}
/**
* Return model versions that satisfy the search query.
*
* @param searchFilter SQL compatible search query string.
* Examples:
* - "name = 'model_name'"
* - "run_id = '...'"
* If null, the result will be equivalent to having an empty search filter.
* @param maxResults Maximum number of model versions desired in one page.
* @param orderBy List of properties to order by. Example: "name DESC".
* @param pageToken String token specifying the next page of results. It should be obtained from
* a call to {@link #searchModelVersions(String)}.
*
* @return A page of model versions that satisfy the search filter.
*/
public ModelVersionsPage searchModelVersions(String searchFilter,
int maxResults,
List orderBy,
String pageToken) {
String json = sendGet(mapper.makeSearchModelVersions(
searchFilter, maxResults, orderBy, pageToken
));
SearchModelVersions.Response response = mapper.toSearchModelVersionsResponse(json);
return new ModelVersionsPage(response.getModelVersionsList(), response.getNextPageToken(),
searchFilter, maxResults, orderBy, this);
}
/**
* Closes the MlflowClient and releases any associated resources.
*/
public void close() {
this.httpCaller.close();
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy