All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hivemall.mix.client.MixClient Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package hivemall.mix.client;

import hivemall.model.ModelUpdateHandler;
import hivemall.mix.MixMessage;
import hivemall.mix.MixMessage.MixEventName;
import hivemall.mix.MixedModel;
import hivemall.mix.MixedWeight;
import hivemall.mix.NodeInfo;
import hivemall.utils.hadoop.HadoopUtils;
import io.netty.bootstrap.Bootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;

import java.io.Closeable;
import java.io.IOException;
import java.net.SocketAddress;
import java.util.HashMap;
import java.util.Map;

import javax.annotation.CheckForNull;
import javax.annotation.Nonnull;
import javax.net.ssl.SSLException;

public final class MixClient implements ModelUpdateHandler, Closeable {
    public static final String DUMMY_JOB_ID = "__DUMMY_JOB_ID__";

    private final MixEventName event;
    private String groupID;
    private final boolean ssl;
    private final int mixThreshold;
    private final MixRequestRouter router;
    private final MixClientHandler msgHandler;
    private final Map channelMap;

    private boolean initialized = false;
    private EventLoopGroup workers;

    public MixClient(@Nonnull MixEventName event, @CheckForNull String groupID,
            @Nonnull String connectURIs, boolean ssl, int mixThreshold, @Nonnull MixedModel model) {
        if (groupID == null) {
            throw new IllegalArgumentException("groupID is null");
        }
        if (mixThreshold < 1 || mixThreshold > Byte.MAX_VALUE) {
            throw new IllegalArgumentException("Invalid mixThreshold: " + mixThreshold);
        }
        this.event = event;
        this.groupID = groupID;
        this.router = new MixRequestRouter(connectURIs);
        this.ssl = ssl;
        this.mixThreshold = mixThreshold;
        this.msgHandler = new MixClientHandler(model);
        this.channelMap = new HashMap();
    }

    private void initialize() throws Exception {
        EventLoopGroup workerGroup = new NioEventLoopGroup();
        NodeInfo[] serverNodes = router.getAllNodes();
        for (NodeInfo node : serverNodes) {
            Bootstrap b = new Bootstrap();
            configureBootstrap(b, workerGroup, node);
        }
        this.workers = workerGroup;
        this.initialized = true;
    }

    private void configureBootstrap(Bootstrap b, EventLoopGroup workerGroup, NodeInfo server)
            throws SSLException, InterruptedException {
        // Configure SSL.
        final SslContext sslCtx;
        if (ssl) {
            sslCtx = SslContext.newClientContext(InsecureTrustManagerFactory.INSTANCE);
        } else {
            sslCtx = null;
        }

        b.group(workerGroup);
        b.option(ChannelOption.SO_KEEPALIVE, true);
        b.option(ChannelOption.TCP_NODELAY, true);
        b.channel(NioSocketChannel.class);
        b.handler(new MixClientInitializer(msgHandler, sslCtx));

        SocketAddress remoteAddr = server.getSocketAddress();
        ChannelFuture channelFuture = b.connect(remoteAddr).sync();
        Channel channel = channelFuture.channel();

        channelMap.put(server, channel);
    }

    /**
     * @return true if sent request, otherwise false
     */
    @Override
    public boolean onUpdate(Object feature, float weight, float covar, short clock,
            int deltaUpdates) throws Exception {
        assert (deltaUpdates > 0) : deltaUpdates;
        if (deltaUpdates < mixThreshold) {
            return false; // avoid mixing
        }

        if (!initialized) {
            replaceGroupIDIfRequired();
            initialize(); // initialize connections to mix servers
        }

        MixMessage msg = new MixMessage(event, feature, weight, covar, clock, deltaUpdates);
        msg.setGroupID(groupID);

        NodeInfo server = router.selectNode(msg);
        Channel ch = channelMap.get(server);
        if (!ch.isActive()) {// reconnect
            SocketAddress remoteAddr = server.getSocketAddress();
            ch.connect(remoteAddr).sync();
        }

        //ch.writeAndFlush(msg).sync();
        ch.writeAndFlush(msg); // send asynchronously in the background
        return true;
    }

    @Override
    public void sendCancelRequest(@Nonnull Object feature, @Nonnull MixedWeight mixed)
            throws Exception {
        assert (initialized);

        float weight = mixed.getWeight();
        float covar = mixed.getCovar();
        int deltaUpdates = mixed.getDeltaUpdates();

        MixMessage msg = new MixMessage(event, feature, weight, covar, deltaUpdates, true);
        assert (groupID != null);
        msg.setGroupID(groupID);

        // TODO REVIEWME consider mix server faults (what if mix server dead? Do not send cancel request?)
        NodeInfo server = router.selectNode(msg);
        Channel ch = channelMap.get(server);
        if (!ch.isActive()) {// reconnect
            SocketAddress remoteAddr = server.getSocketAddress();
            ch.connect(remoteAddr).sync();
        }

        ch.writeAndFlush(msg); // send asynchronously in the background
    }

    private void replaceGroupIDIfRequired() {
        if (groupID.startsWith(DUMMY_JOB_ID)) {
            String jobId = HadoopUtils.getJobId();
            this.groupID = groupID.replace(DUMMY_JOB_ID, jobId);
        }
    }

    @Override
    public void close() throws IOException {
        if (workers != null) {
            for (Channel ch : channelMap.values()) {
                ch.close();
            }
            channelMap.clear();
            workers.shutdownGracefully();
            this.workers = null;
        }
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy