All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.hadoop.hive.ql.exec.CommonMergeJoinOperator Maven / Gradle / Ivy

There is a newer version: 4.0.0
Show newest version
/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.hadoop.hive.ql.exec;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.TreeSet;
import java.util.concurrent.Future;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.ql.exec.persistence.RowContainer;
import org.apache.hadoop.hive.ql.exec.tez.RecordSource;
import org.apache.hadoop.hive.ql.exec.tez.TezContext;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.plan.CommonMergeJoinDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeDesc;
import org.apache.hadoop.hive.ql.plan.JoinCondDesc;
import org.apache.hadoop.hive.ql.plan.OperatorDesc;
import org.apache.hadoop.hive.ql.plan.api.OperatorType;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.io.WritableComparable;
import org.apache.hadoop.io.WritableComparator;

/*
 * With an aim to consolidate the join algorithms to either hash based joins (MapJoinOperator) or
 * sort-merge based joins, this operator is being introduced. This operator executes a sort-merge
 * based algorithm. It replaces both the JoinOperator and the SMBMapJoinOperator for the tez side of
 * things. It works in either the map phase or reduce phase.
 *
 * The basic algorithm is as follows:
 *
 * 1. The processOp receives a row from a "big" table.
 * 2. In order to process it, the operator does a fetch for rows from the other tables.
 * 3. Once we have a set of rows from the other tables (till we hit a new key), more rows are
 *    brought in from the big table and a join is performed.
 */

