io.ray.serve.poll.LongPollClientFactory Maven / Gradle / Ivy
package io.ray.serve.poll;
import com.google.common.base.Preconditions;
import io.ray.api.ActorHandle;
import io.ray.api.BaseActorHandle;
import io.ray.api.ObjectRef;
import io.ray.api.PyActorHandle;
import io.ray.api.Ray;
import io.ray.api.exception.RayActorException;
import io.ray.api.exception.RayTaskException;
import io.ray.api.exception.RayTimeoutException;
import io.ray.api.function.PyActorMethod;
import io.ray.serve.api.Serve;
import io.ray.serve.common.Constants;
import io.ray.serve.config.RayServeConfig;
import io.ray.serve.controller.ServeController;
import io.ray.serve.generated.ActorNameList;
import io.ray.serve.replica.ReplicaContext;
import io.ray.serve.util.CollectionUtil;
import io.ray.serve.util.ServeProtoUtil;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import org.apache.commons.lang3.builder.ReflectionToStringBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/** The long poll client factory that holds a asynchronous singleton thread. */
public class LongPollClientFactory {
private static final Logger LOGGER = LoggerFactory.getLogger(LongPollClientFactory.class);
/** Handle to actor embedding LongPollHost. */
private static BaseActorHandle hostActor;
/** A set mapping keys to callbacks to be called on state update for the corresponding keys. */
private static final Map KEY_LISTENERS = new ConcurrentHashMap<>();
public static final Map SNAPSHOT_IDS = new ConcurrentHashMap<>();
public static final Map OBJECT_SNAPSHOTS = new ConcurrentHashMap<>();
private static ScheduledExecutorService scheduledExecutorService;
private static boolean inited = false;
private static long longPollTimoutS = 1L;
public static final Map> DESERIALIZERS =
new HashMap<>();
static {
DESERIALIZERS.put(LongPollNamespace.ROUTE_TABLE, ServeProtoUtil::parseEndpointSet);
DESERIALIZERS.put(
LongPollNamespace.RUNNING_REPLICAS,
bytes -> ServeProtoUtil.bytesToProto(bytes, ActorNameList::parseFrom));
}
public static void register(BaseActorHandle hostActor, Map keyListeners) {
init(hostActor);
if (!inited) {
return;
}
KEY_LISTENERS.putAll(keyListeners);
for (KeyType keyType : keyListeners.keySet()) {
SNAPSHOT_IDS.put(keyType, -1);
}
LOGGER.info("LongPollClient registered keys: {}.", keyListeners.keySet());
try {
pollNext();
} catch (RayTimeoutException e) {
LOGGER.info("Register poll timeout. keys:{}", keyListeners.keySet());
}
}
public static synchronized void init(BaseActorHandle hostActor) {
if (inited) {
return;
}
long intervalS = 10L;
try {
ReplicaContext replicaContext = Serve.getReplicaContext();
boolean enabled =
Optional.ofNullable(replicaContext.getConfig())
.map(config -> config.get(RayServeConfig.LONG_POOL_CLIENT_ENABLED))
.map(Boolean::valueOf)
.orElse(true);
if (!enabled) {
LOGGER.info("LongPollClient is disabled.");
return;
}
if (null == hostActor) {
hostActor = Ray.getActor(Constants.SERVE_CONTROLLER_NAME, Constants.SERVE_NAMESPACE).get();
}
intervalS =
Optional.ofNullable(replicaContext.getConfig())
.map(config -> config.get(RayServeConfig.LONG_POOL_CLIENT_INTERVAL))
.map(Long::valueOf)
.orElse(10L);
longPollTimoutS =
Optional.ofNullable(replicaContext.getConfig())
.map(config -> config.get(RayServeConfig.LONG_POOL_CLIENT_TIMEOUT_S))
.map(Long::valueOf)
.orElse(10L);
} catch (Exception e) {
LOGGER.info(
"Serve.getReplicaContext()` may only be called from within a Ray Serve deployment.");
}
Preconditions.checkNotNull(hostActor);
LongPollClientFactory.hostActor = hostActor;
scheduledExecutorService =
Executors.newSingleThreadScheduledExecutor(
new ThreadFactory() {
@Override
public Thread newThread(Runnable r) {
Thread thread = new Thread(r, "ray-serve-long-poll-client-thread");
thread.setDaemon(true);
return thread;
}
});
long finalIntervalS = intervalS;
scheduledExecutorService.scheduleWithFixedDelay(
() -> {
try {
pollNext();
} catch (RayTimeoutException e) {
LOGGER.info(
"long poll timeout in {} seconds, execute next poll after {} seconds.",
longPollTimoutS,
finalIntervalS);
} catch (RayActorException e) {
LOGGER.error("LongPollClient failed to connect to host. Shutting down.");
stop();
} catch (RayTaskException e) {
LOGGER.error("LongPollHost errored", e);
} catch (Throwable e) {
LOGGER.error("LongPollClient failed to update object of key {}", SNAPSHOT_IDS, e);
}
},
0L,
intervalS,
TimeUnit.SECONDS);
inited = true;
LOGGER.info("LongPollClient was initialized");
}
/** Poll the updates. */
@SuppressWarnings("unchecked")
public static synchronized void pollNext() {
LOGGER.info("LongPollClient polls next snapshotIds {}", SNAPSHOT_IDS);
LongPollRequest longPollRequest = new LongPollRequest(SNAPSHOT_IDS);
LongPollResult longPollResult = null;
if (hostActor instanceof PyActorHandle) {
// Poll from python controller.
ObjectRef
© 2015 - 2024 Weber Informatics LLC | Privacy Policy