All Downloads are FREE. Search and download functionalities are using the official Maven repository.
Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
io.ray.serve.poll.LongPollClientFactory Maven / Gradle / Ivy
package io.ray.serve.poll;
import com.google.common.base.Preconditions;
import io.ray.api.ActorHandle;
import io.ray.api.BaseActorHandle;
import io.ray.api.ObjectRef;
import io.ray.api.PyActorHandle;
import io.ray.api.Ray;
import io.ray.api.exception.RayActorException;
import io.ray.api.exception.RayTaskException;
import io.ray.api.exception.RayTimeoutException;
import io.ray.api.function.PyActorMethod;
import io.ray.serve.api.Serve;
import io.ray.serve.common.Constants;
import io.ray.serve.config.RayServeConfig;
import io.ray.serve.controller.ServeController;
import io.ray.serve.generated.ActorNameList;
import io.ray.serve.replica.ReplicaContext;
import io.ray.serve.util.CollectionUtil;
import io.ray.serve.util.ServeProtoUtil;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import org.apache.commons.lang3.builder.ReflectionToStringBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/** The long poll client factory that holds a asynchronous singleton thread. */
public class LongPollClientFactory {
private static final Logger LOGGER = LoggerFactory.getLogger(LongPollClientFactory.class);
/** Handle to actor embedding LongPollHost. */
private static BaseActorHandle hostActor;
/** A set mapping keys to callbacks to be called on state update for the corresponding keys. */
private static final Map KEY_LISTENERS = new ConcurrentHashMap<>();
public static final Map SNAPSHOT_IDS = new ConcurrentHashMap<>();
public static final Map OBJECT_SNAPSHOTS = new ConcurrentHashMap<>();
private static ScheduledExecutorService scheduledExecutorService;
private static boolean inited = false;
private static long longPollTimoutS = 1L;
public static final Map> DESERIALIZERS =
new HashMap<>();
static {
DESERIALIZERS.put(LongPollNamespace.ROUTE_TABLE, ServeProtoUtil::parseEndpointSet);
DESERIALIZERS.put(
LongPollNamespace.RUNNING_REPLICAS,
bytes -> ServeProtoUtil.bytesToProto(bytes, ActorNameList::parseFrom));
}
public static void register(BaseActorHandle hostActor, Map keyListeners) {
init(hostActor);
if (!inited) {
return;
}
KEY_LISTENERS.putAll(keyListeners);
for (KeyType keyType : keyListeners.keySet()) {
SNAPSHOT_IDS.put(keyType, -1);
}
LOGGER.info("LongPollClient registered keys: {}.", keyListeners.keySet());
try {
pollNext();
} catch (RayTimeoutException e) {
LOGGER.info("Register poll timeout. keys:{}", keyListeners.keySet());
}
}
public static synchronized void init(BaseActorHandle hostActor) {
if (inited) {
return;
}
long intervalS = 10L;
try {
ReplicaContext replicaContext = Serve.getReplicaContext();
boolean enabled =
Optional.ofNullable(replicaContext.getConfig())
.map(config -> config.get(RayServeConfig.LONG_POOL_CLIENT_ENABLED))
.map(Boolean::valueOf)
.orElse(true);
if (!enabled) {
LOGGER.info("LongPollClient is disabled.");
return;
}
if (null == hostActor) {
hostActor = Ray.getActor(Constants.SERVE_CONTROLLER_NAME, Constants.SERVE_NAMESPACE).get();
}
intervalS =
Optional.ofNullable(replicaContext.getConfig())
.map(config -> config.get(RayServeConfig.LONG_POOL_CLIENT_INTERVAL))
.map(Long::valueOf)
.orElse(10L);
longPollTimoutS =
Optional.ofNullable(replicaContext.getConfig())
.map(config -> config.get(RayServeConfig.LONG_POOL_CLIENT_TIMEOUT_S))
.map(Long::valueOf)
.orElse(10L);
} catch (Exception e) {
LOGGER.info(
"Serve.getReplicaContext()` may only be called from within a Ray Serve deployment.");
}
Preconditions.checkNotNull(hostActor);
LongPollClientFactory.hostActor = hostActor;
scheduledExecutorService =
Executors.newSingleThreadScheduledExecutor(
new ThreadFactory() {
@Override
public Thread newThread(Runnable r) {
Thread thread = new Thread(r, "ray-serve-long-poll-client-thread");
thread.setDaemon(true);
return thread;
}
});
long finalIntervalS = intervalS;
scheduledExecutorService.scheduleWithFixedDelay(
() -> {
try {
pollNext();
} catch (RayTimeoutException e) {
LOGGER.info(
"long poll timeout in {} seconds, execute next poll after {} seconds.",
longPollTimoutS,
finalIntervalS);
} catch (RayActorException e) {
LOGGER.error("LongPollClient failed to connect to host. Shutting down.");
stop();
} catch (RayTaskException e) {
LOGGER.error("LongPollHost errored", e);
} catch (Throwable e) {
LOGGER.error("LongPollClient failed to update object of key {}", SNAPSHOT_IDS, e);
}
},
0L,
intervalS,
TimeUnit.SECONDS);
inited = true;
LOGGER.info("LongPollClient was initialized");
}
/** Poll the updates. */
@SuppressWarnings("unchecked")
public static synchronized void pollNext() {
LOGGER.info("LongPollClient polls next snapshotIds {}", SNAPSHOT_IDS);
LongPollRequest longPollRequest = new LongPollRequest(SNAPSHOT_IDS);
LongPollResult longPollResult = null;
if (hostActor instanceof PyActorHandle) {
// Poll from python controller.
ObjectRef currentRef =
((PyActorHandle) hostActor)
.task(
PyActorMethod.of(Constants.CONTROLLER_LISTEN_FOR_CHANGE_METHOD),
longPollRequest.toProtobuf().toByteArray())
.remote();
Object data = Ray.get(currentRef, longPollTimoutS * 1000);
longPollResult = LongPollResult.parseFrom((byte[]) data);
} else {
// Poll from java controller.
ObjectRef currentRef =
((ActorHandle) hostActor)
.task(ServeController::listenForChange, longPollRequest)
.remote();
longPollResult = LongPollResult.parseFrom(currentRef.get(longPollTimoutS * 1000));
}
processUpdate(longPollResult == null ? null : longPollResult.getUpdatedObjects());
}
public static void processUpdate(Map updates) {
if (updates == null || updates.isEmpty()) {
LOGGER.info("LongPollClient received nothing.");
return;
}
LOGGER.info("LongPollClient received updates for keys: {}", updates.keySet());
for (Map.Entry entry : updates.entrySet()) {
KeyType keyType = entry.getKey();
UpdatedObject updatedObject = entry.getValue();
Object objectSnapshot = updatedObject.getObjectSnapshot();
if (LOGGER.isDebugEnabled()) {
LOGGER.debug(
"The updated object for key {} is {}",
keyType,
ReflectionToStringBuilder.toString(objectSnapshot));
}
KeyListener keyListener = KEY_LISTENERS.get(entry.getKey());
if (keyListener == null) {
LOGGER.warn(
"LongPollClient has no listener for key: {}, maybe this key was garbage collected.",
entry.getKey());
continue;
}
KEY_LISTENERS.get(entry.getKey()).notifyChanged(objectSnapshot);
OBJECT_SNAPSHOTS.put(entry.getKey(), objectSnapshot);
SNAPSHOT_IDS.put(entry.getKey(), entry.getValue().getSnapshotId());
}
}
public static void unregister(Set keys) {
if (CollectionUtil.isEmpty(keys)) {
return;
}
for (KeyType keyType : keys) {
SNAPSHOT_IDS.remove(keyType);
KEY_LISTENERS.remove(keyType);
OBJECT_SNAPSHOTS.remove(keyType);
}
LOGGER.info("LongPollClient unregistered keys: {}.", keys);
}
public static synchronized void stop() {
if (!inited) {
return;
}
if (scheduledExecutorService != null) {
scheduledExecutorService.shutdown();
try {
scheduledExecutorService.awaitTermination(longPollTimoutS, TimeUnit.SECONDS);
} catch (InterruptedException e) {
LOGGER.error("awaitTermination error, the exception is ", e);
}
}
KEY_LISTENERS.clear();
OBJECT_SNAPSHOTS.clear();
SNAPSHOT_IDS.clear();
inited = false;
LOGGER.info("LongPollClient was stopped.");
}
public static boolean isInitialized() {
return inited;
}
}