All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.bigdl.ppml.fl.psi.PsiIntersection Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2016 The BigDL Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.intel.analytics.bigdl.ppml.fl.psi;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.*;

public class PsiIntersection {
    public final int maxCollection;
    public final int shuffleSeed;

    protected final int nThreads = Integer.parseInt(System.getProperty(
            "PsiThreads", "6"));

    protected ExecutorService pool = Executors.newFixedThreadPool(nThreads);

    public PsiIntersection(int maxCollection, int shuffleSeed) {
        this.maxCollection = maxCollection;
        this.shuffleSeed = shuffleSeed;
    }

    protected List collections = new ArrayList();
    protected List intersection;

    public int numCollection() {
        return collections.size();
    }

    public void addCollection(
            String[] collection) throws InterruptedException, ExecutionException{
        synchronized (this) {
            if (collections.size() == maxCollection) {
                throw new IllegalArgumentException("Collection is full.");
            }
            collections.add(collection);
            if (collections.size() >= maxCollection) {
                // TODO: sort by collections' size
                String[] current = collections.get(0);
                for(int i = 1; i < maxCollection - 1; i++){
                    Arrays.parallelSort(current);
                    current = findIntersection(current, collections.get(i))
                        .toArray(new String[intersection.size()]);
                }
                Arrays.parallelSort(current);
                List result = findIntersection(current, collections.get(maxCollection - 1));
                Utils.shuffle(result, shuffleSeed);
                intersection = result;
            }
        }
    }

    // Join a with b, a should be sorted.
    private static class FindIntersection implements Callable> {
        protected String[] a;
        protected String[] b;
        protected int bStart;
        protected int length;

        public FindIntersection(String[] a,
                                String[] b,
                                int bStart,
                                int length) {
            this.a = a;
            this.b = b;
            this.bStart = bStart;
            this.length = length;
        }

        @Override
        public List call() {
            return findIntersection(a, b, bStart, length);
        }

        protected static List findIntersection(
                String[] a,
                String[] b,
                int start,
                int length){
            ArrayList intersection = new ArrayList();
            for(int i = start; i < length + start; i++) {
                if (Arrays.binarySearch(a, b[i]) >= 0){
                    intersection.add(b[i]);
                }
            }
            return intersection;
        }
    }

    protected List findIntersection(
            String[] a,
            String[] b) throws InterruptedException, ExecutionException{
        int[] splitPoints = new int[nThreads + 1];
        int extractLen = b.length - nThreads * (b.length / nThreads);
        for(int i = 1; i < splitPoints.length; i++) {
            splitPoints[i] = b.length / nThreads * i;
            if (i <= extractLen) {
                splitPoints[i] += i;
            } else {
                splitPoints[i] += extractLen;
            }
        }

        Future>[] futures = new Future[nThreads];
        for(int i = 0; i < nThreads; i++) {
            futures[i] = pool.submit(new FindIntersection(a, b, splitPoints[i],
                splitPoints[i + 1] - splitPoints[i]));
        }
        List intersection = futures[0].get();
        for(int i = 1; i < nThreads; i++) {
            intersection.addAll(futures[i].get());
        }
        return intersection;
    }

    public List getIntersection() throws InterruptedException{
        synchronized (this) {
            return intersection;
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy