All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.hadoop.hive.ql.ppd.SyntheticJoinPredicate Maven / Gradle / Ivy

/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.hadoop.hive.ql.ppd;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.Stack;

import com.facebook.presto.hive.$internal.org.apache.commons.logging.Log;
import com.facebook.presto.hive.$internal.org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.conf.HiveConf.ConfVars;
import org.apache.hadoop.hive.ql.exec.CommonJoinOperator;
import org.apache.hadoop.hive.ql.exec.FunctionRegistry;
import org.apache.hadoop.hive.ql.exec.JoinOperator;
import org.apache.hadoop.hive.ql.exec.Operator;
import org.apache.hadoop.hive.ql.exec.OperatorFactory;
import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator;
import org.apache.hadoop.hive.ql.exec.RowSchema;
import org.apache.hadoop.hive.ql.exec.TableScanOperator;
import org.apache.hadoop.hive.ql.lib.DefaultRuleDispatcher;
import org.apache.hadoop.hive.ql.lib.Dispatcher;
import org.apache.hadoop.hive.ql.lib.GraphWalker;
import org.apache.hadoop.hive.ql.lib.Node;
import org.apache.hadoop.hive.ql.lib.NodeProcessor;
import org.apache.hadoop.hive.ql.lib.NodeProcessorCtx;
import org.apache.hadoop.hive.ql.lib.PreOrderWalker;
import org.apache.hadoop.hive.ql.lib.Rule;
import org.apache.hadoop.hive.ql.lib.RuleRegExp;
import org.apache.hadoop.hive.ql.optimizer.Transform;
import org.apache.hadoop.hive.ql.parse.ParseContext;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.plan.ExprNodeDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeDynamicListDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeGenericFuncDesc;
import org.apache.hadoop.hive.ql.plan.FilterDesc;
import org.apache.hadoop.hive.ql.plan.JoinCondDesc;
import org.apache.hadoop.hive.ql.plan.JoinDesc;
import org.apache.hadoop.hive.ql.plan.OperatorDesc;

/**
 * creates synthetic predicates that represent "IN (keylist other table)"
 */
public class SyntheticJoinPredicate implements Transform {

  private static transient Log LOG = LogFactory.getLog(SyntheticJoinPredicate.class.getName());

  @Override
  public ParseContext transform(ParseContext pctx) throws SemanticException {

    if (!pctx.getConf().getVar(ConfVars.HIVE_EXECUTION_ENGINE).equals("tez")
        || !pctx.getConf().getBoolVar(ConfVars.TEZ_DYNAMIC_PARTITION_PRUNING)) {
      return pctx;
    }

    Map opRules = new LinkedHashMap();
    opRules.put(new RuleRegExp("R1", "(" +
        TableScanOperator.getOperatorName() + "%" + ".*" +
        ReduceSinkOperator.getOperatorName() + "%" +
        JoinOperator.getOperatorName() + "%)"), new JoinSynthetic());

    // The dispatcher fires the processor corresponding to the closest matching
    // rule and passes the context along
    SyntheticContext context = new SyntheticContext(pctx);
    Dispatcher disp = new DefaultRuleDispatcher(null, opRules, context);
    GraphWalker ogw = new PreOrderWalker(disp);

    // Create a list of top op nodes
    List topNodes = new ArrayList();
    topNodes.addAll(pctx.getTopOps().values());
    ogw.startWalking(topNodes, null);

    return pctx;
  }

  // insert filter operator between target(child) and input(parent)
  private static Operator createFilter(Operator target, Operator parent,
      RowSchema parentRS, ExprNodeDesc filterExpr) {
    Operator filter = OperatorFactory.get(new FilterDesc(filterExpr, false),
        new RowSchema(parentRS.getSignature()));
    filter.getParentOperators().add(parent);
    filter.getChildOperators().add(target);
    parent.replaceChild(target, filter);
    target.replaceParent(parent, filter);
    return filter;
  }

  private static class SyntheticContext implements NodeProcessorCtx {

    ParseContext parseContext;

    public SyntheticContext(ParseContext pCtx) {
      parseContext = pCtx;
    }

    public ParseContext getParseContext() {
      return parseContext;
    }
  }

  private static class JoinSynthetic implements NodeProcessor {
    @Override
    public Object process(Node nd, Stack stack, NodeProcessorCtx procCtx,
        Object... nodeOutputs) throws SemanticException {

      ParseContext pCtx = ((SyntheticContext) procCtx).getParseContext();

      @SuppressWarnings("unchecked")
      CommonJoinOperator join = (CommonJoinOperator) nd;

      ReduceSinkOperator source = (ReduceSinkOperator) stack.get(stack.size() - 2);
      int srcPos = join.getParentOperators().indexOf(source);

      List> parents = join.getParentOperators();

      int[][] targets = getTargets(join);

      Operator parent = source.getParentOperators().get(0);
      RowSchema parentRS = parent.getSchema();

      // don't generate for null-safes.
      if (join.getConf().getNullSafes() != null) {
        for (boolean b : join.getConf().getNullSafes()) {
          if (b) {
            return null;
          }
        }
      }

      for (int targetPos: targets[srcPos]) {
        if (srcPos == targetPos) {
          continue;
        }

        if (LOG.isDebugEnabled()) {
          LOG.debug("Synthetic predicate: " + srcPos + " --> " + targetPos);
        }
        ReduceSinkOperator target = (ReduceSinkOperator) parents.get(targetPos);
        List sourceKeys = source.getConf().getKeyCols();
        List targetKeys = target.getConf().getKeyCols();

        if (sourceKeys.size() < 1) {
          continue;
        }

        ExprNodeDesc syntheticExpr = null;

        for (int i = 0; i < sourceKeys.size(); ++i) {
          List inArgs = new ArrayList();
          inArgs.add(sourceKeys.get(i));

          ExprNodeDynamicListDesc dynamicExpr =
              new ExprNodeDynamicListDesc(targetKeys.get(i).getTypeInfo(), target, i);

          inArgs.add(dynamicExpr);

          ExprNodeDesc syntheticInExpr =
              ExprNodeGenericFuncDesc.newInstance(FunctionRegistry.getFunctionInfo("in")
                  .getGenericUDF(), inArgs);

          if (syntheticExpr != null) {
            List andArgs = new ArrayList();
            andArgs.add(syntheticExpr);
            andArgs.add(syntheticInExpr);

            syntheticExpr =
                ExprNodeGenericFuncDesc.newInstance(FunctionRegistry.getFunctionInfo("and")
                    .getGenericUDF(), andArgs);
          } else {
            syntheticExpr = syntheticInExpr;
          }
        }

        Operator newFilter = createFilter(source, parent, parentRS, syntheticExpr);
        parent = newFilter;
      }

      return null;
    }

    // calculate filter propagation directions for each alias
    // L<->R for inner/semi join, L<-R for left outer join, R<-L for right outer
    // join
    private int[][] getTargets(CommonJoinOperator join) {
      JoinCondDesc[] conds = join.getConf().getConds();

      int aliases = conds.length + 1;
      Vectors vector = new Vectors(aliases);
      for (JoinCondDesc cond : conds) {
        int left = cond.getLeft();
        int right = cond.getRight();
        switch (cond.getType()) {
        case JoinDesc.INNER_JOIN:
        case JoinDesc.LEFT_SEMI_JOIN:
          vector.add(left, right);
          vector.add(right, left);
          break;
        case JoinDesc.LEFT_OUTER_JOIN:
          vector.add(right, left);
          break;
        case JoinDesc.RIGHT_OUTER_JOIN:
          vector.add(left, right);
          break;
        case JoinDesc.FULL_OUTER_JOIN:
          break;
        }
      }
      int[][] result = new int[aliases][];
      for (int pos = 0 ; pos < aliases; pos++) {
        // find all targets recursively
        result[pos] = vector.traverse(pos);
      }
      return result;
    }
  }

  private static class Vectors {

    private final Set[] vector;

    @SuppressWarnings("unchecked")
    public Vectors(int length) {
      vector = new Set[length];
    }

    public void add(int from, int to) {
      if (vector[from] == null) {
        vector[from] = new HashSet();
      }
      vector[from].add(to);
    }

    public int[] traverse(int pos) {
      Set targets = new HashSet();
      traverse(targets, pos);
      return toArray(targets);
    }

    private int[] toArray(Set values) {
      int index = 0;
      int[] result = new int[values.size()];
      for (int value : values) {
        result[index++] = value;
      }
      return result;
    }

    private void traverse(Set targets, int pos) {
      if (vector[pos] == null) {
        return;
      }
      for (int target : vector[pos]) {
        if (targets.add(target)) {
          traverse(targets, target);
        }
      }
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy