All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.bigdl.ppml.fl.psi.PSIServiceImpl Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2016 The BigDL Authors
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.intel.analytics.bigdl.ppml.fl.psi;

import com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.*;
import com.intel.analytics.bigdl.ppml.fl.generated.PSIServiceGrpc;
import com.intel.analytics.bigdl.ppml.fl.generated.PSIServiceProto.*;
import io.grpc.stub.StreamObserver;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;


import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;

public class PSIServiceImpl extends PSIServiceGrpc.PSIServiceImplBase {
    private static final Logger logger = LogManager.getLogger(PSIServiceImpl.class);
    // TODO Thread safe
    protected PsiIntersection psiTask;
    // This psiCollections is
    //            TaskId,   ClientId, UploadRequest
    protected Map psiCollections = new ConcurrentHashMap<>();
    int clientNum;
    String clientSalt;
    String clientSecret;
    // Stores the seed used in shuffling for each taskId
    int clientShuffleSeed = 0;
    protected int splitSize = 1000000;

    public PSIServiceImpl(int clientNum) {
        this.clientNum = clientNum;
    }
    @Override
    public void getSalt(SaltRequest req, StreamObserver responseObserver) {
        String salt;
        // Store salt
        String taskId = req.getTaskId();
        if (clientSalt != null) {
            salt = clientSalt;
        } else {
            salt = Utils.getRandomUUID();
            clientSalt = salt;
        }

        // Store secure
        if (clientSecret == null) {
            clientSecret = req.getSecureCode();
        } else if (!clientSecret.equals(req.getSecureCode())) {
            // TODO Reply empty
            salt = "";
        }
        // Store random seed for shuffling
        if (clientShuffleSeed == 0) {
            clientShuffleSeed = Utils.getRandomInt();
        }
        SaltReply reply = SaltReply.newBuilder().setSaltReply(salt).build();
        responseObserver.onNext(reply);
        responseObserver.onCompleted();
    }

    @Override
    public void uploadSet(UploadSetRequest request,
                          StreamObserver responseObserver) {
        SIGNAL signal;

        signal= SIGNAL.SUCCESS;
        int clientId = request.getClientId();
        int numSplit = request.getNumSplit();
        int splitLength = request.getSplitLength();
        int totalLength = request.getTotalLength();

        if(!psiCollections.containsKey(clientId)){
            if(psiCollections.size() >= clientNum) {
                logger.error("Too many clients, already has " +
                        psiCollections.keySet() +
                        ". The new one is " + clientId);
            }
            psiCollections.put(clientId, new String[totalLength]);
        }
        String[] collectionStorage = psiCollections.get(clientId);
        String[] ids = request.getHashedIDList().toArray(new String[request.getHashedIDList().size()]);
        int split = request.getSplit();
        // TODO: verify requests' splits are unique.
        System.arraycopy(ids, 0, collectionStorage, split * splitLength, ids.length);
        logger.info("ClientId" + clientId + ",split: " + split + ", numSplit: " + numSplit + ".");
        if (split == numSplit - 1) {
            synchronized (this) {
                try {
                    if (psiTask != null) {
                        logger.info("Adding " + (psiTask.numCollection() + 1) +
                                "th collections");
                        long st = System.currentTimeMillis();
                        psiTask.addCollection(collectionStorage);
                        logger.info("Added " + (psiTask.numCollection()) +
                                "th collections. Find Intersection time cost: " + (System.currentTimeMillis()-st) + " ms");
                    } else {
                        logger.info("Adding 1th collections.");
                        PsiIntersection pi = new PsiIntersection(clientNum,
                                clientShuffleSeed);
                        pi.addCollection(collectionStorage);
                        psiTask = pi;
                        logger.info("Added 1th collections.");
                    }
                    psiCollections.remove(clientId);
                } catch (InterruptedException | ExecutionException e){
                    logger.error(e.getMessage());
                    signal= SIGNAL.ERROR;
                } catch (IllegalArgumentException iae) {
                    logger.error("Current client ids are " + psiCollections.keySet());
                    logger.error(iae.getMessage());
                    throw iae;
                }
            }
        }


        UploadSetResponse response = UploadSetResponse.newBuilder()
                .setStatus(signal)
                .build();
        responseObserver.onNext(response);
        responseObserver.onCompleted();
    }

    @Override
    public void downloadIntersection(DownloadIntersectionRequest request,
                                     StreamObserver responseObserver) {
        SIGNAL signal = SIGNAL.SUCCESS;
        if (psiTask != null) {
            try {
                List intersection = psiTask.getIntersection();
                if (intersection == null) {
                    DownloadIntersectionResponse response = DownloadIntersectionResponse.newBuilder()
                            .setStatus(SIGNAL.EMPTY_INPUT).build();
                    responseObserver.onNext(response);
                    responseObserver.onCompleted();
                    return;
                }
                int split = request.getSplit();
                int numSplit = Utils.getTotalSplitNum(intersection, splitSize);
                List splitIntersection = Utils.getSplit(intersection, split, numSplit, splitSize);
                DownloadIntersectionResponse response = DownloadIntersectionResponse.newBuilder()
                        .setStatus(signal)
                        .setSplit(split)
                        .setNumSplit(numSplit)
                        .setTotalLength(intersection.size())
                        .setSplitLength(splitSize)
                        .addAllIntersection(splitIntersection).build();
                responseObserver.onNext(response);
                responseObserver.onCompleted();
            } catch (InterruptedException e) {
                logger.error(e.getMessage());
                signal = SIGNAL.ERROR;
                DownloadIntersectionResponse response = DownloadIntersectionResponse.newBuilder()
                        .setStatus(signal).build();
                responseObserver.onNext(response);
                responseObserver.onCompleted();
            }
        } else {
            DownloadIntersectionResponse response = DownloadIntersectionResponse.newBuilder()
                    .setStatus(SIGNAL.ERROR).build();
            responseObserver.onNext(response);
            responseObserver.onCompleted();
        }
    }


}





© 2015 - 2025 Weber Informatics LLC | Privacy Policy