All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.hazelcast.jet.impl.pipeline.transform.HashJoinTransform Maven / Gradle / Ivy

There is a newer version: 5.5.0
Show newest version
/*
 * Copyright (c) 2008-2023, Hazelcast, Inc. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.hazelcast.jet.impl.pipeline.transform;

import com.hazelcast.function.BiFunctionEx;
import com.hazelcast.function.FunctionEx;
import com.hazelcast.jet.core.Edge;
import com.hazelcast.jet.core.Vertex;
import com.hazelcast.jet.datamodel.ItemsByTag;
import com.hazelcast.jet.datamodel.Tag;
import com.hazelcast.jet.function.TriFunction;
import com.hazelcast.jet.impl.pipeline.PipelineImpl.Context;
import com.hazelcast.jet.impl.pipeline.Planner;
import com.hazelcast.jet.impl.pipeline.Planner.PlannerVertex;
import com.hazelcast.jet.impl.processor.HashJoinCollectP;
import com.hazelcast.jet.impl.processor.HashJoinP;
import com.hazelcast.jet.pipeline.JoinClause;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.util.List;

import static com.hazelcast.jet.core.Edge.from;
import static com.hazelcast.jet.core.Vertex.LOCAL_PARALLELISM_USE_DEFAULT;
import static com.hazelcast.jet.impl.pipeline.Planner.applyRebalancing;
import static com.hazelcast.jet.impl.pipeline.Planner.tailList;
import static com.hazelcast.jet.impl.util.Util.toList;

@SuppressWarnings("rawtypes")
public class HashJoinTransform extends AbstractTransform {

    private static final long serialVersionUID = 1L;

    @Nonnull
    private final List> clauses;
    @Nonnull
    private final List tags;
    @Nullable
    private final BiFunctionEx mapToOutputBiFn;
    @Nullable
    private final TriFunction mapToOutputTriFn;
    @Nullable
    private final List whereNullsNotAllowed;

    public HashJoinTransform(
            @Nonnull List upstream,
            @Nonnull List> clauses,
            @Nonnull List tags,
            @Nonnull BiFunctionEx mapToOutputBiFn
    ) {
        super(upstream.size() + "-way hash-join", upstream);
        this.clauses = clauses;
        this.tags = tags;
        this.mapToOutputBiFn = mapToOutputBiFn;
        this.mapToOutputTriFn = null;
        this.whereNullsNotAllowed = null;
    }

    public  HashJoinTransform(
            @Nonnull List upstream,
            @Nonnull List> clauses,
            @Nonnull List tags,
            @Nonnull TriFunction mapToOutputTriFn
    ) {
        super(upstream.size() + "-way hash-join", upstream);
        this.clauses = clauses;
        this.tags = tags;
        this.mapToOutputBiFn = null;
        this.mapToOutputTriFn = mapToOutputTriFn;
        this.whereNullsNotAllowed = null;
    }
    public HashJoinTransform(
            @Nonnull List upstream,
            @Nonnull List> clauses,
            @Nonnull List tags,
            @Nonnull BiFunctionEx mapToOutputBiFn,
            @Nonnull List whereNullsNotAllowed
    ) {
        super(upstream.size() + "-way hash-join", upstream);
        this.clauses = clauses;
        this.tags = tags;
        this.mapToOutputBiFn = mapToOutputBiFn;
        this.mapToOutputTriFn = null;
        this.whereNullsNotAllowed = whereNullsNotAllowed;
    }

    //         ---------           ----------           ----------
    //        | primary |         | joined-1 |         | joined-2 |
    //         ---------           ----------           ----------
    //             |                   |                     |
    //             |              distributed          distributed
    //             |               broadcast            broadcast
    //             |                   v                     v
    //             |             -------------         -------------
    //             |            | collector-1 |       | collector-2 |
    //             |            | localPara=1 |       | localPara=1 |
    //             |             -------------         -------------
    //             |                   |                     |
    //             |                 local                 local
    //           local             broadcast             broadcast
    //          unicast           prioritized           prioritized
    //         ordinal 0           ordinal 1             ordinal 2
    //             \                   |                     |
    //              ----------------\  |   /----------------/
    //                              v  v  v
    //                              --------
    //                             | joiner |
    //                              --------
    @Override
    @SuppressWarnings("unchecked")
    public void addToDag(Planner p, Context context) {
        determineLocalParallelism(LOCAL_PARALLELISM_USE_DEFAULT, context, p.isPreserveOrder());
        PlannerVertex primary = p.transform2vertex.get(this.upstream().get(0));
        List keyFns = toList(this.clauses, JoinClause::leftKeyFn);

        List tags = this.tags;
        BiFunctionEx mapToOutputBiFn = this.mapToOutputBiFn;
        TriFunction mapToOutputTriFn = this.mapToOutputTriFn;

        // must be extracted to variable, probably because of serialization bug
        BiFunctionEx, Object[], ItemsByTag> tupleToItems = tupleToItemsByTag(whereNullsNotAllowed);

        Vertex joiner = p.addVertex(this, name() + "-joiner", determinedLocalParallelism(),
                () -> new HashJoinP<>(keyFns, tags, mapToOutputBiFn, mapToOutputTriFn, tupleToItems)).v;
        Edge edgeToJoiner = from(primary.v, primary.nextAvailableOrdinal()).to(joiner, 0);
        if (p.isPreserveOrder()) {
            edgeToJoiner.isolated();
        } else {
            applyRebalancing(edgeToJoiner, this);
        }
        p.dag.edge(edgeToJoiner);

        String collectorName = name() + "-collector";
        int collectorOrdinal = 1;
        for (Transform fromTransform : tailList(this.upstream())) {
            PlannerVertex fromPv = p.transform2vertex.get(fromTransform);
            JoinClause clause = this.clauses.get(collectorOrdinal - 1);
            FunctionEx getKeyFn = (FunctionEx) clause.rightKeyFn();
            FunctionEx projectFn = (FunctionEx) clause.rightProjectFn();
            Vertex collector = p.dag.newVertex(collectorName + collectorOrdinal,
                    () -> new HashJoinCollectP(getKeyFn, projectFn));
            collector.localParallelism(1);
            p.dag.edge(from(fromPv.v, fromPv.nextAvailableOrdinal())
                    .to(collector, 0)
                    .distributed().broadcast());
            p.dag.edge(from(collector, 0)
                    .to(joiner, collectorOrdinal)
                    .broadcast().priority(-1));
            collectorOrdinal++;
        }
    }

    private static BiFunctionEx, Object[], ItemsByTag> tupleToItemsByTag(List nullsNotAllowed) {
        return (tagList, tuple) -> {
            ItemsByTag res = new ItemsByTag();
            for (int i = 0; i < tagList.size(); i++) {
                if (tuple[i] == null && nullsNotAllowed.get(i)) {
                    return null;
                }
                res.put(tagList.get(i), tuple[i]);
            }
            return res;
        };
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy