com.vesoft.nebula.driver.graph.net.RoundRobinLoadBalancer Maven / Gradle / Ivy
The newest version!
package com.vesoft.nebula.driver.graph.net;
import com.vesoft.nebula.driver.graph.data.HostAddress;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class RoundRobinLoadBalancer implements LoadBalancer, Serializable {
private static final Logger logger = LoggerFactory.getLogger(RoundRobinLoadBalancer.class);
private static final int S_OK = 0;
private static final int S_BAD = 1;
private final List addresses = new ArrayList<>();
private final Map serversStatus = new ConcurrentHashMap<>();
private final boolean strictlyServerHealthy;
private final String userName;
private final Map authOptions;
private final AtomicInteger pos = new AtomicInteger(0);
private ScheduledExecutorService schedule;
public RoundRobinLoadBalancer(List addresses,
String userName,
Map authOptions,
boolean strictlyServerHealthy,
long healthCheckTime) {
for (HostAddress addr : addresses) {
this.addresses.add(addr);
this.serversStatus.put(addr, S_BAD);
}
this.strictlyServerHealthy = strictlyServerHealthy;
this.userName = userName;
this.authOptions = authOptions;
if (healthCheckTime > 0) {
schedule = Executors.newScheduledThreadPool(1);
schedule.scheduleAtFixedRate(this::scheduleTask, 0,
healthCheckTime, TimeUnit.MILLISECONDS);
}
}
public void close() {
if (schedule != null && !schedule.isShutdown()) {
schedule.shutdownNow();
}
}
@Override
public HostAddress getAddress() {
if (pos.get() == Integer.MAX_VALUE) {
pos.set(0);
}
int tryCount = 0;
int newPos;
while (++tryCount <= addresses.size()) {
newPos = (pos.getAndIncrement()) % addresses.size();
HostAddress addr = addresses.get(newPos);
if (serversStatus.get(addr) == S_OK) {
return addr;
}
}
return null;
}
public void updateServersStatus() {
for (HostAddress hostAddress : addresses) {
if (ping(hostAddress)) {
serversStatus.put(hostAddress, S_OK);
} else {
serversStatus.put(hostAddress, S_BAD);
}
}
}
public List getGoodAddresses() {
List goodHosts = new ArrayList<>();
for (Map.Entry server : serversStatus.entrySet()) {
if (server.getValue() == S_OK) {
goodHosts.add(server.getKey());
}
}
return goodHosts;
}
public boolean ping(HostAddress addr) {
try {
NebulaClient client = NebulaClient
.builder(addr.toString(), userName)
.withAuthOptions(authOptions)
.build();
client.close();
return true;
} catch (Exception e) {
logger.error("ping failed, ", e);
return false;
}
}
public boolean isServersOK() {
this.updateServersStatus();
int numServersWithOkStatus = 0;
int numServersWithBadStatus = 0;
for (HostAddress hostAddress : addresses) {
if (serversStatus.get(hostAddress) == S_OK) {
numServersWithOkStatus++;
} else {
numServersWithBadStatus++;
}
}
return (strictlyServerHealthy && numServersWithBadStatus == 0)
|| (!strictlyServerHealthy && numServersWithOkStatus > 0);
}
private void scheduleTask() {
updateServersStatus();
}
}