All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.datavec.local.transforms.join.ExecuteJoinFromCoGroupFlatMapFunctionAdapter Maven / Gradle / Ivy

/*******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package org.datavec.local.transforms.join;

import com.google.common.collect.Iterables;
import org.datavec.api.transform.join.Join;
import org.datavec.api.writable.Writable;
import org.datavec.local.transforms.functions.FlatMapFunctionAdapter;
import org.nd4j.linalg.primitives.Pair;

import java.util.ArrayList;
import java.util.List;

/**
 * Execute a join
 *
 * @author Alex Black
 */
public class ExecuteJoinFromCoGroupFlatMapFunctionAdapter implements
        FlatMapFunctionAdapter, Pair>, List>>>, List> {

    private final Join join;

    public ExecuteJoinFromCoGroupFlatMapFunctionAdapter(Join join) {
        this.join = join;
    }

    @Override
    public List> call(
                    Pair, Pair>, List>>> t2)
                    throws Exception {

        Iterable> leftList = t2.getSecond().getFirst();
        Iterable> rightList = t2.getSecond().getSecond();

        List> ret = new ArrayList<>();
        Join.JoinType jt = join.getJoinType();
        switch (jt) {
            case Inner:
                //Return records where key columns appear in BOTH
                //So if no values from left OR right: no return values
                for (List jvl : leftList) {
                    for (List jvr : rightList) {
                        List joined = join.joinExamples(jvl, jvr);
                        ret.add(joined);
                    }
                }
                break;
            case LeftOuter:
                //Return all records from left, even if no corresponding right value (NullWritable in that case)
                for (List jvl : leftList) {
                    if (Iterables.size(rightList) == 0) {
                        List joined = join.joinExamples(jvl, null);
                        ret.add(joined);
                    } else {
                        for (List jvr : rightList) {
                            List joined = join.joinExamples(jvl, jvr);
                            ret.add(joined);
                        }
                    }
                }
                break;
            case RightOuter:
                //Return all records from right, even if no corresponding left value (NullWritable in that case)
                for (List jvr : rightList) {
                    if (Iterables.size(leftList) == 0) {
                        List joined = join.joinExamples(null, jvr);
                        ret.add(joined);
                    } else {
                        for (List jvl : leftList) {
                            List joined = join.joinExamples(jvl, jvr);
                            ret.add(joined);
                        }
                    }
                }
                break;
            case FullOuter:
                //Return all records, even if no corresponding left/right value (NullWritable in that case)
                if (Iterables.size(leftList) == 0) {
                    //Only right values
                    for (List jvr : rightList) {
                        List joined = join.joinExamples(null, jvr);
                        ret.add(joined);
                    }
                } else if (Iterables.size(rightList) == 0) {
                    //Only left values
                    for (List jvl : leftList) {
                        List joined = join.joinExamples(jvl, null);
                        ret.add(joined);
                    }
                } else {
                    //Records from both left and right
                    for (List jvl : leftList) {
                        for (List jvr : rightList) {
                            List joined = join.joinExamples(jvl, jvr);
                            ret.add(joined);
                        }
                    }
                }
                break;
        }

        return ret;
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy