com.tencent.angel.ps.client.PSClient Maven / Gradle / Ivy
/*
* Tencent is pleased to support the open source community by making Angel available.
*
* Copyright (C) 2017-2018 THL A29 Limited, a Tencent company. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
* compliance with the License. You may obtain a copy of the License at
*
* https://opensource.org/licenses/Apache-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
* or implied. See the License for the specific language governing permissions and limitations under
* the License.
*
*/
package com.tencent.angel.ps.client;
import com.tencent.angel.PartitionKey;
import com.tencent.angel.common.location.Location;
import com.tencent.angel.common.transport.ChannelManager2;
import com.tencent.angel.common.transport.NettyChannel;
import com.tencent.angel.conf.AngelConf;
import com.tencent.angel.ps.PSContext;
import com.tencent.angel.ps.ParameterServerId;
import com.tencent.angel.ps.server.data.request.RecoverPartRequest;
import com.tencent.angel.ps.server.data.request.Request;
import com.tencent.angel.ps.server.data.request.UpdateClockRequest;
import com.tencent.angel.ps.server.data.response.Response;
import com.tencent.angel.ps.server.data.response.ResponseType;
import com.tencent.angel.ps.storage.partition.ServerPartition;
import com.tencent.angel.psagent.matrix.transport.FutureResult;
import com.tencent.angel.utils.ByteBufUtils;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.channel.*;
import io.netty.channel.epoll.EpollEventLoopGroup;
import io.netty.channel.epoll.EpollSocketChannel;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.handler.codec.LengthFieldPrepender;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
/**
* PS RPC client
*/
public class PSClient {
private static final Log LOG = LogFactory.getLog(PSClient.class);
/**
* netty client bootstrap
*/
private Bootstrap bootstrap;
/**
* netty client thread pool
*/
private EventLoopGroup eventGroup;
/**
* channel pool manager:it maintain a channel pool for every server
*/
private ChannelManager2 channelManager;
/**
* PS context
*/
private final PSContext context;
private final AtomicBoolean stopped;
/**
* Seq id generator
*/
private final AtomicInteger seqIdGen;
/**
* Request seq id to rpc result map
*/
private final ConcurrentHashMap seqIdToResultMap;
/**
* Request seq id to Request map
*/
private final ConcurrentHashMap seqIdToRequestMap;
/**
* Response message queue
*/
private final LinkedBlockingQueue responseQueue;
/**
* Is use direct buffer for netty
*/
private final boolean useDirectBuf;
/**
* Create a ps client
*
* @param context PS context
*/
public PSClient(PSContext context) {
this.context = context;
stopped = new AtomicBoolean(false);
seqIdGen = new AtomicInteger(0);
seqIdToResultMap = new ConcurrentHashMap<>();
seqIdToRequestMap = new ConcurrentHashMap<>();
responseQueue = new LinkedBlockingQueue<>();
useDirectBuf = context.getConf()
.getBoolean(AngelConf.ANGEL_NETTY_MATRIXTRANSFER_SERVER_USEDIRECTBUFFER,
AngelConf.DEFAULT_ANGEL_NETTY_MATRIXTRANSFER_SERVER_USEDIRECTBUFFER);
}
/**
* Init
*/
public void init() {
bootstrap = new Bootstrap();
Configuration conf = context.getConf();
int nettyWorkerNum = conf
.getInt(AngelConf.ANGEL_PS_HA_SYNC_WORKER_NUM, AngelConf.DEFAULT_ANGEL_PS_HA_SYNC_WORKER_NUM);
channelManager = new ChannelManager2(bootstrap, nettyWorkerNum);
int sendBuffSize = conf.getInt(AngelConf.ANGEL_PS_HA_SYNC_SEND_BUFFER_SIZE,
AngelConf.DEFAULT_ANGEL_PS_HA_SYNC_SEND_BUFFER_SIZE);
final int maxMessageSize = conf.getInt(AngelConf.ANGEL_NETTY_MATRIXTRANSFER_MAX_MESSAGE_SIZE,
AngelConf.DEFAULT_ANGEL_NETTY_MATRIXTRANSFER_MAX_MESSAGE_SIZE);
int ioRatio = conf.getInt(AngelConf.ANGEL_NETTY_MATRIXTRANSFER_CLIENT_IORATIO,
AngelConf.DEFAULT_ANGEL_NETTY_MATRIXTRANSFER_CLIENT_IORATIO);
String channelType = conf.get(AngelConf.ANGEL_NETTY_MATRIXTRANSFER_CLIENT_CHANNEL_TYPE,
AngelConf.DEFAULT_ANGEL_NETTY_MATRIXTRANSFER_CLIENT_CHANNEL_TYPE);
// Use Epoll for linux
Class channelClass;
String os = System.getProperty("os.name");
if (os != null && os.toLowerCase().startsWith("linux") && channelType.equals("epoll")) {
LOG.info("Use epoll channel");
channelClass = EpollSocketChannel.class;
eventGroup = new EpollEventLoopGroup();
((EpollEventLoopGroup) eventGroup).setIoRatio(ioRatio);
} else {
LOG.info("Use nio channel");
channelClass = NioSocketChannel.class;
eventGroup = new NioEventLoopGroup(nettyWorkerNum);
((NioEventLoopGroup) eventGroup).setIoRatio(ioRatio);
}
bootstrap.group(eventGroup).channel(channelClass).option(ChannelOption.SO_SNDBUF, sendBuffSize)
.handler(new ChannelInitializer() {
@Override protected void initChannel(SocketChannel ch) throws Exception {
ChannelPipeline pipeLine = ch.pipeline();
pipeLine.addLast(new LengthFieldBasedFrameDecoder(maxMessageSize, 0, 4, 0, 4));
pipeLine.addLast(new LengthFieldPrepender(4));
pipeLine.addLast(new PSClientHandler());
}
});
}
/**
* Start
*/
public void start() {
}
/**
* Stop all services and worker threads.
*/
public void stop() {
if (!stopped.getAndSet(true)) {
if (channelManager != null) {
channelManager.clear();
}
eventGroup.shutdownGracefully();
}
}
/**
* Response Handler
*/
class PSClientHandler extends ChannelInboundHandlerAdapter {
@Override public void channelActive(ChannelHandlerContext ctx) {
}
@Override public void channelInactive(ChannelHandlerContext ctx) throws Exception {
LOG.debug("channel " + ctx.channel() + " inactive");
}
@Override public void channelRead(ChannelHandlerContext ctx, Object msg) {
LOG.debug("receive a message " + ((ByteBuf) msg).readableBytes());
handleResponse((ByteBuf) msg);
}
}
/**
* Handle response message
*
* @param msg response message
*/
private void handleResponse(ByteBuf msg) {
try {
int seqId = msg.readInt();
// find the partition request context from cache
Request request = seqIdToRequestMap.remove(seqId);
if (request == null) {
return;
}
returnChannel(request);
FutureResult result = seqIdToResultMap.remove(seqId);
if (result == null) {
return;
}
long startTs = System.currentTimeMillis();
Response response = new Response();
response.deserialize(msg);
result.set(response);
msg.release();
if (LOG.isDebugEnabled()) {
LOG.debug(
"handle response of request " + request + " use time=" + (System.currentTimeMillis()
- startTs));
}
} catch (Throwable x) {
LOG.fatal("hanlder rpc response failed ", x);
context.getPs().failed("hanlder rpc response failed " + x.getMessage());
}
}
/**
* Send request
*
* @param serverId dest ps id
* @param location dest ps location
* @param seqId request seq id
* @param request request
* @param msg serialized request
* @param result request future result
*/
private void send(ParameterServerId serverId, Location location, int seqId, Request request,
ByteBuf msg, FutureResult result) {
if (location == null) {
String log = "server " + serverId + " location is null";
LOG.error(log);
result.set(new Response(ResponseType.SERVER_NOT_READY, log));
return;
}
long startTs = System.currentTimeMillis();
// get a channel to server from pool
NettyChannel channel;
try {
channel = getChannel(location);
// if channel is not valid, it means maybe the connections to the server are closed
if (!channel.getChannel().isActive() || !channel.getChannel().isOpen()) {
String log = "channel " + channel + " is not active";
LOG.error(log);
// channelManager.removeChannelPool(loc);
result.set(new Response(ResponseType.NETWORK_ERROR, log));
}
request.getContext().setChannel(channel);
ChannelFuture cf = channel.getChannel().writeAndFlush(msg);
cf.addListener(new RequesterChannelFutureListener(seqId, request));
} catch (Throwable x) {
if (!stopped.get()) {
LOG.error("get channel failed ", x);
}
String log = "get channel failed " + x.getMessage();
result.set(new Response(ResponseType.NETWORK_ERROR, log));
}
}
/**
* Recover a matrix partition for a ps
*
* @param serverId dest ps id
* @param location dest ps location
* @param part need recover partition
* @return recover result
*/
public FutureResult recoverPart(ParameterServerId serverId, Location location,
ServerPartition part) {
// Generate seq id
int seqId = seqIdGen.incrementAndGet();
FutureResult result = new FutureResult<>();
seqIdToResultMap.put(seqId, result);
// Create a RecoverPartRequest
PartitionKey partKey = part.getPartitionKey();
RecoverPartRequest request = new RecoverPartRequest(
context.getClockVectorManager().getClockVec(partKey.getMatrixId(), partKey.getPartitionId()),
new PartitionKey(partKey.getMatrixId(), partKey.getPartitionId()), part);
request.getContext().setServerId(serverId);
seqIdToRequestMap.put(seqId, request);
// Serialize the request
ByteBuf msg = ByteBufUtils.newByteBuf(16 + request.bufferLen(), useDirectBuf);
msg.writeInt(-1);
msg.writeInt(0);
msg.writeInt(seqId);
msg.writeInt(request.getType().getMethodId());
request.serialize(msg);
send(serverId, location, seqId, request, msg, result);
return result;
}
/**
* Put update data to another ps
*
* @param serverId dest ps id
* @param location dest ps location
* @param request update request
* @param update serialized update request
* @return update result
*/
public FutureResult put(ParameterServerId serverId, Location location, Request request,
ByteBuf update) {
// Change the seqId for the request
int seqId = seqIdGen.incrementAndGet();
changeSeqId(seqId, update);
request.getContext().setServerId(serverId);
FutureResult result = new FutureResult<>();
seqIdToResultMap.put(seqId, result);
seqIdToRequestMap.put(seqId, request);
send(serverId, location, seqId, request, update, result);
return result;
}
/**
* Update partition clock
*
* @param serverId ps id
* @param location ps location
* @param partKey partition information
* @param taskIndex task index
* @param clock clock value
* @return future result
*/
public FutureResult updateClock(ParameterServerId serverId, Location location,
PartitionKey partKey, int taskIndex, int clock) {
int seqId = seqIdGen.incrementAndGet();
UpdateClockRequest request = new UpdateClockRequest(partKey, taskIndex, clock);
FutureResult response = new FutureResult<>();
seqIdToResultMap.put(seqId, response);
seqIdToRequestMap.put(seqId, request);
// Serialize the request
ByteBuf msg = ByteBufUtils.newByteBuf(16 + request.bufferLen(), useDirectBuf);
msg.writeInt(-1);
msg.writeInt(0);
msg.writeInt(seqId);
msg.writeInt(request.getType().getMethodId());
request.serialize(msg);
send(serverId, location, seqId, request, msg, response);
return response;
}
private void changeSeqId(int seqId, ByteBuf update) {
update.setInt(8, seqId);
}
private void returnChannel(Request item) {
try {
if (item.getContext().getChannel() != null) {
channelManager.releaseChannel(item.getContext().getChannel());
}
} catch (Exception x) {
LOG.error("return channel to channel pool failed ", x);
}
}
class RequesterChannelFutureListener implements ChannelFutureListener {
private final Request request;
private final int seqId;
public RequesterChannelFutureListener(int seqId, Request request) {
this.request = request;
this.seqId = seqId;
}
@Override public void operationComplete(ChannelFuture future) throws Exception {
LOG.debug("send request " + request + " with seqId=" + seqId + " complete");
if (!future.isSuccess()) {
LOG.error("send " + seqId + " failed ", future.cause());
FutureResult result = seqIdToResultMap.remove(seqId);
returnChannel(request);
result.set(new Response(ResponseType.NETWORK_ERROR,
"send request failed " + future.cause().toString()));
}
}
}
private NettyChannel getChannel(Location loc) throws TimeoutException, InterruptedException {
return channelManager.getChannel(new Location(loc.getIp(), loc.getPort() + 1));
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy