
prerna.cluster.util.RemoteClientServerZKRESTProxy Maven / Gradle / Ivy
The newest version!
package prerna.cluster.util;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import org.apache.http.HttpEntity;
import org.apache.http.client.config.RequestConfig;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.client.methods.HttpHead;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClients;
import org.apache.http.util.EntityUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.json.JSONArray;
import org.json.JSONObject;
import prerna.engine.api.RemoteModelStateEnum;
/**
* This class provides the same functionality as RemoteClientServerZK but connects to
* the REST proxy for ZooKeeper instead of connecting directly to ZooKeeper.
* It is designed for local development environments when direct ZooKeeper access isn't available.
*/
public class RemoteClientServerZKRESTProxy implements IRemoteClientServer {
private static final Logger classLogger = LogManager.getLogger(RemoteClientServerZKRESTProxy.class);
private static RemoteClientServerZKRESTProxy instance;
private static final String WARMING_PATH = "/models/warming";
private static final String ACTIVE_PATH = "/models/active";
private static final String MODEL_SCALER_PATH = "/services/kube-model-deployer";
private String zkRestProxyBaseUrl;
private final ConcurrentMap modelStates = new ConcurrentHashMap<>();
private final ConcurrentMap modelClusterIps = new ConcurrentHashMap<>();
private final ConcurrentMap modelNames = new ConcurrentHashMap<>();
// Periodic refreshes
private final ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(1);
private RemoteClientServerZKRESTProxy() {
classLogger.info("RemoteClientServerZKRestProxy being initialized...");
}
public static RemoteClientServerZKRESTProxy getInstance() {
if (instance != null) {
return instance;
}
if (instance == null) {
synchronized (RemoteClientServerZKRESTProxy.class) {
if (instance == null) {
instance = new RemoteClientServerZKRESTProxy();
instance.init();
}
}
}
return instance;
}
private void init() {
try {
Map env = System.getenv();
if (env.containsKey("ZK_INGRESS")) {
zkRestProxyBaseUrl = env.get("ZK_INGRESS");
if (!zkRestProxyBaseUrl.endsWith("/")) {
zkRestProxyBaseUrl += "/";
}
classLogger.info("Using ZK_INGRESS from environment: {}", zkRestProxyBaseUrl);
} else {
throw new IllegalStateException("ZK_INGRESS environment variable is required for ZK REST Proxy mode");
}
// Validate connection to ZK REST Proxy
validateConnection();
// Initial load of model states
loadInitialState();
// Setup periodic refresh
setupCacheRefresher();
} catch (Exception e) {
classLogger.error("Failed to initialize ZooKeeper REST Proxy connection", e);
throw new RuntimeException("Failed to initialize ZooKeeper REST Proxy connection", e);
}
}
private void validateConnection() throws Exception {
String healthCheckUrl = zkRestProxyBaseUrl + "health";
RequestConfig requestConfig = RequestConfig.custom()
.setConnectTimeout(5000)
.setSocketTimeout(5000)
.build();
try (CloseableHttpClient httpClient = HttpClients.custom()
.setDefaultRequestConfig(requestConfig)
.build()) {
HttpGet httpGet = new HttpGet(healthCheckUrl);
try (CloseableHttpResponse response = httpClient.execute(httpGet)) {
int statusCode = response.getStatusLine().getStatusCode();
if (statusCode != 200) {
throw new IllegalStateException("ZooKeeper REST Proxy health check failed with status code: " + statusCode);
}
classLogger.info("Successfully connected to ZooKeeper REST Proxy");
}
} catch (IOException e) {
classLogger.error("Failed to connect to ZooKeeper REST Proxy", e);
throw new IllegalStateException("Failed to connect to ZooKeeper REST Proxy: " + e.getMessage(), e);
}
}
/**
* Gets ZNode data as a JSONObject from the REST proxy response.
* This method extracts the nested "data" object from the response.
*
* @param path The ZNode path
* @return JSONObject containing the data field contents, or null if error/not found
*/
private JSONObject getZNodeDataAsJson(String path) {
String formattedPath = path;
if (formattedPath.startsWith("/")) {
formattedPath = formattedPath.substring(1);
}
String url = zkRestProxyBaseUrl + "znodes/v1/" + formattedPath;
RequestConfig requestConfig = RequestConfig.custom()
.setConnectTimeout(5000)
.setSocketTimeout(5000)
.build();
try (CloseableHttpClient httpClient = HttpClients.custom()
.setDefaultRequestConfig(requestConfig)
.build()) {
HttpGet httpGet = new HttpGet(url);
try (CloseableHttpResponse response = httpClient.execute(httpGet)) {
int statusCode = response.getStatusLine().getStatusCode();
if (statusCode == 404) {
classLogger.warn("ZNode not found at path: {}", path);
return null;
}
if (statusCode != 200) {
classLogger.error("Failed to get ZNode data, status code: {}", statusCode);
return null;
}
HttpEntity entity = response.getEntity();
if (entity != null) {
String jsonResponse = EntityUtils.toString(entity);
JSONObject responseObj = new JSONObject(jsonResponse);
if (responseObj.has("data") && !responseObj.isNull("data")) {
// The data field is a nested JSON object
JSONObject dataObj = responseObj.getJSONObject("data");
return dataObj;
}
}
}
} catch (Exception e) {
classLogger.error("Error getting ZNode data for path {}: {}", path, e.getMessage());
}
return null;
}
private void loadInitialState() {
try {
String modelScalerIpData = getZNodeData(MODEL_SCALER_PATH);
if (modelScalerIpData != null) {
classLogger.info("Model scaler IP from REST proxy: {}", modelScalerIpData);
}
List activeModels = getZNodeChildren(ACTIVE_PATH);
for (String modelId : activeModels) {
String modelPath = ACTIVE_PATH + "/" + modelId;
classLogger.info("Loading data for active model at path: {}", modelPath);
JSONObject modelData = getZNodeDataAsJson(modelPath);
if (modelData != null) {
String clusterIp = modelData.getString("ip");
String modelName = modelData.getString("model_name");
modelStates.put(modelId, RemoteModelStateEnum.ACTIVE);
modelClusterIps.put(modelId, clusterIp);
modelNames.put(modelId, modelName);
classLogger.info("Loaded active model {} ({}) with IP: {} from ZK REST Proxy.",
modelId, modelName, clusterIp);
}
}
List warmingModels = getZNodeChildren(WARMING_PATH);
for (String modelId : warmingModels) {
modelStates.put(modelId, RemoteModelStateEnum.WARMING);
classLogger.info("Loaded warming model: {}", modelId);
}
} catch (Exception e) {
classLogger.error("Error loading initial state from ZK REST Proxy", e);
}
}
private void setupCacheRefresher() {
classLogger.info("Setting up cache refresher for remote models...");
scheduler.scheduleAtFixedRate(() -> {
try {
refreshModelStates();
} catch (Exception e) {
classLogger.error("Error refreshing model states", e);
}
}, 30, 30, TimeUnit.SECONDS); // Refresh every 30 seconds
}
private void refreshModelStates() {
classLogger.debug("Refreshing model states from ZK REST Proxy...");
try {
List activeModels = getZNodeChildren(ACTIVE_PATH);
for (String modelId : activeModels) {
modelStates.put(modelId, RemoteModelStateEnum.ACTIVE);
}
List warmingModels = getZNodeChildren(WARMING_PATH);
for (String modelId : warmingModels) {
if (!modelStates.getOrDefault(modelId, RemoteModelStateEnum.COLD).equals(RemoteModelStateEnum.ACTIVE)) {
modelStates.put(modelId, RemoteModelStateEnum.WARMING);
}
}
for (String modelId : activeModels) {
String path = ACTIVE_PATH + "/" + modelId;
updateModelInfoFromZNode(modelId, path);
}
// Clean up stale entries
for (String modelId : new ArrayList<>(modelStates.keySet())) {
if (!activeModels.contains(modelId) && !warmingModels.contains(modelId)) {
modelStates.put(modelId, RemoteModelStateEnum.COLD);
}
}
classLogger.debug("Refreshed model states, current map: {}", modelStates);
} catch (Exception e) {
classLogger.error("Error refreshing model states from ZK REST Proxy", e);
}
}
private void updateModelInfoFromZNode(String modelId, String path) {
JSONObject modelData = getZNodeDataAsJson(path);
if (modelData != null) {
try {
String clusterIp = modelData.getString("ip");
String modelName = modelData.getString("model_name");
// Update cache maps
modelClusterIps.put(modelId, clusterIp);
modelNames.put(modelId, modelName);
classLogger.debug("Updated model info for {}: name={}, ip={}",
modelId, modelName, clusterIp);
} catch (Exception e) {
classLogger.error("Error extracting model data fields for {}: {}", modelId, e.getMessage());
}
}
}
private String getZNodeData(String path) {
String formattedPath = path;
if (formattedPath.startsWith("/")) {
formattedPath = formattedPath.substring(1);
}
String url = zkRestProxyBaseUrl + "znodes/v1/" + formattedPath;
RequestConfig requestConfig = RequestConfig.custom()
.setConnectTimeout(5000)
.setSocketTimeout(5000)
.build();
try (CloseableHttpClient httpClient = HttpClients.custom()
.setDefaultRequestConfig(requestConfig)
.build()) {
HttpGet httpGet = new HttpGet(url);
try (CloseableHttpResponse response = httpClient.execute(httpGet)) {
int statusCode = response.getStatusLine().getStatusCode();
if (statusCode == 404) {
classLogger.warn("ZNode not found at path: {}", path);
return null;
}
if (statusCode != 200) {
classLogger.error("Failed to get ZNode data, status code: {}", statusCode);
return null;
}
HttpEntity entity = response.getEntity();
if (entity != null) {
String jsonResponse = EntityUtils.toString(entity);
JSONObject responseObj = new JSONObject(jsonResponse);
if (responseObj.has("data") && !responseObj.isNull("data")) {
return responseObj.getString("data");
}
}
}
} catch (Exception e) {
classLogger.error("Error getting ZNode data for path {}: {}", path, e.getMessage());
}
return null;
}
private List getZNodeChildren(String path) {
List children = new ArrayList<>();
String formattedPath = path;
if (formattedPath.startsWith("/")) {
formattedPath = formattedPath.substring(1);
}
String url = zkRestProxyBaseUrl + "znodes/v1/" + formattedPath + "?view=children";
RequestConfig requestConfig = RequestConfig.custom()
.setConnectTimeout(5000)
.setSocketTimeout(5000)
.build();
try (CloseableHttpClient httpClient = HttpClients.custom()
.setDefaultRequestConfig(requestConfig)
.build()) {
HttpGet httpGet = new HttpGet(url);
try (CloseableHttpResponse response = httpClient.execute(httpGet)) {
int statusCode = response.getStatusLine().getStatusCode();
if (statusCode == 404) {
classLogger.warn("ZNode not found at path: {}", path);
return children;
}
if (statusCode != 200) {
classLogger.error("Failed to get ZNode children, status code: {}", statusCode);
return children;
}
HttpEntity entity = response.getEntity();
if (entity != null) {
String jsonResponse = EntityUtils.toString(entity);
JSONObject responseObj = new JSONObject(jsonResponse);
if (responseObj.has("children")) {
JSONArray childrenArray = responseObj.getJSONArray("children");
for (int i = 0; i < childrenArray.length(); i++) {
children.add(childrenArray.getString(i));
}
}
}
}
} catch (Exception e) {
classLogger.error("Error getting ZNode children for path {}: {}", path, e.getMessage());
}
return children;
}
private boolean checkZNodeExists(String path) {
String formattedPath = path;
if (formattedPath.startsWith("/")) {
formattedPath = formattedPath.substring(1);
}
String url = zkRestProxyBaseUrl + "znodes/v1/" + formattedPath;
RequestConfig requestConfig = RequestConfig.custom()
.setConnectTimeout(3000)
.setSocketTimeout(3000)
.build();
try (CloseableHttpClient httpClient = HttpClients.custom()
.setDefaultRequestConfig(requestConfig)
.build()) {
HttpHead httpHead = new HttpHead(url);
try (CloseableHttpResponse response = httpClient.execute(httpHead)) {
return response.getStatusLine().getStatusCode() == 200;
}
} catch (Exception e) {
classLogger.error("Error checking if ZNode exists at path {}: {}", path, e.getMessage());
}
return false;
}
public void close() {
try {
if (scheduler != null && !scheduler.isShutdown()) {
scheduler.shutdown();
try {
if (!scheduler.awaitTermination(5, TimeUnit.SECONDS)) {
scheduler.shutdownNow();
}
} catch (InterruptedException e) {
scheduler.shutdownNow();
Thread.currentThread().interrupt();
}
}
} catch (Exception e) {
classLogger.error("Error closing resources", e);
}
}
@Override
public RemoteModelStateEnum getModelState(String modelId) {
return modelStates.getOrDefault(modelId, RemoteModelStateEnum.COLD);
}
@Override
public String getModelClusterIp(String modelId) {
String clusterIp = modelClusterIps.get(modelId);
if (clusterIp == null) {
String path = ACTIVE_PATH + "/" + modelId;
JSONObject modelData = getZNodeDataAsJson(path);
if (modelData != null && !modelData.isEmpty()) {
try {
classLogger.info("modelData: {}", modelData);
clusterIp = modelData.getString("ip");
String modelName = modelData.getString("model_name");
modelClusterIps.put(modelId, clusterIp);
modelNames.put(modelId, modelName);
} catch (Exception e) {
classLogger.error("Error parsing model data for {}: {}", modelId, e.getMessage());
}
}
}
return clusterIp;
}
@Override
public String getModelName(String modelId) {
return modelNames.get(modelId);
}
@Override
public boolean isModelWarming(String modelId) {
String path = WARMING_PATH + "/" + modelId;
return checkZNodeExists(path);
}
@Override
public boolean isModelActive(String modelId) {
String path = ACTIVE_PATH + "/" + modelId;
return checkZNodeExists(path);
}
@Override
public boolean waitForModelActive(String modelId, long timeoutMs) {
try {
long startTime = System.currentTimeMillis();
boolean foundInActivePath = false;
boolean isHealthy = false;
while (System.currentTimeMillis() - startTime < timeoutMs) {
if (isModelActive(modelId)) {
classLogger.info("Model {} was found in the active path", modelId);
foundInActivePath = true;
break;
} else {
classLogger.info("Model {} in a warming wait loop..", modelId);
}
Thread.sleep(3000);
}
if (!foundInActivePath) {
classLogger.warn("Timeout waiting for model {} to appear in active path after {}ms",
modelId, timeoutMs);
return false;
}
long healthCheckStart = System.currentTimeMillis();
long remainingTimeout = timeoutMs - (healthCheckStart - startTime);
String clusterIp = getModelClusterIp(modelId);
if (clusterIp == null) {
classLogger.error("No cluster IP available for model {}", modelId);
return false;
}
try {
long waitTime = Math.min(remainingTimeout, 1000000);
Thread.sleep(waitTime);
isHealthy = true;
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
classLogger.warn("Interrupted while waiting for model to become healthy");
}
modelStates.put(modelId, RemoteModelStateEnum.ACTIVE);
return isHealthy;
} catch (Exception e) {
classLogger.error("Error waiting for model {} to become active", modelId, e);
return false;
}
}
@Override
public boolean waitForState(String modelId, RemoteModelStateEnum desiredState, long timeoutMs) {
try {
long startTime = System.currentTimeMillis();
while (System.currentTimeMillis() - startTime < timeoutMs) {
RemoteModelStateEnum currentState = getModelState(modelId);
if (currentState == desiredState) {
return true;
}
// If we're waiting for ACTIVE but the model is COLD, try to check active path directly
if (desiredState == RemoteModelStateEnum.ACTIVE && currentState == RemoteModelStateEnum.COLD) {
if (isModelActive(modelId)) {
modelStates.put(modelId, RemoteModelStateEnum.ACTIVE);
return true;
}
}
Thread.sleep(1000); // 1 second
}
classLogger.warn("Timeout waiting for model {} to reach state {} after {}ms",
modelId, desiredState, timeoutMs);
return false;
} catch (Exception e) {
classLogger.error("Error waiting for model {} to reach state {}", modelId, desiredState, e);
return false;
}
}
@Override
public List getActiveModels() {
List activeModels = new ArrayList<>();
try {
List activeModelIds = getZNodeChildren(ACTIVE_PATH);
for (String modelId : activeModelIds) {
String name = modelNames.get(modelId);
RemoteModelStateEnum state = modelStates.getOrDefault(modelId, RemoteModelStateEnum.ACTIVE);
if (name != null) {
activeModels.add(new RemoteModelInfo(
modelId,
name,
state
));
} else {
String path = ACTIVE_PATH + "/" + modelId;
JSONObject modelData = getZNodeDataAsJson(path);
if (modelData != null) {
name = modelData.getString("model_name");
activeModels.add(new RemoteModelInfo(
modelId,
name,
state
));
modelNames.put(modelId, name);
} else {
activeModels.add(new RemoteModelInfo(
modelId,
"Unknown",
state
));
}
}
}
} catch (Exception e) {
classLogger.error("Error getting active models with details", e);
}
return activeModels;
}
@Override
public List getWarmingModels() {
List warmingModels = new ArrayList<>();
try {
List warmingModelIds = getZNodeChildren(WARMING_PATH);
for (String modelId : warmingModelIds) {
String name = modelNames.get(modelId);
RemoteModelStateEnum state = modelStates.getOrDefault(modelId, RemoteModelStateEnum.WARMING);
warmingModels.add(new RemoteModelInfo(
modelId,
name != null ? name : "Warming...",
state
));
}
} catch (Exception e) {
classLogger.error("Error getting warming models with details", e);
}
return warmingModels;
}
@Override
public Map canItRun(String hfModelId) throws Exception {
// Always assume it can run in dev mode i guess
Map result = new HashMap<>();
result.put("can_run", true);
result.put("message", "Development mode - assuming model can run");
classLogger.info("Development mode: Assuming model {} can run", hfModelId);
return result;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy