All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.tencent.angel.ipc.TConnectionManager Maven / Gradle / Ivy

There is a newer version: 3.2.0
Show newest version
/*
 * Tencent is pleased to support the open source community by making Angel available.
 *
 * Copyright (C) 2017-2018 THL A29 Limited, a Tencent company. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in 
 * compliance with the License. You may obtain a copy of the License at
 *
 * https://opensource.org/licenses/Apache-2.0
 *
 * Unless required by applicable law or agreed to in writing, software distributed under the License
 * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 * or implied. See the License for the specific language governing permissions and limitations under
 * the License.
 *
 */


package com.tencent.angel.ipc;

import com.tencent.angel.Chore;
import com.tencent.angel.Stoppable;
import com.tencent.angel.conf.AngelConf;
import com.tencent.angel.exception.RemoteException;
import com.tencent.angel.io.Addressing;
import com.tencent.angel.master.MasterProtocol;
import com.tencent.angel.ps.server.control.PSProtocol;
import com.tencent.angel.worker.WorkerProtocol;
import org.apache.hadoop.conf.Configuration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.Closeable;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.util.*;
import java.util.Map.Entry;
import java.util.concurrent.ConcurrentHashMap;

public class TConnectionManager {

  private static final Logger LOG = LoggerFactory.getLogger(TConnectionManager.class);

  static final Map ML_INSTANCES =
    new LinkedHashMap();

  /**
   * Get the connection that goes with the passed conf configuration instance. If no
   * current connection exists, method creates a new connection for the passed conf
   * instance.
   */
  public static TConnection getConnection(Configuration conf) {
    TConnectionKey connectionKey = new TConnectionKey(conf);
    synchronized (ML_INSTANCES) {
      TConnectionImplementation connection = ML_INSTANCES.get(connectionKey);
      if (connection == null) {
        connection = new TConnectionImplementation(conf, true);
        ML_INSTANCES.put(connectionKey, connection);
      }
      connection.incCount();
      return connection;
    }
  }

  public static class TConnectionImplementation implements TConnection, Closeable {

    static final Logger LOG = LoggerFactory.getLogger(TConnectionImplementation.class);

    private final Class psClientClass;

    public static final String PS_CLIENT_PROTOCOL_CLASS = "ml.ps.client.protocol.class";

    /**
     * Default client protocol class name.
     */
    public static final String DEFAULT_PS_CLIENT_PROTOCOL_CLASS = PSProtocol.class.getName();

    private final Class asyncPSClientClass;

    public static final String ASYNC_PS_CLIENT_PROTOCOL_CLASS = "ml.ps.async.client.protocol.class";

    /**
     * Default ps client protocol class name.
     */
    public static final String DEFAULT_PS_ASYNC_CLIENT_PROTOCOL_CLASS =
      PSProtocol.AsyncProtocol.class.getName();

    private final Class workerClientClass;

    /**
     * Parameter name for what client protocol to use.
     */
    public static final String WORKER_CLIENT_PROTOCOL_CLASS = "ml.worker.client.protocol.class";

    /**
     * Default client protocol class name.
     */
    public static final String DEFAULT_WORKER_CLIENT_PROTOCOL_CLASS =
      WorkerProtocol.class.getName();

    private final Class masterClientClass;

    /**
     * Parameter name for what client protocol to use.
     */
    public static final String MASTER_CLIENT_PROTOCOL_CLASS = "ml.master.client.protocol.class";

    /**
     * Default client protocol class name.
     */
    public static final String DEFAULT_MASTER_CLIENT_PROTOCOL_CLASS =
      MasterProtocol.class.getName();

    private final int rpcTimeout;
    private final int maxRPCAttempts;

    private final Configuration conf;

    private volatile boolean closed;

    // indicates whether this connection's life cycle is managed
    private final boolean managed;

    private final ConcurrentHashMap> servers =
      new ConcurrentHashMap>();

    private final ConcurrentHashMap connectionLock =
      new ConcurrentHashMap();

    private boolean stopProxy = true;

    private int refCount;

    private final DelayedClosing delayedClosing = DelayedClosing.createAndStart(this);

    /**
     * constructor
     *
     * @param conf Configuration object
     */
    @SuppressWarnings("unchecked") public TConnectionImplementation(Configuration conf,
      boolean managed) {
      this.conf = conf;
      this.managed = managed;
      this.closed = false;
      try {
        String clientClassName =
          conf.get(PS_CLIENT_PROTOCOL_CLASS, DEFAULT_PS_CLIENT_PROTOCOL_CLASS);
        this.psClientClass = (Class) Class.forName(clientClassName);

        String asyncClassName =
          conf.get(ASYNC_PS_CLIENT_PROTOCOL_CLASS, DEFAULT_PS_ASYNC_CLIENT_PROTOCOL_CLASS);
        this.asyncPSClientClass =
          (Class) Class.forName(asyncClassName);

        clientClassName =
          conf.get(WORKER_CLIENT_PROTOCOL_CLASS, DEFAULT_WORKER_CLIENT_PROTOCOL_CLASS);
        this.workerClientClass = (Class) Class.forName(clientClassName);
        clientClassName =
          conf.get(MASTER_CLIENT_PROTOCOL_CLASS, DEFAULT_MASTER_CLIENT_PROTOCOL_CLASS);
        this.masterClientClass = (Class) Class.forName(clientClassName);
      } catch (ClassNotFoundException e) {
        throw new UnsupportedOperationException(e);
      }

      this.rpcTimeout = conf.getInt(AngelConf.ML_RPC_TIMEOUT_KEY, AngelConf.DEFAULT_ML_RPC_TIMEOUT);
      this.maxRPCAttempts = conf
        .getInt(AngelConf.ML_CLIENT_RPC_MAXATTEMPTS, AngelConf.DEFAULT_ML_CLIENT_RPC_MAXATTEMPTS);
    }

    @Override public Configuration getConfiguration() {
      return this.conf;
    }

    /**
     * Return if this client has no reference
     *
     * @return true if this client has no reference; false otherwise
     */
    boolean isZeroReference() {
      return refCount == 0;
    }

    /**
     * Increment this client's reference count.
     */
    void incCount() {
      ++refCount;
    }

    /**
     * Decrement this client's reference count.
     */
    void decCount() {
      if (refCount > 0) {
        --refCount;
      }
    }

    /**
     * Either the passed isa is null or hostname can be but not both.
     *
     * @param hostname
     * @param port
     * @param protocolClass
     * @param version
     * @return Proxy.
     * @throws java.io.IOException
     */
    public VersionedProtocol getProtocol(final String hostname, final int port,
      final Class protocolClass, final long version,
      List addrList4Failover) throws IOException {
      String rsName = Addressing.createHostAndPortStr(hostname, port);
      // See if we already have a connection (common case)
      Map protocols = this.servers.get(rsName);
      if (protocols == null) {
        protocols = new HashMap();
        Map existingProtocols =
          this.servers.putIfAbsent(rsName, protocols);
        if (existingProtocols != null) {
          protocols = existingProtocols;
        }
      }
      String protocol = protocolClass.getName();
      VersionedProtocol server = protocols.get(protocol);
      if (server == null) {
        // create a unique lock for this RS + protocol (if necessary)
        String lockKey = protocol + "@" + rsName;
        this.connectionLock.putIfAbsent(lockKey, lockKey);
        // get the RS lock
        synchronized (this.connectionLock.get(lockKey)) {
          // do one more lookup in case we were stalled above
          server = protocols.get(protocol);
          if (server == null) {
            try {
              // Only create isa when we need to.
              InetSocketAddress address = new InetSocketAddress(hostname, port);
              // definitely a cache miss. establish an RPC for
              // this RS
              server = MLRPC
                .waitForProxy(protocolClass, version, address, this.conf, this.maxRPCAttempts,
                  this.rpcTimeout, this.rpcTimeout, addrList4Failover);
              protocols.put(protocol, server);
            } catch (RemoteException e) {
              LOG.warn("RemoteException connecting to RS", e);
              // Throw what the RemoteException was carrying.
              throw e.unwrapRemoteException();
            }
          }
        }
      }
      return server;
    }

    @Override public boolean isClosed() {
      return this.closed;
    }

    /**
     * @return the refCount
     */
    public int getRefCount() {
      return refCount;
    }

    public void stopProxyOnClose(boolean stopProxy) {
      this.stopProxy = stopProxy;
    }

    void close(boolean stopProxy) {
      if (this.closed) {
        return;
      }
      if (LOG.isDebugEnabled()) {
        LOG.error("closing connection, stopProxy: " + stopProxy);
      }
      delayedClosing.stop("Closing connection");
      if (stopProxy) {
        for (Map i : servers.values()) {
          for (VersionedProtocol server : i.values()) {
            LOG.debug("to stop MLRPC proxy!");
            MLRPC.stopProxy(server);
          }
        }
      }
      this.servers.clear();
      this.closed = true;
    }

    @Override public void close() {
      if (LOG.isDebugEnabled()) {
        LOG.debug("connection is closing!");
      }
      if (managed) {
        TConnectionManager.deleteConnection(this, stopProxy, false);
      } else {
        close(true);
      }
    }

    private static class DelayedClosing extends Chore implements Stoppable {
      private TConnectionImplementation hci;
      Stoppable stoppable;

      private DelayedClosing(TConnectionImplementation hci, Stoppable stoppable) {
        super("ZooKeeperWatcher and Master delayed closing for connection " + hci, 60 * 1000, // We
          // check
          // every
          // minutes
          stoppable);
        this.hci = hci;
        this.stoppable = stoppable;
      }

      static DelayedClosing createAndStart(TConnectionImplementation hci) {
        Stoppable stoppable = new Stoppable() {
          private volatile boolean isStopped = false;

          @Override public void stop(String why) {
            isStopped = true;
          }

          @Override public boolean isStopped() {
            return isStopped;
          }
        };

        return new DelayedClosing(hci, stoppable);
      }

      @Override protected void chore() {
        // TODO
      }

