All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.databricks.sdk.mixin.ClustersExt Maven / Gradle / Ivy

There is a newer version: 0.35.0
Show newest version
package com.databricks.sdk.mixin;

import static com.databricks.sdk.service.compute.CloudProviderNodeStatus.*;

import com.databricks.sdk.core.ApiClient;
import com.databricks.sdk.core.DatabricksError;
import com.databricks.sdk.service.compute.*;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.concurrent.TimeoutException;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ClustersExt extends ClustersAPI {
  private static final Logger LOG = LoggerFactory.getLogger(ClustersExt.class);
  private static final String INVALID_STATE = "INVALID_STATE";

  public ClustersExt(ApiClient apiClient) {
    super(apiClient);
  }

  public ClustersExt(ClustersService mock) {
    super(mock);
  }

  public String selectSparkVersion(SparkVersionSelector selector) throws IllegalArgumentException {
    // Logic ported from
    // https://github.com/databricks/databricks-sdk-go/blob/main/service/clusters/spark_version.go
    List versions = new ArrayList<>();
    GetSparkVersionsResponse sv = sparkVersions();
    for (SparkVersion version : sv.getVersions()) {
      if (version.getName() == null) continue;
      if (selector.scala != null && !version.getKey().contains("-scala" + selector.scala)) {
        continue;
      }
      boolean matches =
          ((!version.getKey().contains("apache-spark-"))
              && ((version.getKey().contains("-ml-")) == selector.ml)
              && ((version.getKey().contains("-hls-")) == selector.genomics)
              && ((version.getKey().contains("-gpu-")) == selector.gpu)
              && ((version.getKey().contains("-photon-")) == selector.photon)
              && ((version.getKey().contains("-aarch64-")) == selector.graviton)
              && ((version.getName().contains("Beta")) == selector.beta));
      if (matches && selector.longTermSupport) {
        matches = version.getName().contains("LTS") || version.getKey().contains("-esr-");
      }
      if (matches && selector.sparkVersion != null) {
        matches = ("Apache Spark " + selector.sparkVersion).equals(version.getName());
      }
      if (matches) {
        versions.add(version.getKey());
      }
    }
    if (versions.size() < 1) {
      throw new IllegalArgumentException("spark versions query returned no results");
    }
    if (versions.size() > 1) {
      if (!selector.latest) {
        throw new IllegalArgumentException("spark versions query returned multiple results");
      }
      versions.sort((v1, v2) -> SemVer.parse(v2).compareTo(SemVer.parse(v1)));
    }
    return versions.get(0);
  }

  public String selectNodeType(NodeTypeSelector selector) {
    // Logic ported from
    // https://github.com/databricks/databricks-sdk-go/blob/main/service/clusters/node_type.go
    ListNodeTypesResponse res = this.listNodeTypes();
    List types =
        res.getNodeTypes().stream().sorted(nodeSortingComparator).collect(Collectors.toList());
    for (NodeType nt : types) {
      if (shouldNodeBeSkipped(nt)) {
        continue;
      }
      Long gbs = 0L;
      if (nt.getMemoryMb() != null) {
        gbs = nt.getMemoryMb() / 1024;
      }
      if (selector.fleet != null && !nt.getNodeTypeId().contains(selector.fleet)) {
        continue;
      }
      if (selector.minMemoryGb != null && gbs < selector.minMemoryGb) {
        continue;
      }
      if (selector.gbPerCore != null && gbs / nt.getNumCores() < selector.gbPerCore) {
        continue;
      }
      if (selector.minCores != null && nt.getNumCores() < selector.minCores) {
        continue;
      }
      if ((selector.minGpus != null && nt.getNumGpus() < selector.minGpus)
          || (selector.minGpus == null && nt.getNumGpus() != null && nt.getNumGpus() > 0)) {
        continue;
      }
      if (selector.localDisk != null || selector.localDiskMinSize != null) {
        NodeInstanceType instanceType = nt.getNodeInstanceType();
        if (instanceType == null) {
          continue;
        }
        long localDisks = instanceType.getLocalDisks() != null ? instanceType.getLocalDisks() : 0;
        long localNvmeDisks =
            instanceType.getLocalNvmeDisks() != null ? instanceType.getLocalNvmeDisks() : 0;
        if (localDisks < 1 && localNvmeDisks < 1) {
          continue;
        }
        Long localDiskSizeGb =
            instanceType.getLocalDiskSizeGb() != null ? instanceType.getLocalDiskSizeGb() : 0;
        Long localNvmeDiskSizeGb =
            instanceType.getLocalNvmeDiskSizeGb() != null
                ? instanceType.getLocalNvmeDiskSizeGb()
                : 0;
        long allDisksSize = localDiskSizeGb + localNvmeDiskSizeGb;
        if (selector.localDiskMinSize != null && allDisksSize < selector.localDiskMinSize) {
          continue;
        }
      }
      if (selector.category != null && !nt.getCategory().equalsIgnoreCase(selector.category)) {
        continue;
      }
      if (selector.isIoCacheEnabled != null
          && !nt.getIsIoCacheEnabled().equals(selector.isIoCacheEnabled)) {
        continue;
      }
      if (selector.supportPortForwarding != null
          && !nt.getSupportPortForwarding().equals(selector.supportPortForwarding)) {
        continue;
      }
      if (selector.photonDriverCapable != null
          && !nt.getPhotonDriverCapable().equals(selector.photonDriverCapable)) {
        continue;
      }
      if (selector.photonWorkerCapable != null
          && !nt.getPhotonWorkerCapable().equals(selector.photonWorkerCapable)) {
        continue;
      }
      if (selector.graviton != null && !nt.getIsGraviton().equals(selector.graviton)) {
        continue;
      }
      return nt.getNodeTypeId();
    }
    throw new IllegalArgumentException("Cannot determine smallest node type");
  }

  private static Function instanceTypeComparator(
      Function yy) {
    return nodeType -> {
      NodeInstanceType nodeInstanceType = nodeType.getNodeInstanceType();
      if (nodeInstanceType == null) {
        return 0L;
      }
      Long value = yy.apply(nodeInstanceType);
      if (value == null) {
        return 0L;
      }
      return value;
    };
  }

  private final Comparator nodeSortingComparator =
      (item1, item2) ->
          Comparator.comparing(NodeType::getIsDeprecated, Comparator.nullsLast(Boolean::compare))
              .thenComparing(NodeType::getNumCores, Comparator.nullsLast(Double::compare))
              .thenComparing(NodeType::getMemoryMb, Comparator.nullsLast(Long::compare))
              .thenComparing(
                  instanceTypeComparator(NodeInstanceType::getLocalDisks),
                  Comparator.nullsLast(Long::compare))
              .thenComparing(
                  instanceTypeComparator(NodeInstanceType::getLocalDiskSizeGb),
                  Comparator.nullsLast(Long::compare))
              .thenComparing(
                  instanceTypeComparator(NodeInstanceType::getLocalNvmeDisks),
                  Comparator.nullsLast(Long::compare))
              .thenComparing(
                  instanceTypeComparator(NodeInstanceType::getLocalNvmeDiskSizeGb),
                  Comparator.nullsLast(Long::compare))
              .thenComparing(NodeType::getNumGpus, Comparator.nullsLast(Long::compare))
              .thenComparing(NodeType::getInstanceTypeId, Comparator.nullsLast(String::compareTo))
              .compare(item1, item2);

  private static boolean shouldNodeBeSkipped(NodeType nt) {
    if (nt.getNodeInfo() == null) {
      return false;
    }
    if (nt.getNodeInfo().getStatus() == null) {
      return false;
    }
    for (CloudProviderNodeStatus st : nt.getNodeInfo().getStatus()) {
      if (st == NOT_AVAILABLE_IN_REGION || st == NOT_ENABLED_ON_SUBSCRIPTION) {
        return true;
      }
    }
    return false;
  }

  public void ensureClusterIsRunning(String clusterId) throws TimeoutException {
    Duration outerTimeout = Duration.ofMinutes(20);
    long deadline = System.currentTimeMillis() + outerTimeout.toMillis();
    while (System.currentTimeMillis() < deadline) {
      try {
        ClusterDetails info = get(clusterId);
        if (info.getState() == State.TERMINATED) {
          start(clusterId).get();
        } else if (info.getState() == State.TERMINATING) {
          waitGetClusterTerminated(clusterId);
          start(clusterId).get();
        } else if (Arrays.asList(State.PENDING, State.RESIZING, State.RESTARTING)
            .contains(info.getState())) {
          waitGetClusterRunning(clusterId);
        } else if (Arrays.asList(State.ERROR, State.UNKNOWN).contains(info.getState())) {
          throw new DatabricksError(info.getState().name(), info.getStateMessage());
        }
        // running, reconfiguring
        LOG.debug("Cluster is {}: {}", info.getState(), info.getStateMessage());
        return;
      } catch (DatabricksError e) {
        if (e.getErrorCode().equals(INVALID_STATE)) {
          LOG.debug("Cluster was started by other process: {} Retrying.", e.getMessage());
          continue;
        }
        LOG.debug("Received {} error code", e.getErrorCode());
        throw e;
      }
    }
    throw new TimeoutException("Cannot ensure cluster to start");
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy