All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.hadoop.hive.ql.exec.tez.DynamicPartitionPruner Maven / Gradle / Ivy

/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.hadoop.hive.ql.exec.tez;

import java.io.DataInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicBoolean;

import com.facebook.presto.hive.$internal.com.google.common.annotations.VisibleForTesting;
import com.facebook.presto.hive.$internal.com.google.common.base.Preconditions;

import com.facebook.presto.hive.$internal.org.apache.commons.lang3.mutable.MutableInt;
import com.facebook.presto.hive.$internal.org.apache.commons.logging.Log;
import com.facebook.presto.hive.$internal.org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.ExprNodeEvaluator;
import org.apache.hadoop.hive.ql.exec.ExprNodeEvaluatorFactory;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.plan.ExprNodeDesc;
import org.apache.hadoop.hive.ql.plan.MapWork;
import org.apache.hadoop.hive.ql.plan.PartitionDesc;
import org.apache.hadoop.hive.ql.plan.TableDesc;
import org.apache.hadoop.hive.serde2.Deserializer;
import org.apache.hadoop.hive.serde2.SerDeException;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters.Converter;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory;
import org.apache.hadoop.io.BytesWritable;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.util.ReflectionUtils;
import org.apache.tez.dag.api.event.VertexState;
import org.apache.tez.runtime.api.InputInitializerContext;
import org.apache.tez.runtime.api.events.InputInitializerEvent;

/**
 * DynamicPartitionPruner takes a list of assigned partitions at runtime (split
 * generation) and prunes them using events generated during execution of the
 * dag.
 *
 */
public class DynamicPartitionPruner {

  private static final Log LOG = LogFactory.getLog(DynamicPartitionPruner.class);

  private final InputInitializerContext context;
  private final MapWork work;
  private final JobConf jobConf;


  private final Map> sourceInfoMap =
      new HashMap>();

  private final BytesWritable writable = new BytesWritable();

  /* Keeps track of all events that need to be processed - irrespective of the source */
  private final BlockingQueue queue = new LinkedBlockingQueue();

  /* Keeps track of vertices from which events are expected */
  private final Set sourcesWaitingForEvents = new HashSet();

  // Stores negative values to count columns. Eventually set to #tasks X #columns after the source vertex completes.
  private final Map numExpectedEventsPerSource = new HashMap<>();
  private final Map numEventsSeenPerSource = new HashMap<>();

  private int sourceInfoCount = 0;

  private final Object endOfEvents = new Object();

  private int totalEventCount = 0;

  public DynamicPartitionPruner(InputInitializerContext context, MapWork work, JobConf jobConf) throws
      SerDeException {
    this.context = context;
    this.work = work;
    this.jobConf = jobConf;
    synchronized (this) {
      initialize();
    }
  }

  public void prune()
      throws SerDeException, IOException,
      InterruptedException, HiveException {

    synchronized(sourcesWaitingForEvents) {

      if (sourcesWaitingForEvents.isEmpty()) {
        return;
      }

      Set states = Collections.singleton(VertexState.SUCCEEDED);
      for (String source : sourcesWaitingForEvents) {
        // we need to get state transition updates for the vertices that will send
        // events to us. once we have received all events and a vertex has succeeded,
        // we can move to do the pruning.
        context.registerForVertexStateUpdates(source, states);
      }
    }

    LOG.info("Waiting for events (" + sourceInfoCount + " sources) ...");
    // synchronous event processing loop. Won't return until all events have
    // been processed.
    this.processEvents();
    this.prunePartitions();
    LOG.info("Ok to proceed.");
  }

  public BlockingQueue getQueue() {
    return queue;
  }

  private void clear() {
    sourceInfoMap.clear();
    sourceInfoCount = 0;
  }

  private void initialize() throws SerDeException {
    this.clear();
    Map columnMap = new HashMap();
    // sources represent vertex names
    Set sources = work.getEventSourceTableDescMap().keySet();

    sourcesWaitingForEvents.addAll(sources);

    for (String s : sources) {
      // Set to 0 to start with. This will be decremented for all columns for which events
      // are generated by this source - which is eventually used to determine number of expected
      // events for the source. #colums X #tasks
      numExpectedEventsPerSource.put(s, new MutableInt(0));
      numEventsSeenPerSource.put(s, new MutableInt(0));
      // Virtual relation generated by the reduce sync
      List tables = work.getEventSourceTableDescMap().get(s);
      // Real column name - on which the operation is being performed
      List columnNames = work.getEventSourceColumnNameMap().get(s);
      // Expression for the operation. e.g. N^2 > 10
      List partKeyExprs = work.getEventSourcePartKeyExprMap().get(s);
      // eventSourceTableDesc, eventSourceColumnName, evenSourcePartKeyExpr move in lock-step.
      // One entry is added to each at the same time

      Iterator cit = columnNames.iterator();
      Iterator pit = partKeyExprs.iterator();
      // A single source can process multiple columns, and will send an event for each of them.
      for (TableDesc t : tables) {
        numExpectedEventsPerSource.get(s).decrement();
        ++sourceInfoCount;
        String columnName = cit.next();
        ExprNodeDesc partKeyExpr = pit.next();
        SourceInfo si = createSourceInfo(t, partKeyExpr, columnName, jobConf);
        if (!sourceInfoMap.containsKey(s)) {
          sourceInfoMap.put(s, new ArrayList());
        }
        List sis = sourceInfoMap.get(s);
        sis.add(si);

        // We could have multiple sources restrict the same column, need to take
        // the union of the values in that case.
        if (columnMap.containsKey(columnName)) {
          // All Sources are initialized up front. Events from different sources will end up getting added to the same list.
          // Pruning is disabled if either source sends in an event which causes pruning to be skipped
          si.values = columnMap.get(columnName).values;
          si.skipPruning = columnMap.get(columnName).skipPruning;
        }
        columnMap.put(columnName, si);
      }
    }
  }

  private void prunePartitions() throws HiveException {
    int expectedEvents = 0;
    for (Map.Entry> entry : this.sourceInfoMap.entrySet()) {
      String source = entry.getKey();
      for (SourceInfo si : entry.getValue()) {
        int taskNum = context.getVertexNumTasks(source);
        LOG.info("Expecting " + taskNum + " events for vertex " + source + ", for column " + si.columnName);
        expectedEvents += taskNum;
        prunePartitionSingleSource(source, si);
      }
    }

    // sanity check. all tasks must submit events for us to succeed.
    if (expectedEvents != totalEventCount) {
      LOG.error("Expecting: " + expectedEvents + ", received: " + totalEventCount);
      throw new HiveException("Incorrect event count in dynamic partition pruning");
    }
  }

  @VisibleForTesting
  protected void prunePartitionSingleSource(String source, SourceInfo si)
      throws HiveException {

    if (si.skipPruning.get()) {
      // in this case we've determined that there's too much data
      // to prune dynamically.
      LOG.info("Skip pruning on " + source + ", column " + si.columnName);
      return;
    }

    Set values = si.values;
    String columnName = si.columnName;

    if (LOG.isDebugEnabled()) {
      StringBuilder sb = new StringBuilder("Pruning ");
      sb.append(columnName);
      sb.append(" with ");
      for (Object value : values) {
        sb.append(value == null ? null : value.toString());
        sb.append(", ");
      }
      LOG.debug(sb.toString());
    }

    ObjectInspector oi =
        PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(TypeInfoFactory
            .getPrimitiveTypeInfo(si.fieldInspector.getTypeName()));

    Converter converter =
        ObjectInspectorConverters.getConverter(
            PrimitiveObjectInspectorFactory.javaStringObjectInspector, oi);

    StructObjectInspector soi =
        ObjectInspectorFactory.getStandardStructObjectInspector(
            Collections.singletonList(columnName), Collections.singletonList(oi));

    @SuppressWarnings("rawtypes")
    ExprNodeEvaluator eval = ExprNodeEvaluatorFactory.get(si.partKey);
    eval.initialize(soi);

    applyFilterToPartitions(converter, eval, columnName, values);
  }

  @SuppressWarnings("rawtypes")
  private void applyFilterToPartitions(Converter converter, ExprNodeEvaluator eval,
      String columnName, Set values) throws HiveException {

    Object[] row = new Object[1];

    Iterator it = work.getPathToPartitionInfo().keySet().iterator();
    while (it.hasNext()) {
      String p = it.next();
      PartitionDesc desc = work.getPathToPartitionInfo().get(p);
      Map spec = desc.getPartSpec();
      if (spec == null) {
        throw new IllegalStateException("No partition spec found in dynamic pruning");
      }

      String partValueString = spec.get(columnName);
      if (partValueString == null) {
        throw new IllegalStateException("Could not find partition value for column: " + columnName);
      }

      Object partValue = converter.convert(partValueString);
      if (LOG.isDebugEnabled()) {
        LOG.debug("Converted partition value: " + partValue + " original (" + partValueString + ")");
      }

      row[0] = partValue;
      partValue = eval.evaluate(row);
      if (LOG.isDebugEnabled()) {
        LOG.debug("part key expr applied: " + partValue);
      }

      if (!values.contains(partValue)) {
        LOG.info("Pruning path: " + p);
        it.remove();
        work.getPathToAliases().remove(p);
        work.getPaths().remove(p);
        work.getPartitionDescs().remove(desc);
      }
    }
  }

  @VisibleForTesting
  protected SourceInfo createSourceInfo(TableDesc t, ExprNodeDesc partKeyExpr, String columnName,
                                        JobConf jobConf) throws
      SerDeException {
    return new SourceInfo(t, partKeyExpr, columnName, jobConf);

  }

  @SuppressWarnings("deprecation")
  @VisibleForTesting
  static class SourceInfo {
    public final ExprNodeDesc partKey;
    public final Deserializer deserializer;
    public final StructObjectInspector soi;
    public final StructField field;
    public final ObjectInspector fieldInspector;
    /* List of partitions that are required - populated from processing each event */
    public Set values = new HashSet();
    /* Whether to skipPruning - depends on the payload from an event which may signal skip - if the event payload is too large */
    public AtomicBoolean skipPruning = new AtomicBoolean();
    public final String columnName;

    @VisibleForTesting // Only used for testing.
    SourceInfo(TableDesc table, ExprNodeDesc partKey, String columnName, JobConf jobConf, Object forTesting) {
      this.partKey = partKey;
      this.columnName = columnName;
      this.deserializer = null;
      this.soi = null;
      this.field = null;
      this.fieldInspector = null;
    }

    public SourceInfo(TableDesc table, ExprNodeDesc partKey, String columnName, JobConf jobConf)
        throws SerDeException {

      this.skipPruning.set(false);

      this.partKey = partKey;

      this.columnName = columnName;

      deserializer = ReflectionUtils.newInstance(table.getDeserializerClass(), null);
      deserializer.initialize(jobConf, table.getProperties());

      ObjectInspector inspector = deserializer.getObjectInspector();
      LOG.debug("Type of obj insp: " + inspector.getTypeName());

      soi = (StructObjectInspector) inspector;
      List fields = soi.getAllStructFieldRefs();
      if (fields.size() > 1) {
        LOG.error("expecting single field in input");
      }

      field = fields.get(0);

      fieldInspector =
          ObjectInspectorUtils.getStandardObjectInspector(field.getFieldObjectInspector());
    }
  }

  private void processEvents() throws SerDeException, IOException, InterruptedException {
    int eventCount = 0;

    while (true) {
      Object element = queue.take();

      if (element == endOfEvents) {
        // we're done processing events
        break;
      }

      InputInitializerEvent event = (InputInitializerEvent) element;

      LOG.info("Input event: " + event.getTargetInputName() + ", " + event.getTargetVertexName()
          + ", " + (event.getUserPayload().limit() - event.getUserPayload().position()));
      processPayload(event.getUserPayload(), event.getSourceVertexName());
      eventCount += 1;
    }
    LOG.info("Received events: " + eventCount);
  }

  @SuppressWarnings("deprecation")
  @VisibleForTesting
  protected String processPayload(ByteBuffer payload, String sourceName) throws SerDeException,
      IOException {

    DataInputStream in = new DataInputStream(new ByteBufferBackedInputStream(payload));
    try {
      String columnName = in.readUTF();

      LOG.info("Source of event: " + sourceName);

      List infos = this.sourceInfoMap.get(sourceName);
      if (infos == null) {
        throw new IllegalStateException("no source info for event source: " + sourceName);
      }

      SourceInfo info = null;
      for (SourceInfo si : infos) {
        if (columnName.equals(si.columnName)) {
          info = si;
          break;
        }
      }

      if (info == null) {
        throw new IllegalStateException("no source info for column: " + columnName);
      }

      if (info.skipPruning.get()) {
        // Marked as skipped previously. Don't bother processing the rest of the payload.
      } else {
        boolean skip = in.readBoolean();
        if (skip) {
          info.skipPruning.set(true);
        } else {
          while (payload.hasRemaining()) {
            writable.readFields(in);

            Object row = info.deserializer.deserialize(writable);

            Object value = info.soi.getStructFieldData(row, info.field);
            value = ObjectInspectorUtils.copyToStandardObject(value, info.fieldInspector);

            if (LOG.isDebugEnabled()) {
              LOG.debug("Adding: " + value + " to list of required partitions");
            }
            info.values.add(value);
          }
        }
      }
    } finally {
      if (in != null) {
        in.close();
      }
    }
    return sourceName;
  }

  private static class ByteBufferBackedInputStream extends InputStream {

    ByteBuffer buf;

    public ByteBufferBackedInputStream(ByteBuffer buf) {
      this.buf = buf;
    }

    @Override
    public int read() throws IOException {
      if (!buf.hasRemaining()) {
        return -1;
      }
      return buf.get() & 0xFF;
    }

    @Override
    public int read(byte[] bytes, int off, int len) throws IOException {
      if (!buf.hasRemaining()) {
        return -1;
      }

      len = Math.min(len, buf.remaining());
      buf.get(bytes, off, len);
      return len;
    }
  }

  public void addEvent(InputInitializerEvent event) {
    synchronized(sourcesWaitingForEvents) {
      if (sourcesWaitingForEvents.contains(event.getSourceVertexName())) {
        ++totalEventCount;
        numEventsSeenPerSource.get(event.getSourceVertexName()).increment();
        queue.offer(event);
        checkForSourceCompletion(event.getSourceVertexName());
      }
    }
  }

  public void processVertex(String name) {
    LOG.info("Vertex succeeded: " + name);
    synchronized(sourcesWaitingForEvents) {
      // Get a deterministic count of number of tasks for the vertex.
      MutableInt prevVal = numExpectedEventsPerSource.get(name);
      int prevValInt = prevVal.intValue();
      Preconditions.checkState(prevValInt < 0,
          "Invalid value for numExpectedEvents for source: " + name + ", oldVal=" + prevValInt);
      prevVal.setValue((-1) * prevValInt * context.getVertexNumTasks(name));
      checkForSourceCompletion(name);
    }
  }

  private void checkForSourceCompletion(String name) {
    int expectedEvents = numExpectedEventsPerSource.get(name).getValue();
    if (expectedEvents < 0) {
      // Expected events not updated yet - vertex SUCCESS notification not received.
      return;
    } else {
      int processedEvents = numEventsSeenPerSource.get(name).getValue();
      if (processedEvents == expectedEvents) {
        sourcesWaitingForEvents.remove(name);
        if (sourcesWaitingForEvents.isEmpty()) {
          // we've got what we need; mark the queue
          queue.offer(endOfEvents);
        } else {
          LOG.info("Waiting for " + sourcesWaitingForEvents.size() + " sources.");
        }
      } else if (processedEvents > expectedEvents) {
        throw new IllegalStateException(
            "Received too many events for " + name + ", Expected=" + expectedEvents +
                ", Received=" + processedEvents);
      }
      return;
    }
  }
}