All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.mlflow.tracking.MlflowHttpCaller Maven / Gradle / Ivy

The newest version!
package org.mlflow.tracking;

import com.google.common.annotations.VisibleForTesting;

import java.io.IOException;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.security.KeyManagementException;
import java.security.KeyStoreException;
import java.security.NoSuchAlgorithmException;
import java.util.Base64;

import org.apache.http.HttpResponse;
import org.apache.http.client.methods.HttpEntityEnclosingRequestBase;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.client.methods.HttpPatch;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.client.methods.HttpRequestBase;
import org.apache.http.conn.ssl.NoopHostnameVerifier;
import org.apache.http.conn.ssl.SSLConnectionSocketFactory;
import org.apache.http.conn.ssl.TrustSelfSignedStrategy;
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClientBuilder;
import org.apache.http.ssl.SSLContextBuilder;
import org.apache.http.util.EntityUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.mlflow.tracking.creds.MlflowHostCreds;
import org.mlflow.tracking.creds.MlflowHostCredsProvider;


class MlflowHttpCaller {
  private static final Logger logger = LoggerFactory.getLogger(MlflowHttpCaller.class);
  private static final String BASE_API_PATH = "api/2.0/mlflow";
  protected CloseableHttpClient httpClient;
  private final MlflowHostCredsProvider hostCredsProvider;
  private final int maxRateLimitIntervalMillis;
  private final int rateLimitRetrySleepInitMillis;
  private final int maxRetryAttempts;

  /**
   * Construct a new MlflowHttpCaller with a default configuration for request retries.
   */
  MlflowHttpCaller(MlflowHostCredsProvider hostCredsProvider) {
    this(hostCredsProvider, 60000, 1000, 3);
  }

  /**
   * Construct a new MlflowHttpCaller.
   *
   * @param maxRateLimitIntervalMs The maximum amount of time, in milliseconds, to spend retrying a
   *                               single request in response to rate limiting (error code 429).
   * @param rateLimitRetrySleepInitMs The initial backoff delay, in milliseconds, when retrying a
   *                                  request in response to rate limiting (error code 429). The
   *                                  delay is increased exponentially after each rate limiting
   *                                  response until the total delay incurred across all retries for
   *                                  the request exceeds the specified maxRateLimitIntervalSeconds.
   * @param maxRetryAttempts The maximum number of times to retry a request, excluding rate limit
   *                         retries.
   */
  MlflowHttpCaller(MlflowHostCredsProvider hostCredsProvider,
                   int maxRateLimitIntervalMs,
                   int rateLimitRetrySleepInitMs,
                   int maxRetryAttempts) {
    this.hostCredsProvider = hostCredsProvider;
    this.maxRateLimitIntervalMillis = maxRateLimitIntervalMs;
    this.rateLimitRetrySleepInitMillis = rateLimitRetrySleepInitMs;
    this.maxRetryAttempts = maxRetryAttempts;
  }

  @VisibleForTesting
  MlflowHttpCaller(MlflowHostCredsProvider hostCredsProvider,
                   int maxRateLimitIntervalMs,
                   int rateLimitRetrySleepInitMs,
                   int maxRetryAttempts,
                   CloseableHttpClient client) {
    this(
      hostCredsProvider, maxRateLimitIntervalMs, rateLimitRetrySleepInitMs, maxRetryAttempts);
    this.httpClient = client;
  }

  private HttpResponse executeRequestWithRateLimitRetries(HttpRequestBase request)
      throws IOException {
    int timeLeft = maxRateLimitIntervalMillis;
    int sleepFor = rateLimitRetrySleepInitMillis;
    HttpResponse response = httpClient.execute(request);
    while (response.getStatusLine().getStatusCode() == 429 && timeLeft > 0) {
      logger.warn("Request returned with status code 429 (Rate limit exceeded). Retrying after "
                  + sleepFor
                  + " milliseconds. Will continue to retry 429s for up to "
                  + timeLeft
                  + " milliseconds.");
      try {
        Thread.sleep(sleepFor);
      } catch (InterruptedException e) {
        throw new RuntimeException(e);
      }
      timeLeft -= sleepFor;
      sleepFor = Math.min(timeLeft, 2 * sleepFor);
      response = httpClient.execute(request);
    }
    checkError(response);
    return response;
  }

  private HttpResponse executeRequest(HttpRequestBase request) throws IOException {
    HttpResponse response = null;
    int attemptsRemaining = this.maxRetryAttempts;
    while (attemptsRemaining > 0) {
      attemptsRemaining -= 1;
      try {
        response = executeRequestWithRateLimitRetries(request);
        break;
      } catch (MlflowHttpException e) {
        if (attemptsRemaining > 0 && e.getStatusCode() != 429) {
          logger.warn("Request returned with status code {} (Rate limit exceeded)."
                      + " Retrying up to {} more times. Response body: {}",
                      e.getStatusCode(),
                      attemptsRemaining,
                      e.getBodyMessage());
          continue;
        } else {
          throw e;
        }
      }
    }
    return response;
  }

  String get(String path) {
    logger.debug("Sending GET " + path);
    HttpGet request = new HttpGet();
    fillRequestSettings(request, path);
    try {
      HttpResponse response = executeRequest(request);
      String responseJson = EntityUtils.toString(response.getEntity(), StandardCharsets.UTF_8);
      logger.debug("Response: " + responseJson);
      return responseJson;
    } catch (IOException e) {
      throw new MlflowClientException(e);
    }
  }

  // TODO(aaron) Convert to InputStream.
  byte[] getAsBytes(String path) {
    logger.debug("Sending GET " + path);
    HttpGet request = new HttpGet();
    fillRequestSettings(request, path);
    try {
      HttpResponse response = executeRequest(request);
      byte[] bytes = EntityUtils.toByteArray(response.getEntity());
      logger.debug("response: #bytes=" + bytes.length);
      return bytes;
    } catch (IOException e) {
      throw new MlflowClientException(e);
    }
  }

  String post(String path, String json) {
    logger.debug("Sending POST " + path + ": " + json);
    HttpPost request = new HttpPost();
    return send(request, path, json);
  }

  String patch(String path, String json) {
    logger.debug("Sending PATCH " + path + ": " + json);
    HttpPatch request = new HttpPatch();
    return send(request, path, json);
  }

  private String send(HttpEntityEnclosingRequestBase request, String path, String json) {
    fillRequestSettings(request, path);
    request.setEntity(new StringEntity(json, StandardCharsets.UTF_8));
    request.setHeader("Content-Type", "application/json");
    try {
      HttpResponse response = executeRequest(request);
      String responseJson = EntityUtils.toString(response.getEntity(), StandardCharsets.UTF_8);
      logger.debug("Response: " + responseJson);
      return responseJson;
    } catch (IOException e) {
      throw new MlflowClientException(e);
    }
  }

  private void checkError(HttpResponse response) throws MlflowClientException, IOException {
    int statusCode = response.getStatusLine().getStatusCode();
    String reasonPhrase = response.getStatusLine().getReasonPhrase();
    if (isError(statusCode)) {
      String bodyMessage = EntityUtils.toString(response.getEntity(), StandardCharsets.UTF_8);
      if (statusCode >= 400 && statusCode <= 499) {
        throw new MlflowHttpException(statusCode, reasonPhrase, bodyMessage);
      }
      if (statusCode >= 500 && statusCode <= 599) {
        throw new MlflowHttpException(statusCode, reasonPhrase, bodyMessage);
      }
      throw new MlflowHttpException(statusCode, reasonPhrase, bodyMessage);
    }
  }

  private void fillRequestSettings(HttpRequestBase request, String path) {
    MlflowHostCreds hostCreds = hostCredsProvider.getHostCreds();
    createHttpClientIfNecessary(hostCreds.shouldIgnoreTlsVerification());
    String uri = hostCreds.getHost() + "/" + BASE_API_PATH + "/" + path;
    request.setURI(URI.create(uri));
    String username = hostCreds.getUsername();
    String password = hostCreds.getPassword();
    String token = hostCreds.getToken();
    if (username != null && password != null) {
      String authHeader = Base64.getEncoder()
        .encodeToString((username + ":" + password).getBytes(StandardCharsets.UTF_8));
      request.addHeader("Authorization", "Basic " + authHeader);
    } else if (token != null) {
      request.addHeader("Authorization", "Bearer " + token);
    }

    String userAgent = "mlflow-java-client";
    String clientVersion = MlflowClientVersion.getClientVersion();
    if (!clientVersion.isEmpty()) {
      userAgent += "/" + clientVersion;
    }
    request.addHeader("User-Agent", userAgent);
  }

  private boolean isError(int statusCode) {
    return statusCode < 200 || statusCode > 399;
  }

  private void createHttpClientIfNecessary(boolean noTlsVerify) {
    if (httpClient != null) {
      return;
    }

    HttpClientBuilder builder = HttpClientBuilder.create();
    if (noTlsVerify) {
      try {
        SSLContextBuilder sslBuilder = new SSLContextBuilder()
          .loadTrustMaterial(null, new TrustSelfSignedStrategy());
        SSLConnectionSocketFactory connectionFactory =
          new SSLConnectionSocketFactory(sslBuilder.build(), new NoopHostnameVerifier());
        builder.setSSLSocketFactory(connectionFactory);
      } catch (NoSuchAlgorithmException | KeyStoreException | KeyManagementException e) {
        logger.warn("Could not set noTlsVerify to true, verification will remain", e);
      }
    }

    this.httpClient = builder.build();
  }

  void close() {
    if (httpClient != null) {
      try {
	  httpClient.close();
      } catch(IOException e){
	  logger.warn("Unable to close connection to mlflow backend", e);
      }
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy