All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.alibaba.dashscope.protocol.pool.WebsocketPool Maven / Gradle / Ivy

package com.alibaba.dashscope.protocol.pool;

import com.alibaba.dashscope.protocol.WebsocketRpc;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import javax.xml.bind.DatatypeConverter;
import lombok.Getter;
import lombok.Setter;
import lombok.var;
import org.apache.commons.pool2.BasePooledObjectFactory;
import org.apache.commons.pool2.PooledObject;
import org.apache.commons.pool2.impl.DefaultPooledObject;
import org.apache.commons.pool2.impl.GenericObjectPool;
import org.apache.commons.pool2.impl.GenericObjectPoolConfig;

public class WebsocketPool {

  static class WsPoolFactory extends BasePooledObjectFactory {

    private final String url;

    private final Map headers;

    public WsPoolFactory(String url, Map headers) {
      this.url = url;
      this.headers = headers;
    }

    @Override
    public WebsocketRpc create() {
      return new WebsocketRpc(url, headers);
    }

    @Override
    public PooledObject wrap(WebsocketRpc o) {
      return new DefaultPooledObject<>(o);
    }
  }

  private static volatile WebsocketPool instance;

  private final Map> wsPoolFactoryMap =
      Maps.newConcurrentMap();

  private final Set borrowed = Sets.newConcurrentHashSet();

  @Getter @Setter private int maxIdle = 8;

  @Getter @Setter private int maxTotal = 8;

  public static WebsocketPool getInstance() {
    if (instance == null) {
      synchronized (HttpPool.class) {
        if (instance == null) {
          instance = new WebsocketPool();
        }
      }
    }
    return instance;
  }

  private WebsocketPool() {}

  @Override
  protected void finalize() throws Throwable {
    for (GenericObjectPool pool : wsPoolFactoryMap.values()) {
      pool.close();
    }
    borrowed.clear();
    super.finalize();
  }

  public static String calculateMd5(Map headers) {
    try {
      List> entries = new ArrayList<>(headers.entrySet());
      entries.sort(Map.Entry.comparingByKey());
      StringBuilder sb = new StringBuilder();
      sb.append("{");
      for (Map.Entry entry : entries) {
        if (sb.length() > 1) {
          sb.append(", ");
        }
        sb.append(entry.getKey()).append("=").append(entry.getValue());
      }
      sb.append("}");
      String mapString = sb.toString();
      MessageDigest md = MessageDigest.getInstance("MD5");
      byte[] hash = md.digest(mapString.getBytes());
      return DatatypeConverter.printHexBinary(hash).toLowerCase();
    } catch (NoSuchAlgorithmException e) {
      throw new RuntimeException(e);
    }
  }

  /**
   * Get a rpc client from the pool.
   *
   * @param url The connecting url of the websocket client.
   * @param headers The headers to use for connection.
   * @return The websocket rpc client.
   * @throws Exception Error occurs.
   */
  public WebsocketRpc getWsClient(String url, Map headers) throws Exception {
    String key = url + calculateMd5(headers);
    if (!wsPoolFactoryMap.containsKey(key)) {
      synchronized (wsPoolFactoryMap) {
        if (!wsPoolFactoryMap.containsKey(key)) {
          GenericObjectPoolConfig config = new GenericObjectPoolConfig<>();
          config.setMaxTotal(maxTotal);
          config.setMaxIdle(maxIdle);
          wsPoolFactoryMap.put(
              key, new GenericObjectPool<>(new WsPoolFactory(url, headers), config));
        }
      }
    }
    var websocketRpc = wsPoolFactoryMap.get(key).borrowObject();
    borrowed.add(websocketRpc);
    return websocketRpc;
  }

  /**
   * Return the borrowed object to the pool.
   *
   * @param websocketRpc The borrowed websocket rpc client.
   */
  public void returnWsClient(WebsocketRpc websocketRpc) {
    borrowed.remove(websocketRpc);
    wsPoolFactoryMap
        .get(websocketRpc.getUrl() + calculateMd5(websocketRpc.getHeaders()))
        .returnObject(websocketRpc);
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy