All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.mlflow.tracking.utils.DatabricksContext Maven / Gradle / Ivy

package org.mlflow.tracking.utils;

import com.google.common.annotations.VisibleForTesting;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.HashMap;
import java.util.Map;

public class DatabricksContext {
  public static final String CONFIG_PROVIDER_CLASS_NAME =
    "com.databricks.config.DatabricksClientSettingsProvider";
  private static final Logger logger = LoggerFactory.getLogger(
    DatabricksContext.class);
  private final Map configProvider;

  private DatabricksContext(Map configProvider) {
    this.configProvider = configProvider;
  }

  public static DatabricksContext createIfAvailable() {
    return createIfAvailable(CONFIG_PROVIDER_CLASS_NAME);
  }

  @VisibleForTesting
  static DatabricksContext createIfAvailable(String className) {
    Map configProvider = getConfigProviderIfAvailable(className);
    if (configProvider == null) {
      return null;
    }
    return new DatabricksContext(configProvider);
  }

  public Map getTags() {
    if (isInDatabricksNotebook()) {
      return getTagsForDatabricksNotebook();
    } else if (isInDatabricksJob()) {
      return getTagsForDatabricksJob();
    } else {
      return new HashMap<>();
    }
  }

  public boolean isInDatabricksNotebook() {
    return configProvider.get("notebookId") != null;
  }

  /**
   * Should only be called if isInDatabricksNotebook() is true.
   */
  private Map getTagsForDatabricksNotebook() {
    Map tagsForNotebook = new HashMap<>();
    String notebookId = getNotebookId();
    if (notebookId != null) {
      tagsForNotebook.put(MlflowTagConstants.DATABRICKS_NOTEBOOK_ID, notebookId);
    }
    String notebookPath = configProvider.get("notebookPath");
    if (notebookPath != null) {
      tagsForNotebook.put(MlflowTagConstants.SOURCE_NAME, notebookPath);
      tagsForNotebook.put(MlflowTagConstants.DATABRICKS_NOTEBOOK_PATH, notebookPath);
      tagsForNotebook.put(MlflowTagConstants.SOURCE_TYPE, "NOTEBOOK");
    }
    String webappUrl = configProvider.get("host");
    if (webappUrl != null) {
      tagsForNotebook.put(MlflowTagConstants.DATABRICKS_WEBAPP_URL, webappUrl);
    }
    return tagsForNotebook;
  }

  /**
   * Should only be called if isInDatabricksNotebook() is true.
   */
  public String getNotebookId() {
    if (!isInDatabricksNotebook()) {
      throw new IllegalArgumentException(
        "getNotebookId() should not be called when isInDatabricksNotebook() is false"
      );
    }
    return configProvider.get("notebookId");
  }

  public String getNotebookPath() {
    if (!isInDatabricksNotebook()) {
      throw new IllegalArgumentException(
        "getNotebookPath() should not be called when isInDatabricksNotebook() is false"
      );
    }
    return configProvider.get("notebookPath");
  }

  private boolean isInDatabricksJob() {
    return configProvider.get("jobId") != null;
  }

  /**
   * Should only be called if isInDatabricksJob() is true.
   */
  private Map getTagsForDatabricksJob() {
    Map tagsForJob = new HashMap<>();
    String jobId = configProvider.get("jobId");
    String jobRunId = configProvider.get("jobRunId");
    String jobType = configProvider.get("jobType");
    String webappUrl = configProvider.get("host");
    if (jobId != null && jobRunId != null) {
      tagsForJob.put(MlflowTagConstants.DATABRICKS_JOB_ID, jobId);
      tagsForJob.put(MlflowTagConstants.DATABRICKS_JOB_RUN_ID, jobRunId);
      tagsForJob.put(MlflowTagConstants.SOURCE_TYPE, "JOB");
      tagsForJob.put(MlflowTagConstants.SOURCE_NAME,
                          String.format("jobs/%s/run/%s", jobId, jobRunId));
    }
    if (jobType != null) {
      tagsForJob.put(MlflowTagConstants.DATABRICKS_JOB_TYPE, jobType);
    }
    if (webappUrl != null) {
      tagsForJob.put(MlflowTagConstants.DATABRICKS_WEBAPP_URL, webappUrl);
    }
    return tagsForJob;
  }

  public static Map getConfigProviderIfAvailable(String className) {
    try {
      Class cls = Class.forName(className);
      return (Map) cls.newInstance();
    } catch (ClassNotFoundException e) {
      return null;
    } catch (IllegalAccessException | InstantiationException e) {
      logger.warn("Found but failed to invoke dynamic config provider", e);
      return null;
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy