com.alibaba.dashscope.protocol.pool.WebsocketPool Maven / Gradle / Ivy
package com.alibaba.dashscope.protocol.pool;
import com.alibaba.dashscope.protocol.WebsocketRpc;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import javax.xml.bind.DatatypeConverter;
import lombok.Getter;
import lombok.Setter;
import lombok.var;
import org.apache.commons.pool2.BasePooledObjectFactory;
import org.apache.commons.pool2.PooledObject;
import org.apache.commons.pool2.impl.DefaultPooledObject;
import org.apache.commons.pool2.impl.GenericObjectPool;
import org.apache.commons.pool2.impl.GenericObjectPoolConfig;
public class WebsocketPool {
static class WsPoolFactory extends BasePooledObjectFactory {
private final String url;
private final Map headers;
public WsPoolFactory(String url, Map headers) {
this.url = url;
this.headers = headers;
}
@Override
public WebsocketRpc create() {
return new WebsocketRpc(url, headers);
}
@Override
public PooledObject wrap(WebsocketRpc o) {
return new DefaultPooledObject<>(o);
}
}
private static volatile WebsocketPool instance;
private final Map> wsPoolFactoryMap =
Maps.newConcurrentMap();
private final Set borrowed = Sets.newConcurrentHashSet();
@Getter @Setter private int maxIdle = 8;
@Getter @Setter private int maxTotal = 8;
public static WebsocketPool getInstance() {
if (instance == null) {
synchronized (HttpPool.class) {
if (instance == null) {
instance = new WebsocketPool();
}
}
}
return instance;
}
private WebsocketPool() {}
@Override
protected void finalize() throws Throwable {
for (GenericObjectPool pool : wsPoolFactoryMap.values()) {
pool.close();
}
borrowed.clear();
super.finalize();
}
public static String calculateMd5(Map headers) {
try {
List> entries = new ArrayList<>(headers.entrySet());
entries.sort(Map.Entry.comparingByKey());
StringBuilder sb = new StringBuilder();
sb.append("{");
for (Map.Entry entry : entries) {
if (sb.length() > 1) {
sb.append(", ");
}
sb.append(entry.getKey()).append("=").append(entry.getValue());
}
sb.append("}");
String mapString = sb.toString();
MessageDigest md = MessageDigest.getInstance("MD5");
byte[] hash = md.digest(mapString.getBytes());
return DatatypeConverter.printHexBinary(hash).toLowerCase();
} catch (NoSuchAlgorithmException e) {
throw new RuntimeException(e);
}
}
/**
* Get a rpc client from the pool.
*
* @param url The connecting url of the websocket client.
* @param headers The headers to use for connection.
* @return The websocket rpc client.
* @throws Exception Error occurs.
*/
public WebsocketRpc getWsClient(String url, Map headers) throws Exception {
String key = url + calculateMd5(headers);
if (!wsPoolFactoryMap.containsKey(key)) {
synchronized (wsPoolFactoryMap) {
if (!wsPoolFactoryMap.containsKey(key)) {
GenericObjectPoolConfig config = new GenericObjectPoolConfig<>();
config.setMaxTotal(maxTotal);
config.setMaxIdle(maxIdle);
wsPoolFactoryMap.put(
key, new GenericObjectPool<>(new WsPoolFactory(url, headers), config));
}
}
}
var websocketRpc = wsPoolFactoryMap.get(key).borrowObject();
borrowed.add(websocketRpc);
return websocketRpc;
}
/**
* Return the borrowed object to the pool.
*
* @param websocketRpc The borrowed websocket rpc client.
*/
public void returnWsClient(WebsocketRpc websocketRpc) {
borrowed.remove(websocketRpc);
wsPoolFactoryMap
.get(websocketRpc.getUrl() + calculateMd5(websocketRpc.getHeaders()))
.returnObject(websocketRpc);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy