org.kaizen4j.common.algorithm.weighted.WeightedRouter Maven / Gradle / Ivy
package org.kaizen4j.common.algorithm.weighted;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.CollectionUtils;
import java.util.List;
public final class WeightedRouter {
private static Logger logger = LoggerFactory.getLogger(WeightedRouter.class);
private final List weightedServers;
public WeightedRouter(List servers) {
Preconditions.checkArgument(!CollectionUtils.isEmpty(servers));
this.weightedServers = Lists.newArrayList(servers.iterator());
}
/**
* 获取下一个权重服务器
*
* @return WeightedServer 匹配的权重服务器
*/
public WeightedServer next() {
if (weightedServers.isEmpty()) {
throw new RuntimeException("Not found any weighted server");
} else if (1 == weightedServers.size()) {
return weightedServers.get(0);
}
return nextWeighted();
}
/**
* 重置服务器的权重
*/
public synchronized void reset() {
weightedServers.forEach(server -> {
server.setEffectiveWeight(server.getWeight());
server.setCurrentWeight(0);
});
}
/**
* 移除权重服务器列表
*/
public synchronized void removeAll() {
weightedServers.clear();
}
/**
* 该算法基于 nginx 的权重算法,请参考如下链接:
*
* @see "https://github.com/phusion/nginx/commit/27e94984486058d73157038f7950a0a36ecc6e35"
*/
private WeightedServer nextWeighted() {
int total = 0;
WeightedServer bestServer = null;
for (WeightedServer server : weightedServers) {
int effectiveWeight = server.getEffectiveWeight();
int currentWeight = server.addAndGetCurrentWeight(effectiveWeight);
total += effectiveWeight;
if (effectiveWeight < server.getWeight()) {
server.addAndGetEffectiveWeight(1);
}
if (null == bestServer || currentWeight > bestServer.getCurrentWeight()) {
bestServer = server;
}
}
if (null == bestServer) {
throw new RuntimeException("Not match any weighted server");
}
bestServer.addAndGetCurrentWeight(-total);
return bestServer;
}
}