All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.bigdl.ppml.fl.psi.PSIStub Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2016 The BigDL Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.intel.analytics.bigdl.ppml.fl.psi;

import com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.*;
import com.intel.analytics.bigdl.ppml.fl.generated.PSIServiceGrpc;
import com.intel.analytics.bigdl.ppml.fl.generated.PSIServiceProto.*;

import io.grpc.Channel;
import io.grpc.Internal;
import io.grpc.StatusRuntimeException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.List;
import java.util.UUID;

public class PSIStub {
    private static final Logger logger = LoggerFactory.getLogger(PSIStub.class);
    private PSIServiceGrpc.PSIServiceBlockingStub stub;
    Integer clientID;
    public PSIStub(Channel channel, Integer clientID) {
        this.clientID = clientID;
        stub = PSIServiceGrpc.newBlockingStub(channel);
    }
    protected String salt;
    protected int splitSize = 1000000;

    public String getSalt() {
        return getSalt("");
    }
    /**
     * For PSI usage only
     * To get salt from FL Server, will get a new one if its salt does not exist on server
     * @param secureCode String, secure code
     * @return String, the salt get from server
     */
    public String getSalt(String secureCode) {
        logger.info(clientID + " getting salt from PSI service");
        SaltRequest request = SaltRequest.newBuilder()
                .setSecureCode(secureCode).build();
        SaltReply response;
        try {
            response = stub.getSalt(request);
        } catch (StatusRuntimeException e) {
            throw new RuntimeException("RPC failed: " + e.getMessage());
        }
        if (!response.getSaltReply().isEmpty()) {
            salt = response.getSaltReply();
        }
        return response.getSaltReply();
    }

    /**
     * For PSI usage only
     * Upload local set to FL Server in VFL
     * @param hashedIdArray List of String, the set trained at local
     */
    public void uploadSet(List hashedIdArray) {
        int numSplit = Utils.getTotalSplitNum(hashedIdArray, splitSize);
        int split = 0;
        while (split < numSplit) {
            List splitArray = Utils.getSplit(hashedIdArray, split, numSplit, splitSize);
            UploadSetRequest request = UploadSetRequest.newBuilder()
                    .setSplit(split)
                    .setNumSplit(numSplit)
                    .setSplitLength(splitSize)
                    .setTotalLength(hashedIdArray.size())
                    .setClientId(clientID)
                    .addAllHashedID(splitArray)
                    .build();
            try {
                stub.uploadSet(request);
            } catch (StatusRuntimeException e) {
                throw new RuntimeException("RPC failed: " + e.getMessage());
            }
            split ++;
        }
    }

    /**
     * For PSI usage only
     * Download intersection from FL Server in VFL
     * @return List of String, the intersection downloaded
     */
    public List downloadIntersection() throws Exception {
        List result = new ArrayList();
        try {
            logger.info("Downloading 0th intersection");
            DownloadIntersectionRequest request = DownloadIntersectionRequest.newBuilder()
                    .setSplit(0)
                    .build();
            DownloadIntersectionResponse response = stub.downloadIntersection(request);
            if (response.getStatus() == SIGNAL.ERROR) {
                throw new Exception("Task ID does not exist on server, please upload set first.");
            }
            if (response.getStatus() == SIGNAL.EMPTY_INPUT) {
                // empty intersection, just return
                return null;
            }
            logger.info("Downloaded 0th intersection");
            result.addAll(response.getIntersectionList());
            for (int i = 1; i < response.getNumSplit(); i++) {
                request = DownloadIntersectionRequest.newBuilder()
                        .setSplit(i)
                        .build();
                logger.info("Downloading " + i + "th intersection");
                response = stub.downloadIntersection(request);
                logger.info("Downloaded " + i + "th intersection");
                result.addAll(response.getIntersectionList());
            }
            assert(result.size() == response.getTotalLength());
        } catch (StatusRuntimeException e) {
            throw new RuntimeException("RPC failed: " + e.getMessage());
        }
        return result;
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy