com.alibaba.middleware.ushura.Chooser Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of eas-sdk Show documentation
Show all versions of eas-sdk Show documentation
SDK for PAI-EAS online inference services
package com.alibaba.middleware.ushura;
import com.alibaba.middleware.ushura.poller.GenericPoller;
import com.alibaba.middleware.ushura.poller.Poller;
import com.alibaba.middleware.ushura.poller.PowerOfTwoPoller;
import com.alibaba.middleware.ushura.util.ThreadLocalRandom;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
public class Chooser {
private K uniqueKey;
private volatile Ref ref;
public T random() {
List items = ref.items;
if (items.size() == 0)
return null;
if (items.size() == 1)
return items.get(0);
return items.get(ThreadLocalRandom.current().nextInt(items.size()));
}
public T randomWithWeight() {
if(ref.items.size() == 0)
return null;
Ref ref = this.ref;
double random = ThreadLocalRandom.current().nextDouble(0, 1);
int index = Arrays.binarySearch(ref.weights, random);
if (index < 0) {
index = -index - 1;
} else {
return ref.items.get(index);
}
if (index >= 0 && index < ref.weights.length) {
if (random < ref.weights[index]) {
return ref.items.get(index);
}
}
/* This should never happen, but it ensures we will return a correct
* object in case there is some floating point inequality problem
* wrt the cumulative probabilities. */
return ref.items.get(ref.items.size() - 1);
}
public T poll() {
if (ref.poller == null)
throw new IllegalStateException("You cannot call this method before you set a poller to a chooser.");
return ref.poller.next();
}
public Chooser(K uniqueKey) {
this(uniqueKey, new ArrayList>());
}
public Chooser(K uniqueKey, List> pairs) {
Ref ref = new Ref(pairs);
ref.refresh();
this.uniqueKey = uniqueKey;
this.ref = ref;
}
public Chooser poller(Poller.PollerType type) {
switch (type) {
case Generic: {
ref.poller = new GenericPoller(ref.items);
break;
}
case PowerOfTwoPoller: {
ref.poller = new PowerOfTwoPoller(ref.items);
}
default: {
ref.poller = new GenericPoller(ref.items);
}
}
return this;
}
public K getUniqueKey() {
return uniqueKey;
}
public Ref getRef() {
return ref;
}
public void refresh(List> itemsWithWeight) {
Ref newRef = new Ref(itemsWithWeight);
newRef.refresh();
newRef.poller = this.ref.poller.refresh(newRef.items);
this.ref = newRef;
}
public class Ref {
private List> itemsWithWeight = new ArrayList>();
private List items = new ArrayList();
private Poller poller = new GenericPoller(items);
private double[] weights;
@SuppressWarnings("unchecked")
public Ref(List> itemsWithWeight) {
this.itemsWithWeight = itemsWithWeight;
}
public void refresh() {
Double originWeightSum = (double) 0;
for (Pair item : itemsWithWeight) {
double weight = item.weight();
if (!(weight > 0)) //ignore item which weight is zero.see test_randomWithWeight_weight0 in ChooserTest
continue;
items.add(item.item());
if (Double.isInfinite(weight)) {
weight = 10000.0D;
}
if (Double.isNaN(weight)) {
weight = 1.0D;
}
originWeightSum += weight;
}
double[] exactWeights = new double[items.size()];
int index = 0;
for (Pair item : itemsWithWeight) {
double singleWeight = item.weight();
if(!(singleWeight > 0)) continue; //ignore item which weight is zero.see test_randomWithWeight_weight0 in ChooserTest
exactWeights[index++] = singleWeight / originWeightSum;
}
weights = new double[items.size()];
double randomRange = 0D;
for (int i = 0; i < index; i++) {
weights[i] = randomRange + exactWeights[i];
randomRange += exactWeights[i];
}
if (index != 0 && !(Math.abs(weights[index - 1] - Double.valueOf(1)) < 0.0001)) {
throw new IllegalStateException("Cumulative Weight caculate wrong , the sum of probabilities does not equals 1.");
}
}
@Override
public int hashCode() {
return itemsWithWeight.hashCode();
}
@SuppressWarnings("unchecked")
@Override
public boolean equals(Object other) {
if (this == other)
return true;
if (other == null)
return false;
if (getClass() != other.getClass())
return false;
if (!(other.getClass().getGenericInterfaces()[0].equals(this.getClass().getGenericInterfaces()[0])))
return false;
Ref otherRef = (Ref) other;
if (itemsWithWeight == null) {
if (otherRef.itemsWithWeight != null)
return false;
} else {
if (otherRef.itemsWithWeight == null)
return false;
else
return this.itemsWithWeight.equals(otherRef.itemsWithWeight);
}
return true;
}
public double[] weights() {
return weights;
}
public boolean contains(T item) {
return items.contains(item);
}
}
@Override
public int hashCode() {
return uniqueKey.hashCode();
}
@Override
public boolean equals(Object other) {
if (this == other)
return true;
if (other == null)
return false;
if (getClass() != other.getClass())
return false;
Chooser otherChooser = (Chooser) other;
if (this.uniqueKey == null) {
if (otherChooser.getUniqueKey() != null)
return false;
} else {
if (otherChooser.getUniqueKey() == null)
return false;
else {
if (!this.uniqueKey.equals(otherChooser.getUniqueKey()))
return false;
}
}
if (this.ref == null) {
if (otherChooser.getRef() != null)
return false;
} else {
if (otherChooser.getRef() == null)
return false;
else {
if (!this.ref.equals(otherChooser.getRef()))
return false;
}
}
return true;
}
}