public class CommonMergeJoinOperator extends AbstractMapJoinOperator implements
    Serializable {

  private static final long serialVersionUID = 1L;
  private boolean isBigTableWork;
  private static final Log LOG = LogFactory.getLog(CommonMergeJoinOperator.class.getName());
  transient List[] keyWritables;
  transient List[] nextKeyWritables;
  transient RowContainer>[] nextGroupStorage;
  transient RowContainer>[] candidateStorage;

  transient String[] tagToAlias;
  private transient boolean[] fetchDone;
  private transient boolean[] foundNextKeyGroup;
  transient boolean firstFetchHappened = false;
  transient boolean localWorkInited = false;
  transient boolean initDone = false;
  transient List otherKey = null;
  transient List values = null;
  transient RecordSource[] sources;
  transient WritableComparator[][] keyComparators;

  transient List> originalParents =
      new ArrayList>();
  transient Set fetchInputAtClose;

  public CommonMergeJoinOperator() {
    super();
  }

  @SuppressWarnings("unchecked")
  @Override
  public Collection> initializeOp(Configuration hconf) throws HiveException {
    Collection> result = super.initializeOp(hconf);
    firstFetchHappened = false;
    fetchInputAtClose = getFetchInputAtCloseList();

    int maxAlias = 0;
    for (byte pos = 0; pos < order.length; pos++) {
      if (pos > maxAlias) {
        maxAlias = pos;
      }
    }
    maxAlias += 1;

    nextGroupStorage = new RowContainer[maxAlias];
    candidateStorage = new RowContainer[maxAlias];
    keyWritables = new ArrayList[maxAlias];
    nextKeyWritables = new ArrayList[maxAlias];
    fetchDone = new boolean[maxAlias];
    foundNextKeyGroup = new boolean[maxAlias];
    keyComparators = new WritableComparator[maxAlias][];

    for (Entry> entry : conf.getKeys().entrySet()) {
      keyComparators[entry.getKey().intValue()] = new WritableComparator[entry.getValue().size()];
    }

    int bucketSize;

    int oldVar = HiveConf.getIntVar(hconf, HiveConf.ConfVars.HIVEMAPJOINBUCKETCACHESIZE);
    if (oldVar != 100) {
      bucketSize = oldVar;
    } else {
      bucketSize = HiveConf.getIntVar(hconf, HiveConf.ConfVars.HIVESMBJOINCACHEROWS);
    }

    for (byte pos = 0; pos < order.length; pos++) {
      RowContainer> rc =
          JoinUtil.getRowContainer(hconf, rowContainerStandardObjectInspectors[pos], pos,
              bucketSize, spillTableDesc, conf, !hasFilter(pos), reporter);
      nextGroupStorage[pos] = rc;
      RowContainer> candidateRC =
          JoinUtil.getRowContainer(hconf, rowContainerStandardObjectInspectors[pos], pos,
              bucketSize, spillTableDesc, conf, !hasFilter(pos), reporter);
      candidateStorage[pos] = candidateRC;
    }

    for (byte pos = 0; pos < order.length; pos++) {
      if (pos != posBigTable) {
        fetchDone[pos] = false;
      }
      foundNextKeyGroup[pos] = false;
    }

    sources = ((TezContext) MapredContext.get()).getRecordSources();
    return result;
  }

  /*
   * In case of outer joins, we need to push records through even if one of the sides is done
   * sending records. For e.g. In the case of full outer join, the right side needs to send in data
   * for the join even after the left side has completed sending all the records on its side. This
   * can be done once at initialize time and at close, these tags will still forward records until
   * they have no more to send. Also, subsequent joins need to fetch their data as well since
   * any join following the outer join could produce results with one of the outer sides depending on
   * the join condition. We could optimize for the case of inner joins in the future here.
   */
  private Set getFetchInputAtCloseList() {
    Set retval = new TreeSet();
    for (JoinCondDesc joinCondDesc : conf.getConds()) {
      retval.add(joinCondDesc.getLeft());
      retval.add(joinCondDesc.getRight());
    }

    return retval;
  }

  @Override
  public void endGroup() throws HiveException {
    // we do not want the end group to cause a checkAndGenObject
    defaultEndGroup();
  }

  @Override
  public void startGroup() throws HiveException {
    // we do not want the start group to clear the storage
    defaultStartGroup();
  }


  /*
   * (non-Javadoc)
   *
   * @see org.apache.hadoop.hive.ql.exec.Operator#processOp(java.lang.Object,
   * int) this processor has a push-pull model. First call to this method is a
   * push but the rest is pulled until we run out of records.
   */
  @Override
  public void process(Object row, int tag) throws HiveException {
    posBigTable = (byte) conf.getBigTablePosition();

    byte alias = (byte) tag;
    List value = getFilteredValue(alias, row);
    // compute keys and values as StandardObjects
    List key = mergeJoinComputeKeys(row, alias);
    if (!firstFetchHappened) {
      firstFetchHappened = true;
      // fetch the first group for all small table aliases
      for (byte pos = 0; pos < order.length; pos++) {
        if (pos != posBigTable) {
          fetchNextGroup(pos);
        }
      }
    }

    //have we reached a new key group?
    boolean nextKeyGroup = processKey(alias, key);
    if (nextKeyGroup) {
      //assert this.nextGroupStorage[alias].size() == 0;
      this.nextGroupStorage[alias].addRow(value);
      foundNextKeyGroup[tag] = true;
      if (tag != posBigTable) {
        return;
      }
    } else {
      if ((tag == posBigTable) && (candidateStorage[tag].rowCount() == joinEmitInterval)) {
        boolean canEmit = true;
        for (byte i = 0; i < foundNextKeyGroup.length; i++) {
          if (i == posBigTable) {
            continue;
          }

          if (foundNextKeyGroup[i] == false) {
            canEmit = false;
            break;
          }

          if (compareKeys(i, key, keyWritables[i]) != 0) {
            canEmit = false;
            break;
          }
        }
        // we can save ourselves from spilling once we have join emit interval worth of rows.
        if (canEmit) {
          LOG.info("We are emitting rows since we hit the join emit interval of "
              + joinEmitInterval);
          joinOneGroup(false);
          candidateStorage[tag].clearRows();
          storage[tag].clearRows();
        }
      }
    }

    reportProgress();
    numMapRowsRead++;

    // the big table has reached a new key group. try to let the small tables
    // catch up with the big table.
    if (nextKeyGroup) {
      assert tag == posBigTable;
      List smallestPos = null;
      do {
        smallestPos = joinOneGroup();
        //jump out the loop if we need input from the big table
      } while (smallestPos != null && smallestPos.size() > 0
          && !smallestPos.contains(this.posBigTable));

      return;
    }

    assert !nextKeyGroup;
    candidateStorage[tag].addRow(value);

  }

  private List joinOneGroup() throws HiveException {
    return joinOneGroup(true);
  }

  private List joinOneGroup(boolean clear) throws HiveException {
    int[] smallestPos = findSmallestKey();
    List listOfNeedFetchNext = null;
    if (smallestPos != null) {
      listOfNeedFetchNext = joinObject(smallestPos, clear);
      if ((listOfNeedFetchNext.size() > 0) && clear) {
        // listOfNeedFetchNext contains all tables that we have joined data in their
        // candidateStorage, and we need to clear candidate storage and promote their
        // nextGroupStorage to candidateStorage and fetch data until we reach a
        // new group.
        for (Byte b : listOfNeedFetchNext) {
          try {
            fetchNextGroup(b);
          } catch (Exception e) {
            throw new HiveException(e);
          }
        }
      }
    }
    return listOfNeedFetchNext;
  }

  private List joinObject(int[] smallestPos, boolean clear) throws HiveException {
    List needFetchList = new ArrayList();
    byte index = (byte) (smallestPos.length - 1);
    for (; index >= 0; index--) {
      if (smallestPos[index] > 0 || keyWritables[index] == null) {
        putDummyOrEmpty(index);
        continue;
      }
      storage[index] = candidateStorage[index];
      if (clear) {
        needFetchList.add(index);
      }
      if (smallestPos[index] < 0) {
        break;
      }
    }
    for (index--; index >= 0; index--) {
      putDummyOrEmpty(index);
    }
    checkAndGenObject();
    if (clear) {
      for (Byte pos : needFetchList) {
        this.candidateStorage[pos].clearRows();
        this.keyWritables[pos] = null;
      }
    }
    return needFetchList;
  }

  private void putDummyOrEmpty(Byte i) {
    // put a empty list or null
    if (noOuterJoin) {
      storage[i] = emptyList;
    } else {
      storage[i] = dummyObjVectors[i];
    }
  }

  private int[] findSmallestKey() {
    int[] result = new int[order.length];
    List smallestOne = null;

    for (byte pos = 0; pos < order.length; pos++) {
      List key = keyWritables[pos];
      if (key == null) {
        continue;
      }
      if (smallestOne == null) {
        smallestOne = key;
        result[pos] = -1;
        continue;
      }
      result[pos] = compareKeys(pos, key, smallestOne);
      if (result[pos] < 0) {
        smallestOne = key;
      }
    }
    return smallestOne == null ? null : result;
  }

  private void fetchNextGroup(Byte t) throws HiveException {
    if (foundNextKeyGroup[t]) {
      // first promote the next group to be the current group if we reached a
      // new group in the previous fetch
      if (this.nextKeyWritables[t] != null) {
        promoteNextGroupToCandidate(t);
      } else {
        this.keyWritables[t] = null;
        this.candidateStorage[t] = null;
        this.nextGroupStorage[t] = null;
      }
      foundNextKeyGroup[t] = false;
    }
    // for the big table, we only need to promote the next group to the current group.
    if (t == posBigTable) {
      return;
    }

    // for tables other than the big table, we need to fetch more data until reach a new group or
    // done.
    while (!foundNextKeyGroup[t]) {
      if (fetchDone[t]) {
        break;
      }
      fetchOneRow(t);
    }
    if (!foundNextKeyGroup[t] && fetchDone[t]) {
      this.nextKeyWritables[t] = null;
    }
  }

  @Override
  public void closeOp(boolean abort) throws HiveException {
    joinFinalLeftData();

    super.closeOp(abort);

    // clean up
    for (int pos = 0; pos < order.length; pos++) {
      if (pos != posBigTable) {
        fetchDone[pos] = false;
      }
      foundNextKeyGroup[pos] = false;
    }
  }

  private void fetchOneRow(byte tag) throws HiveException {
    try {
      fetchDone[tag] = !sources[tag].pushRecord();
      if (sources[tag].isGrouped()) {
        // instead of maintaining complex state for the fetch of the next group,
        // we know for sure that at the end of all the values for a given key,
        // we will definitely reach the next key group.
        foundNextKeyGroup[tag] = true;
      }
    } catch (Exception e) {
      throw new HiveException(e);
    }
  }

  private void joinFinalLeftData() throws HiveException {
    @SuppressWarnings("rawtypes")
    RowContainer bigTblRowContainer = this.candidateStorage[this.posBigTable];

    boolean allFetchDone = allFetchDone();
    // if all left data in small tables are less than and equal to the left data
    // in big table, let's them catch up
    while (bigTblRowContainer != null && bigTblRowContainer.rowCount() > 0 && !allFetchDone) {
      joinOneGroup();
      bigTblRowContainer = this.candidateStorage[this.posBigTable];
      allFetchDone = allFetchDone();
    }

    while (!allFetchDone) {
      List ret = joinOneGroup();
      for (int i = 0; i < fetchDone.length; i++) {
        // if the fetch is not completed for the big table
        if (i == posBigTable) {
          // if we are in close op phase, we have definitely exhausted the big table input
          fetchDone[i] = true;
          continue;
        }

        // in case of outer joins, we need to pull in records from the sides we still
        // need to produce output for apart from the big table. for e.g. full outer join
        if ((fetchInputAtClose.contains(i)) && (fetchDone[i] == false)) {
          // if we have never fetched, we need to fetch before we can do the join
          if (firstFetchHappened == false) {
            // we need to fetch all the needed ones at least once to ensure bootstrapping
            if (i == (fetchDone.length - 1)) {
              firstFetchHappened = true;
            }
            // This is a bootstrap. The joinOneGroup automatically fetches the next rows.
            fetchNextGroup((byte) i);
          }
          // Do the join. It does fetching of next row groups itself.
          if (i == (fetchDone.length - 1)) {
            ret = joinOneGroup();
          }
        }
      }

      if (ret == null || ret.size() == 0) {
        break;
      }

      reportProgress();
      numMapRowsRead++;
      allFetchDone = allFetchDone();
    }

    boolean dataInCache = true;
    while (dataInCache) {
      for (byte pos = 0; pos < order.length; pos++) {
        if (this.foundNextKeyGroup[pos] && this.nextKeyWritables[pos] != null) {
          promoteNextGroupToCandidate(pos);
        }
      }
      joinOneGroup();
      dataInCache = false;
      for (byte pos = 0; pos < order.length; pos++) {
        if (candidateStorage[pos] == null) {
          continue;
        }
        if (this.candidateStorage[pos].hasRows()) {
          dataInCache = true;
          break;
        }
      }
    }
  }

  private boolean allFetchDone() {
    boolean allFetchDone = true;
    for (byte pos = 0; pos < order.length; pos++) {
      if (pos == posBigTable) {
        continue;
      }
      allFetchDone = allFetchDone && fetchDone[pos];
    }
    return allFetchDone;
  }

  private void promoteNextGroupToCandidate(Byte t) throws HiveException {
    this.keyWritables[t] = this.nextKeyWritables[t];
    this.nextKeyWritables[t] = null;
    RowContainer> oldRowContainer = this.candidateStorage[t];
    oldRowContainer.clearRows();
    this.candidateStorage[t] = this.nextGroupStorage[t];
    this.nextGroupStorage[t] = oldRowContainer;
  }

  private boolean processKey(byte alias, List key) throws HiveException {
    List keyWritable = keyWritables[alias];
    if (keyWritable == null) {
      // the first group.
      keyWritables[alias] = key;
      keyComparators[alias] = new WritableComparator[key.size()];
      return false;
    } else {
      int cmp = compareKeys(alias, key, keyWritable);
      if (cmp != 0) {
        nextKeyWritables[alias] = key;
        return true;
      }
      return false;
    }
  }

  @SuppressWarnings("rawtypes")
  private int compareKeys(byte alias, List k1, List k2) {
    final WritableComparator[] comparators = keyComparators[alias];

    // join keys have difference sizes?
    if (k1.size() != k2.size()) {
      return k1.size() - k2.size();
    }

    if (comparators.length == 0) {
      // cross-product - no keys really
      return 0;
    }

    if (comparators.length > 1) {
      // rare case
      return compareKeysMany(comparators, k1, k2);
    } else {
      return compareKey(comparators, 0,
          (WritableComparable) k1.get(0),
          (WritableComparable) k2.get(0),
          nullsafes != null ? nullsafes[0]: false);
    }
  }

  @SuppressWarnings("rawtypes")
  private int compareKeysMany(WritableComparator[] comparators,
      final List k1,
      final List k2) {
    // invariant: k1.size == k2.size
    int ret = 0;
    final int size = k1.size();
    for (int i = 0; i < size; i++) {
      WritableComparable key_1 = (WritableComparable) k1.get(i);
      WritableComparable key_2 = (WritableComparable) k2.get(i);
      ret = compareKey(comparators, i, key_1, key_2,
          nullsafes != null ? nullsafes[i] : false);
      if (ret != 0) {
        return ret;
      }
    }
    return ret;
  }

  @SuppressWarnings("rawtypes")
  private int compareKey(final WritableComparator comparators[], final int pos,
      final WritableComparable key_1,
      final WritableComparable key_2,
      final boolean nullsafe) {

    if (key_1 == null && key_2 == null) {
      if (nullsafe) {
        return 0;
      } else {
        return -1;
      }
    } else if (key_1 == null) {
      return -1;
    } else if (key_2 == null) {
      return 1;
    }

    if (comparators[pos] == null) {
      comparators[pos] = WritableComparator.get(key_1.getClass());
    }
    return comparators[pos].compare(key_1, key_2);
  }

  @SuppressWarnings("unchecked")
  private List mergeJoinComputeKeys(Object row, Byte alias) throws HiveException {
    if ((joinKeysObjectInspectors != null) && (joinKeysObjectInspectors[alias] != null)) {
      return JoinUtil.computeKeys(row, joinKeys[alias], joinKeysObjectInspectors[alias]);
    } else {
      row =
          ObjectInspectorUtils.copyToStandardObject(row, inputObjInspectors[alias],
              ObjectInspectorCopyOption.WRITABLE);
      StructObjectInspector soi = (StructObjectInspector) inputObjInspectors[alias];
      StructField sf = soi.getStructFieldRef(Utilities.ReduceField.KEY.toString());
      return (List) soi.getStructFieldData(row, sf);
    }
  }

  @Override
  public String getName() {
    return getOperatorName();
  }

  static public String getOperatorName() {
    return "MERGEJOIN";
  }

  @Override
  public OperatorType getType() {
    return OperatorType.MERGEJOIN;
  }

  @Override
  public void initializeLocalWork(Configuration hconf) throws HiveException {
    Operator parent = null;

    for (Operator parentOp : parentOperators) {
      if (parentOp != null) {
        parent = parentOp;
        break;
      }
    }

    if (parent == null) {
      throw new HiveException("No valid parents.");
    }
    Map dummyOps =
        ((TezContext) (MapredContext.get())).getDummyOpsMap();
    for (Entry connectOp : dummyOps.entrySet()) {
      if (connectOp.getValue().getChildOperators() == null
          || connectOp.getValue().getChildOperators().isEmpty()) {
        parentOperators.add(connectOp.getKey(), connectOp.getValue());
        connectOp.getValue().getChildOperators().add(this);
      }
    }
    super.initializeLocalWork(hconf);
    return;
  }

  public boolean isBigTableWork() {
    return isBigTableWork;
  }

  public void setIsBigTableWork(boolean bigTableWork) {
    this.isBigTableWork = bigTableWork;
  }

  public int getTagForOperator(Operator op) {
    return originalParents.indexOf(op);
  }

  public void cloneOriginalParentsList(List> opList) {
    originalParents.addAll(opList);
  }
}