      @Override public void stop(String why) {
        stoppable.stop(why);
      }

      @Override public boolean isStopped() {
        return stoppable.isStopped();
      }
    }

    @Override public PSProtocol getPSService(String hostname, int port) throws IOException {
      return (PSProtocol) getProtocol(hostname, port, psClientClass, 0L, null);
    }

    @Override public PSProtocol.AsyncProtocol getAsyncPSService(String hostname, int port)
      throws IOException {
      return (PSProtocol.AsyncProtocol) getProtocol(hostname, port, asyncPSClientClass, 0L, null);
    }

    @Override public WorkerProtocol getWorkerService(String hostname, int port) throws IOException {
      LOG.info("workerClientClass=" + workerClientClass.getName());
      return (WorkerProtocol) getProtocol(hostname, port, workerClientClass, 0L, null);
    }

    @Override public MasterProtocol getMasterService(String hostname, int port) throws IOException {
      return (MasterProtocol) getProtocol(hostname, port, masterClientClass, 0L, null);
    }
  }


  static class TConnectionKey {
    public static String[] CONNECTION_PROPERTIES = new String[] {};

    private Map properties;
    private String username;

    public TConnectionKey(Configuration conf) {
      Map m = new HashMap();
      if (conf != null) {
        for (String property : CONNECTION_PROPERTIES) {
          String value = conf.get(property);
          if (value != null) {
            m.put(property, value);
          }
        }
      }
      this.properties = Collections.unmodifiableMap(m);
    }

    @Override public int hashCode() {
      final int prime = 31;
      int result = 1;
      if (username != null) {
        result = username.hashCode();
      }
      for (String property : CONNECTION_PROPERTIES) {
        String value = properties.get(property);
        if (value != null) {
          result = prime * result + value.hashCode();
        }
      }

      return result;
    }

    @Override public boolean equals(Object obj) {
      if (this == obj)
        return true;
      if (obj == null)
        return false;
      if (getClass() != obj.getClass())
        return false;
      TConnectionKey that = (TConnectionKey) obj;
      if (this.username != null && !this.username.equals(that.username)) {
        return false;
      } else if (this.username == null && that.username != null) {
        return false;
      }
      if (this.properties == null) {
        if (that.properties != null) {
          return false;
        }
      } else {
        if (that.properties == null) {
          return false;
        }
        for (String property : CONNECTION_PROPERTIES) {
          String thisValue = this.properties.get(property);
          String thatValue = that.properties.get(property);
          if (thisValue == thatValue) {
            continue;
          }
          if (thisValue == null || !thisValue.equals(thatValue)) {
            return false;
          }
        }
      }
      return true;
    }

    @Override public String toString() {
      return "TConnectionKey{" + "properties=" + properties + ", username='" + username + '\''
        + '}';
    }
  }

  private static void deleteConnection(TConnection connection, boolean stopProxy,
    boolean staleConnection) {
    synchronized (ML_INSTANCES) {
      for (Entry connectionEntry : ML_INSTANCES
        .entrySet()) {
        if (connectionEntry.getValue() == connection) {
          deleteConnection(connectionEntry.getKey(), stopProxy, staleConnection);
          break;
        }
      }
    }
  }

  private static void deleteConnection(TConnectionKey connectionKey, boolean stopProxy,
    boolean staleConnection) {
    synchronized (ML_INSTANCES) {
      TConnectionImplementation connection = ML_INSTANCES.get(connectionKey);
      if (connection != null) {
        connection.decCount();
        if (connection.isZeroReference() || staleConnection) {
          LOG.debug("to remove connectionKey: ", connectionKey);
          ML_INSTANCES.remove(connectionKey);
          connection.close(stopProxy);
        } else if (stopProxy) {
          connection.stopProxyOnClose(stopProxy);
        }
      } else {
        LOG.error(
          "Connection not found in the list, can't delete it " + "(connection key=" + connectionKey
            + "). May be the key was modified?");
      }
    }
  }

  /**
   * Delete connection information for the instance specified by configuration. If there are no more
   * references to it, this will then close connection to the zookeeper ensemble and let go of all
   * resources.
   *
   * @param conf      configuration whose identity is used to find {@link TConnection} instance.
   * @param stopProxy Shuts down all the proxy's put up to cluster members including to cluster
   *                  TMaster. .
   */
  public static void deleteConnection(Configuration conf, boolean stopProxy) {
    deleteConnection(new TConnectionKey(conf), stopProxy, false);
  }

  /**
   * Delete information for all connections.
   *
   * @param stopProxy stop the proxy as well
   * @throws java.io.IOException
   */
  public static void deleteAllConnections(boolean stopProxy) {
    synchronized (ML_INSTANCES) {
      Set connectionKeys = new HashSet();
      connectionKeys.addAll(ML_INSTANCES.keySet());
      for (TConnectionKey connectionKey : connectionKeys) {
        deleteConnection(connectionKey, stopProxy, false);
      }
      ML_INSTANCES.clear();
    }
  }

  public static void shutDown() {
    MLRPC.shutDown();
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy