
prerna.cluster.util.RemoteClientServerZK Maven / Gradle / Ivy
The newest version!
package prerna.cluster.util;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import org.apache.curator.framework.CuratorFramework;
import org.apache.curator.framework.CuratorFrameworkFactory;
import org.apache.curator.framework.recipes.cache.CuratorCache;
import org.apache.curator.framework.recipes.cache.CuratorCacheListener;
import org.apache.curator.framework.state.ConnectionState;
import org.apache.curator.retry.RetryOneTime;
import org.apache.http.HttpEntity;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.StringEntity;
import org.apache.http.client.config.RequestConfig;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClients;
import org.apache.http.util.EntityUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.json.JSONObject;
import prerna.engine.api.RemoteModelStateEnum;
/**
* This is a Singleton class used to exclusively manage a connection to ZooKeeper for tracking RemoteClientServer deployments on our cluster.
* This will track the state of each model deployment, including whether it is warming or active.
* This class handles holding requests until models are in an active state by performing health checks on the FastAPI service running in the container.
* This is required to give the FastAPI service a grace time between when the model is added to the active path and the time it requires to start up in the container.
* This is used by engines that extend the AbstractRemoteModelEngine IE: NEREngine..
*/
public class RemoteClientServerZK implements IRemoteClientServer {
private static final Logger classLogger = LogManager.getLogger(RemoteClientServerZK.class);
private static RemoteClientServerZK instance;
private static final String WARMING_PATH = "/models/warming";
private static final String ACTIVE_PATH = "/models/active";
private static final String MODEL_SCALER_PATH = "/services/kube-model-deployer";
// Connection-related fields
private CuratorFramework client;
private String zkServer = "localhost:2181";
private boolean connected = false;
private Map env;
// State tracking
private final ConcurrentMap modelStates = new ConcurrentHashMap<>();
private final ConcurrentMap modelClusterIps = new ConcurrentHashMap<>();
private final ConcurrentMap modelNames = new ConcurrentHashMap<>();
private CuratorCache warmingCache;
private CuratorCache activeCache;
public String modelScalerIp;
private Boolean devPortFowarding = false;
private final ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(1);
private RemoteClientServerZK() {
classLogger.info("RemoteClientServerZK being initialized...");
}
public static RemoteClientServerZK getInstance() {
if(instance != null) {
return instance;
}
if (instance == null) {
synchronized (RemoteClientServerZK.class){
if (instance == null) {
instance = new RemoteClientServerZK();
instance.init();
}
}
}
return instance;
}
private void init() {
try {
env = System.getenv();
if (env.containsKey("ZK_SERVER")) {
zkServer = env.get("ZK_SERVER");
}
// Initialize Curator client
RetryOneTime retryPolicy = new RetryOneTime(1000);
client = CuratorFrameworkFactory.builder()
.connectString(zkServer)
.retryPolicy(retryPolicy)
.connectionTimeoutMs(5000)
.sessionTimeoutMs(10000)
.maxCloseWaitMs(2000)
.build();
// Add connection state listener
client.getConnectionStateListenable().addListener((client, state) -> {
classLogger.info("ZooKeeper connection state changed to: {}", state);
if (state == ConnectionState.CONNECTED || state == ConnectionState.RECONNECTED) {
connected = true;
} else if (state == ConnectionState.LOST || state == ConnectionState.SUSPENDED) {
connected = false;
}
});
client.start();
// Wait for connection
if (!client.blockUntilConnected(10, TimeUnit.SECONDS)) {
throw new IllegalStateException("Failed to connect to ZooKeeper");
}
setupZKPaths();
setupCaches();
loadInitialState();
setupCacheRefresher();
} catch (Exception e) {
classLogger.error("Failed to initialize ZooKeeper connection", e);
throw new RuntimeException("Failed to initialize ZooKeeper connection", e);
}
}
private void loadInitialState() {
try {
// Loading model scaler IP first
if (client.checkExists().forPath(MODEL_SCALER_PATH) != null) {
byte[] scalerData = client.getData().forPath(MODEL_SCALER_PATH);
if (scalerData != null && scalerData.length > 0) {
modelScalerIp = new String(scalerData, "UTF-8");
classLogger.info("Discovered model scaler IP: {}", modelScalerIp);
} else {
classLogger.error("Model scaler path exists but no IP data found");
}
}
// Load active models
List activeModels = client.getChildren().forPath(ACTIVE_PATH);
for (String modelId : activeModels) {
String path = ACTIVE_PATH + "/" + modelId;
classLogger.info("Loading data for active model at path: {}", path);
byte[] data = client.getData().forPath(path);
if (data != null && data.length > 0) {
String rawData = new String(data, "UTF-8");
JSONObject jsonData = new JSONObject(rawData);
String clusterIp = jsonData.getString("ip");
String modelName = jsonData.getString("model_name");
modelStates.put(modelId, RemoteModelStateEnum.ACTIVE);
modelClusterIps.put(modelId, clusterIp);
modelNames.put(modelId, modelName);
classLogger.info("Loaded active model {} ({}) with IP: {} from ZK.",
modelId, modelName, clusterIp);
}
}
// Load warming models
List warmingModels = client.getChildren().forPath(WARMING_PATH);
for (String modelId : warmingModels) {
modelStates.put(modelId, RemoteModelStateEnum.WARMING);
classLogger.info("Loaded warming model: {}", modelId);
}
} catch (Exception e) {
classLogger.error("Error loading initial state", e);
}
}
private void setupCacheRefresher() {
classLogger.info("Setting up cache refresher for remote models...");
scheduler.scheduleAtFixedRate(() -> {
try {
refreshModelStates();
} catch (Exception e) {
classLogger.error("Error refreshing model states", e);
}
}, 30, 30, TimeUnit.SECONDS); // Refresh every 30 seconds
}
private void refreshModelStates() throws Exception {
classLogger.debug("Refreshing model states...");
// Refresh active models
List activeModels = client.getChildren().forPath(ACTIVE_PATH);
for (String modelId : activeModels) {
modelStates.put(modelId, RemoteModelStateEnum.ACTIVE);
}
// Refresh warming models
List warmingModels = client.getChildren().forPath(WARMING_PATH);
for (String modelId : warmingModels) {
if (!modelStates.get(modelId).equals(RemoteModelStateEnum.ACTIVE)) {
modelStates.put(modelId, RemoteModelStateEnum.WARMING);
}
}
// Clean up stale entries
for (String modelId : new HashSet<>(modelStates.keySet())) {
if (!activeModels.contains(modelId) && !warmingModels.contains(modelId)) {
modelStates.put(modelId, RemoteModelStateEnum.COLD);
}
}
classLogger.debug("Refreshed model states, current map: {}", modelStates);
}
public String getModelScalerIp() {
return modelScalerIp;
}
public void close() {
try {
if (warmingCache != null) {
warmingCache.close();
}
if (activeCache != null) {
activeCache.close();
}
if (client != null) {
client.close();
}
} catch (Exception e) {
classLogger.error("Error closing resources", e);
}
}
private void setupZKPaths() {
try {
// Ensure base paths exist
if (client.checkExists().forPath(WARMING_PATH) == null) {
client.create().creatingParentsIfNeeded().forPath(WARMING_PATH);
}
if (client.checkExists().forPath(ACTIVE_PATH) == null) {
client.create().creatingParentsIfNeeded().forPath(ACTIVE_PATH);
}
} catch (Exception e) {
classLogger.error("Error setting up ZK paths", e);
}
}
private void setupCaches() {
try {
// Set up cache for warming path
warmingCache = CuratorCache.build(client, WARMING_PATH);
CuratorCacheListener warmingListener = CuratorCacheListener.builder()
.forCreates(node -> {
String modelId = getModelIdFromPath(node.getPath());
classLogger.info("Model {} entered warming state", modelId);
modelStates.put(modelId, RemoteModelStateEnum.WARMING);
})
.forDeletes(node -> {
String modelId = getModelIdFromPath(node.getPath());
classLogger.info("Model {} left warming state", modelId);
// Only update if not active
if (!modelStates.get(modelId).equals(RemoteModelStateEnum.ACTIVE)) {
modelStates.put(modelId, RemoteModelStateEnum.COLD);
}
})
.build();
warmingCache.listenable().addListener(warmingListener);
warmingCache.start();
// Set up cache for active path
activeCache = CuratorCache.build(client, ACTIVE_PATH);
CuratorCacheListener activeListener = CuratorCacheListener.builder()
.forCreates(node -> {
String modelId = getModelIdFromPath(node.getPath());
classLogger.info("Model {} became active, processing data from path: {}", modelId, node.getPath());
modelStates.put(modelId, RemoteModelStateEnum.ACTIVE);
try {
// First try immediate read
byte[] data = node.getData();
if (data == null || data.length == 0) {
classLogger.debug("Initial data read empty for {}, starting retry sequence", modelId);
// If empty Retry
data = getNodeDataWithRetry(node.getPath(), 5, 500); // 5 retries, 500ms delay
}
if (data == null || data.length == 0) {
classLogger.error("No data found for active model {} after retries", modelId);
return;
}
String rawData = new String(data, "UTF-8");
classLogger.info("Processing raw data for model {}: {}", modelId, rawData);
JSONObject jsonData = new JSONObject(rawData);
String clusterIp = jsonData.getString("ip");
String modelName = jsonData.getString("model_name");
modelClusterIps.put(modelId, clusterIp);
modelNames.put(modelId, modelName);
classLogger.info("Successfully registered model {} ({}) with IP: {}",
modelId, modelName, clusterIp);
} catch (Exception e) {
classLogger.error("Error processing data for model {}: {}", modelId, e.getMessage(), e);
}
})
.forDeletes(node -> {
String modelId = getModelIdFromPath(node.getPath());
String modelName = modelNames.get(modelId);
classLogger.info("Model {} ({}) is no longer active", modelId, modelName);
modelClusterIps.remove(modelId);
modelNames.remove(modelId);
if (isModelWarming(modelId)) {
modelStates.put(modelId, RemoteModelStateEnum.WARMING);
} else {
modelStates.put(modelId, RemoteModelStateEnum.COLD);
}
})
.build();
activeCache.listenable().addListener(activeListener);
activeCache.start();
} catch (Exception e) {
classLogger.error("Error setting up ZK caches", e);
}
}
private byte[] getNodeDataWithRetry(String path, int maxRetries, long delayMs) {
for (int i = 0; i < maxRetries; i++) {
try {
byte[] data = client.getData().forPath(path);
if (data != null && data.length > 0) {
return data;
}
Thread.sleep(delayMs);
} catch (Exception e) {
classLogger.debug("Attempt {} to read data from {} failed: {}",
i + 1, path, e.getMessage());
}
}
return null;
}
public String getModelName(String modelId) {
return modelNames.get(modelId);
}
private String getModelIdFromPath(String path) {
return path.substring(path.lastIndexOf('/') + 1);
}
public RemoteModelStateEnum getModelState(String modelId) {
return modelStates.getOrDefault(modelId, RemoteModelStateEnum.COLD);
}
private String getModelPath(String modelId) {
return modelId;
}
public String getModelClusterIp(String modelId) {
try {
String fullPath = ACTIVE_PATH + "/" + modelId;
if (client.checkExists().forPath(fullPath) == null) {
classLogger.error("Path does not exist: {}", fullPath);
return null;
}
byte[] data = client.getData().forPath(fullPath);
if (data == null || data.length == 0) {
classLogger.error("No data found at path: {}", fullPath);
return null;
}
String rawData = new String(data, "UTF-8");
JSONObject jsonData = new JSONObject(rawData);
String ip = jsonData.getString("ip");
// Update cache maps
modelClusterIps.put(modelId, ip);
modelNames.put(modelId, jsonData.getString("model_name"));
return ip;
} catch (Exception e) {
classLogger.error("Error getting cluster IP for model {}: {}", modelId, e.getMessage(), e);
return null;
}
}
public boolean isModelWarming(String modelId) {
try {
String modelPath = getModelPath(modelId);
return client.checkExists().forPath(WARMING_PATH + "/" + modelPath) != null;
} catch (Exception e) {
classLogger.error("Error checking warming state for model {}", modelId, e);
return false;
}
}
public boolean isModelActive(String modelId) {
try {
String modelPath = getModelPath(modelId);
return client.checkExists().forPath(ACTIVE_PATH + "/" + modelPath) != null;
} catch (Exception e) {
classLogger.error("Error checking active state for model {}", modelId, e);
return false;
}
}
// I need this because there is a period of time between when the model is on the active path but the FastAPI service is not quite ready
private boolean checkModelHealth(String modelId) {
String clusterIp = modelClusterIps.get(modelId);
String modelName = modelNames.get(modelId);
if (clusterIp == null || clusterIp.trim().isEmpty()) {
classLogger.error("No valid cluster IP available for health check of model {} ({})",
modelId, modelName);
return false;
}
String healthUrl = devPortFowarding ?
"http://localhost:8888/v2/health/ready" :
String.format("http://%s/v2/health/ready", clusterIp);
classLogger.info("Attempting health check at URL: {}", healthUrl);
RequestConfig requestConfig = RequestConfig.custom()
.setConnectTimeout(1000)
.setSocketTimeout(1000)
.build();
try (CloseableHttpClient httpClient = HttpClients.custom()
.setDefaultRequestConfig(requestConfig)
.build()) {
HttpGet httpGet = new HttpGet(healthUrl);
try (CloseableHttpResponse response = httpClient.execute(httpGet)) {
int statusCode = response.getStatusLine().getStatusCode();
if (statusCode == 200) {
HttpEntity entity = response.getEntity();
if (entity != null) {
String responseString = EntityUtils.toString(entity);
JSONObject healthResponse = new JSONObject(responseString);
if ("ok".equals(healthResponse.optString("status"))) {
return true;
}
}
}
}
} catch (Exception e) {
classLogger.error("Health check failed for model {} ({}): {}",
modelId, modelName, e.getMessage());
}
return false;
}
public boolean waitForModelActive(String modelId, long timeoutMs) {
try {
long startTime = System.currentTimeMillis();
boolean foundInActivePath = false;
boolean isHealthy = false;
// Waiting for model to appear in active path
while (System.currentTimeMillis() - startTime < timeoutMs) {
if (isModelActive(modelId)) {
classLogger.info("Model {} was found in the active path", modelId);
foundInActivePath = true;
break;
} else {
classLogger.info("Model {} in a warming wait loop..", modelId);
}
Thread.sleep(3000);
}
if (!foundInActivePath) {
classLogger.warn("Timeout waiting for model {} to appear in active path after {}ms",
modelId, timeoutMs);
return false;
}
// Wait for container to be healthy
long healthCheckStart = System.currentTimeMillis();
long remainingTimeout = timeoutMs - (healthCheckStart - startTime);
while (System.currentTimeMillis() - healthCheckStart < remainingTimeout) {
if (checkModelHealth(modelId)) {
classLogger.info("Model {} health check passed", modelId);
isHealthy = true;
break;
}
Thread.sleep(1000);
}
if (!isHealthy) {
classLogger.warn("Timeout waiting for model {} to become healthy after appearing in active path",
modelId);
return false;
}
return true;
} catch (Exception e) {
classLogger.error("Error waiting for model {} to become active", modelId, e);
return false;
}
}
/**
* Waits for a model to reach a specific state, with a timeout
* @param modelId The ID of the model to wait for
* @param desiredState The state to wait for
* @param timeoutMs Maximum time to wait in milliseconds
* @return true if the model reached the desired state within the timeout, false otherwise
*/
public boolean waitForState(String modelId, RemoteModelStateEnum desiredState, long timeoutMs) {
try {
long startTime = System.currentTimeMillis();
while (System.currentTimeMillis() - startTime < timeoutMs) {
RemoteModelStateEnum currentState = getModelState(modelId);
if (currentState == desiredState) {
return true;
}
// If we're waiting for ACTIVE but hit FAILED, break early
if (desiredState == RemoteModelStateEnum.ACTIVE &&
currentState == RemoteModelStateEnum.FAILED) {
classLogger.error("Model {} failed while waiting for active state", modelId);
return false;
}
Thread.sleep(1000); // 1 second
}
classLogger.warn("Timeout waiting for model {} to reach state {} after {}ms",
modelId, desiredState, timeoutMs);
return false;
} catch (Exception e) {
classLogger.error("Error waiting for model {} to reach state {}", modelId, desiredState, e);
return false;
}
}
/**
* Gets a list of all active models with their associated information
* @return List of ModelInfo objects containing model details
*/
public List getActiveModels() {
List activeModels = new ArrayList<>();
try {
List activeModelIds = client.getChildren().forPath(ACTIVE_PATH);
for (String modelId : activeModelIds) {
String name = modelNames.get(modelId);
RemoteModelStateEnum state = modelStates.getOrDefault(modelId, RemoteModelStateEnum.COLD);
if (name != null) {
activeModels.add(new RemoteModelInfo(
modelId,
name,
state
));
} else {
classLogger.warn("Incomplete data for active model {}: name={}", modelId, name);
String fullPath = ACTIVE_PATH + "/" + modelId;
byte[] data = client.getData().forPath(fullPath);
if (data != null && data.length > 0) {
String rawData = new String(data, "UTF-8");
JSONObject jsonData = new JSONObject(rawData);
name = jsonData.getString("model_name");
activeModels.add(new RemoteModelInfo(
modelId,
name,
state
));
modelNames.put(modelId, name);
}
}
}
} catch (Exception e) {
classLogger.error("Error getting active models with details", e);
}
return activeModels;
}
/**
* Gets a list of all warming models with their associated information
* @return List of RemoteModelInfo objects containing model details
*/
public List getWarmingModels() {
List warmingModels = new ArrayList<>();
try {
List warmingModelIds = client.getChildren().forPath(WARMING_PATH);
for (String modelId : warmingModelIds) {
String name = modelNames.get(modelId);
RemoteModelStateEnum state = modelStates.getOrDefault(modelId, RemoteModelStateEnum.WARMING);
warmingModels.add(new RemoteModelInfo(
modelId,
name != null ? name : "Warming...",
state
));
}
} catch (Exception e) {
classLogger.error("Error getting warming models with details", e);
}
return warmingModels;
}
public Map canItRun(String hfModelId) throws Exception {
String modelScalerIp = getModelScalerIp();
if (modelScalerIp == null) {
classLogger.error("Unable to get model scaler IP from ZooKeeper");
throw new RuntimeException("Failed to get model scaler IP");
}
String canItRunUrl;
if (devPortFowarding) {
canItRunUrl = "http://localhost:8000/api/can-it-run";
} else {
canItRunUrl = String.format("http://%s/api/can-it-run", modelScalerIp);
}
RequestConfig requestConfig = RequestConfig.custom()
.setConnectTimeout(5000)
.setSocketTimeout(5000)
.build();
try (CloseableHttpClient httpClient = HttpClients.custom()
.setDefaultRequestConfig(requestConfig)
.build()) {
HttpPost httpPost = new HttpPost(canItRunUrl);
httpPost.setHeader("Content-Type", "application/json");
JSONObject requestBody = new JSONObject();
requestBody.put("model_id", hfModelId);
StringEntity entity = new StringEntity(requestBody.toString());
httpPost.setEntity(entity);
try (CloseableHttpResponse response = httpClient.execute(httpPost)) {
int statusCode = response.getStatusLine().getStatusCode();
HttpEntity responseEntity = response.getEntity();
String responseString = EntityUtils.toString(responseEntity);
if (statusCode == 200) {
JSONObject jsonResponse = new JSONObject(responseString);
Map result = new HashMap<>();
for (String key : jsonResponse.keySet()) {
result.put(key, jsonResponse.get(key));
}
classLogger.info("Successfully checked compatibility for model: {} - Can run: {}",
hfModelId, result.get("can_run"));
return result;
} else {
JSONObject errorResponse = new JSONObject(responseString);
String errorMessage = errorResponse.getJSONObject("detail").getString("message");
classLogger.error("Error checking model compatibility: {} (Status: {})",
errorMessage, statusCode);
throw new RuntimeException("Failed to check model compatibility: " + errorMessage);
}
}
} catch (Exception e) {
classLogger.error("Error making request to model scaler: {}", e.getMessage(), e);
throw new RuntimeException("Failed to check model compatibility", e);
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy