All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.hadoop.hive.ql.optimizer.physical.CrossProductCheck Maven / Gradle / Ivy

There is a newer version: 4.0.0
Show newest version
/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.hadoop.hive.ql.optimizer.physical;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Stack;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.AbstractMapJoinOperator;
import org.apache.hadoop.hive.ql.exec.ConditionalTask;
import org.apache.hadoop.hive.ql.exec.CommonMergeJoinOperator;
import org.apache.hadoop.hive.ql.exec.JoinOperator;
import org.apache.hadoop.hive.ql.exec.MapJoinOperator;
import org.apache.hadoop.hive.ql.exec.Operator;
import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator;
import org.apache.hadoop.hive.ql.exec.TableScanOperator;
import org.apache.hadoop.hive.ql.exec.Task;
import org.apache.hadoop.hive.ql.exec.mr.MapRedTask;
import org.apache.hadoop.hive.ql.exec.tez.TezTask;
import org.apache.hadoop.hive.ql.lib.DefaultGraphWalker;
import org.apache.hadoop.hive.ql.lib.DefaultRuleDispatcher;
import org.apache.hadoop.hive.ql.lib.Dispatcher;
import org.apache.hadoop.hive.ql.lib.GraphWalker;
import org.apache.hadoop.hive.ql.lib.Node;
import org.apache.hadoop.hive.ql.lib.NodeProcessor;
import org.apache.hadoop.hive.ql.lib.NodeProcessorCtx;
import org.apache.hadoop.hive.ql.lib.Rule;
import org.apache.hadoop.hive.ql.lib.RuleRegExp;
import org.apache.hadoop.hive.ql.lib.TaskGraphWalker;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.plan.BaseWork;
import org.apache.hadoop.hive.ql.plan.ExprNodeDesc;
import org.apache.hadoop.hive.ql.plan.MapJoinDesc;
import org.apache.hadoop.hive.ql.plan.MapWork;
import org.apache.hadoop.hive.ql.plan.MapredWork;
import org.apache.hadoop.hive.ql.plan.MergeJoinWork;
import org.apache.hadoop.hive.ql.plan.OperatorDesc;
import org.apache.hadoop.hive.ql.plan.ReduceSinkDesc;
import org.apache.hadoop.hive.ql.plan.ReduceWork;
import org.apache.hadoop.hive.ql.plan.TableScanDesc;
import org.apache.hadoop.hive.ql.plan.TezWork;
import org.apache.hadoop.hive.ql.session.SessionState;

/*
 * Check each MapJoin and ShuffleJoin Operator to see they are performing a cross product.
 * If yes, output a warning to the Session's console.
 * The Checks made are the following:
 * 1. MR, Shuffle Join:
 * Check the parent ReduceSinkOp of the JoinOp. If its keys list is size = 0, then
 * this is a cross product.
 * The parent ReduceSinkOp is in the MapWork for the same Stage.
 * 2. MR, MapJoin:
 * If the keys expr list on the mapJoin Desc is an empty list for any input,
 * this implies a cross product.
 * 3. Tez, Shuffle Join:
 * Check the parent ReduceSinkOp of the JoinOp. If its keys list is size = 0, then
 * this is a cross product.
 * The parent ReduceSinkOp checked is based on the ReduceWork.tagToInput map on the
 * reduceWork that contains the JoinOp.
 * 4. Tez, Map Join:
 * If the keys expr list on the mapJoin Desc is an empty list for any input,
 * this implies a cross product.
 */
public class CrossProductCheck implements PhysicalPlanResolver, Dispatcher {

  protected static transient final Log LOG = LogFactory
      .getLog(CrossProductCheck.class);

  @Override
  public PhysicalContext resolve(PhysicalContext pctx) throws SemanticException {
    TaskGraphWalker ogw = new TaskGraphWalker(this);

    ArrayList topNodes = new ArrayList();
    topNodes.addAll(pctx.getRootTasks());

    ogw.startWalking(topNodes, null);
    return pctx;
  }

  @Override
  public Object dispatch(Node nd, Stack stack, Object... nodeOutputs)
      throws SemanticException {
    @SuppressWarnings("unchecked")
    Task currTask = (Task) nd;
    if (currTask instanceof MapRedTask) {
      MapRedTask mrTsk = (MapRedTask)currTask;
      MapredWork mrWrk = mrTsk.getWork();
      checkMapJoins(mrTsk);
      checkMRReducer(currTask.toString(), mrWrk);
    } else if (currTask instanceof ConditionalTask ) {
      List> taskListInConditionalTask =
          ((ConditionalTask) currTask).getListTasks();
      for(Task tsk: taskListInConditionalTask){
        dispatch(tsk, stack, nodeOutputs);
      }

    } else if (currTask instanceof TezTask) {
      TezTask tzTask = (TezTask) currTask;
      TezWork tzWrk = tzTask.getWork();
      checkMapJoins(tzWrk);
      checkTezReducer(tzWrk);
    }
    return null;
  }

  private void warn(String msg) {
    SessionState.getConsole().getInfoStream().println(
        String.format("Warning: %s", msg));
  }

  private void checkMapJoins(MapRedTask mrTsk) throws SemanticException {
    MapredWork mrWrk = mrTsk.getWork();
    MapWork mapWork = mrWrk.getMapWork();
    List warnings = new MapJoinCheck(mrTsk.toString()).analyze(mapWork);
    if (!warnings.isEmpty()) {
      for (String w : warnings) {
        warn(w);
      }
    }
    ReduceWork redWork = mrWrk.getReduceWork();
    if (redWork != null) {
      warnings = new MapJoinCheck(mrTsk.toString()).analyze(redWork);
      if (!warnings.isEmpty()) {
        for (String w : warnings) {
          warn(w);
        }
      }
    }
  }

  private void checkMapJoins(TezWork tzWrk) throws SemanticException {
    for(BaseWork wrk : tzWrk.getAllWork() ) {

      if ( wrk instanceof MergeJoinWork ) {
        wrk = ((MergeJoinWork)wrk).getMainWork();
      }

      List warnings = new MapJoinCheck(wrk.getName()).analyze(wrk);
      if ( !warnings.isEmpty() ) {
        for(String w : warnings) {
          warn(w);
        }
      }
    }
  }

  private void checkTezReducer(TezWork tzWrk) throws SemanticException {
    for(BaseWork wrk : tzWrk.getAllWork() ) {

      if ( wrk instanceof MergeJoinWork ) {
        wrk = ((MergeJoinWork)wrk).getMainWork();
      }

      if ( !(wrk instanceof ReduceWork ) ) {
        continue;
      }
      ReduceWork rWork = (ReduceWork) wrk;
      Operator reducer = ((ReduceWork)wrk).getReducer();
      if ( reducer instanceof JoinOperator || reducer instanceof CommonMergeJoinOperator ) {
        Map rsInfo =
            new HashMap();
        for(Map.Entry e : rWork.getTagToInput().entrySet()) {
          rsInfo.putAll(getReducerInfo(tzWrk, rWork.getName(), e.getValue()));
        }
        checkForCrossProduct(rWork.getName(), reducer, rsInfo);
      }
    }
  }

  private void checkMRReducer(String taskName, MapredWork mrWrk) throws SemanticException {
    ReduceWork rWrk = mrWrk.getReduceWork();
    if ( rWrk == null) {
      return;
    }
    Operator reducer = rWrk.getReducer();
    if ( reducer instanceof JoinOperator|| reducer instanceof CommonMergeJoinOperator ) {
      BaseWork prntWork = mrWrk.getMapWork();
      checkForCrossProduct(taskName, reducer,
          new ExtractReduceSinkInfo(null).analyze(prntWork));
    }
  }

  private void checkForCrossProduct(String taskName,
      Operator reducer,
      Map rsInfo) {
    if ( rsInfo.isEmpty() ) {
      return;
    }
    Iterator it = rsInfo.values().iterator();
    ExtractReduceSinkInfo.Info info = it.next();
    if (info.keyCols.size() == 0) {
      List iAliases = new ArrayList();
      iAliases.addAll(info.inputAliases);
      while (it.hasNext()) {
        info = it.next();
        iAliases.addAll(info.inputAliases);
      }
      String warning = String.format(
          "Shuffle Join %s[tables = %s] in Stage '%s' is a cross product",
          reducer.toString(),
          iAliases,
          taskName);
      warn(warning);
    }
  }

  private Map getReducerInfo(TezWork tzWrk, String vertex, String prntVertex)
      throws SemanticException {
    BaseWork prntWork = tzWrk.getWorkMap().get(prntVertex);
    return new ExtractReduceSinkInfo(vertex).analyze(prntWork);
  }

  /*
   * Given a Work descriptor and the TaskName for the work
   * this is responsible to check each MapJoinOp for cross products.
   * The analyze call returns the warnings list.
   * 

* For MR the taskname is the StageName, for Tez it is the vertex name. */ public static class MapJoinCheck implements NodeProcessor, NodeProcessorCtx { final List warnings; final String taskName; MapJoinCheck(String taskName) { this.taskName = taskName; warnings = new ArrayList(); } List analyze(BaseWork work) throws SemanticException { Map opRules = new LinkedHashMap(); opRules.put(new RuleRegExp("R1", MapJoinOperator.getOperatorName() + "%"), this); Dispatcher disp = new DefaultRuleDispatcher(new NoopProcessor(), opRules, this); GraphWalker ogw = new DefaultGraphWalker(disp); ArrayList topNodes = new ArrayList(); topNodes.addAll(work.getAllRootOperators()); ogw.startWalking(topNodes, null); return warnings; } @Override public Object process(Node nd, Stack stack, NodeProcessorCtx procCtx, Object... nodeOutputs) throws SemanticException { @SuppressWarnings("unchecked") AbstractMapJoinOperator mjOp = (AbstractMapJoinOperator) nd; MapJoinDesc mjDesc = mjOp.getConf(); String bigTablAlias = mjDesc.getBigTableAlias(); if ( bigTablAlias == null ) { Operator parent = null; for(Operator op : mjOp.getParentOperators() ) { if ( op instanceof TableScanOperator ) { parent = op; } } if ( parent != null) { TableScanDesc tDesc = ((TableScanOperator)parent).getConf(); bigTablAlias = tDesc.getAlias(); } } bigTablAlias = bigTablAlias == null ? "?" : bigTablAlias; List joinExprs = mjDesc.getKeys().values().iterator().next(); if ( joinExprs.size() == 0 ) { warnings.add( String.format("Map Join %s[bigTable=%s] in task '%s' is a cross product", mjOp.toString(), bigTablAlias, taskName)); } return null; } } /* * for a given Work Descriptor, it extracts information about the ReduceSinkOps * in the Work. For Tez, you can restrict it to ReduceSinks for a particular output * vertex. */ public static class ExtractReduceSinkInfo implements NodeProcessor, NodeProcessorCtx { static class Info { List keyCols; List inputAliases; Info(List keyCols, List inputAliases) { this.keyCols = keyCols; this.inputAliases = inputAliases == null ? new ArrayList() : inputAliases; } Info(List keyCols, String[] inputAliases) { this.keyCols = keyCols; this.inputAliases = inputAliases == null ? new ArrayList() : Arrays.asList(inputAliases); } } final String outputTaskName; final Map reduceSinkInfo; ExtractReduceSinkInfo(String parentTaskName) { this.outputTaskName = parentTaskName; reduceSinkInfo = new HashMap(); } Map analyze(BaseWork work) throws SemanticException { Map opRules = new LinkedHashMap(); opRules.put(new RuleRegExp("R1", ReduceSinkOperator.getOperatorName() + "%"), this); Dispatcher disp = new DefaultRuleDispatcher(new NoopProcessor(), opRules, this); GraphWalker ogw = new DefaultGraphWalker(disp); ArrayList topNodes = new ArrayList(); topNodes.addAll(work.getAllRootOperators()); ogw.startWalking(topNodes, null); return reduceSinkInfo; } @Override public Object process(Node nd, Stack stack, NodeProcessorCtx procCtx, Object... nodeOutputs) throws SemanticException { ReduceSinkOperator rsOp = (ReduceSinkOperator) nd; ReduceSinkDesc rsDesc = rsOp.getConf(); if ( outputTaskName != null ) { String rOutputName = rsDesc.getOutputName(); if ( rOutputName == null || !outputTaskName.equals(rOutputName)) { return null; } } reduceSinkInfo.put(rsDesc.getTag(), new Info(rsDesc.getKeyCols(), rsOp.getInputAliases())); return null; } } static class NoopProcessor implements NodeProcessor { @Override public final Object process(Node nd, Stack stack, NodeProcessorCtx procCtx, Object... nodeOutputs) throws SemanticException { return nd; } } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy