All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.hadoop.hive.ql.exec.tez.tools.KeyValueInputMerger Maven / Gradle / Ivy

There is a newer version: 4.0.0
Show newest version
/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.hadoop.hive.ql.exec.tez.tools;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.serde2.Deserializer;
import org.apache.hadoop.hive.serde2.SerDeException;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption;
import org.apache.hadoop.io.Writable;
import org.apache.tez.runtime.library.api.KeyValueReader;

/**
 * A KeyValuesReader implementation that returns a sorted stream of key-values
 * by doing a sorted merge of the key-value in LogicalInputs.
 * Tags are in the last byte of the key, so no special handling for tags is required.
 * Uses a priority queue to pick the KeyValuesReader of the input that is next in
 * sort order.
 */
@SuppressWarnings("deprecation")
public class KeyValueInputMerger extends KeyValueReader {

  public static final Log l4j = LogFactory.getLog(KeyValueInputMerger.class);
  private PriorityQueue pQueue = null;
  private KeyValueReader nextKVReader = null;
  private ObjectInspector[] inputObjInspectors = null;
  private Deserializer deserializer = null;
  private List structFields = null;
  private List fieldOIs = null;
  private final Map> kvReaderStandardObjMap =
      new HashMap>();

  public KeyValueInputMerger(List multiMRInputs, Deserializer deserializer,
      ObjectInspector[] inputObjInspectors, List sortCols) throws Exception {
    //get KeyValuesReaders from the LogicalInput and add them to priority queue
    int initialCapacity = multiMRInputs.size();
    pQueue = new PriorityQueue(initialCapacity, new KVReaderComparator());
    this.inputObjInspectors = inputObjInspectors;
    this.deserializer = deserializer;
    fieldOIs = new ArrayList();
    structFields = new ArrayList();
    StructObjectInspector structOI = (StructObjectInspector) inputObjInspectors[0];
    for (String field : sortCols) {
      StructField sf = structOI.getStructFieldRef(field);
      structFields.add(sf);
      ObjectInspector stdOI =
          ObjectInspectorUtils.getStandardObjectInspector(sf.getFieldObjectInspector());
      fieldOIs.add(stdOI);
    }
    l4j.info("Initialized the priority queue with multi mr inputs: " + multiMRInputs.size());
    for (KeyValueReader input : multiMRInputs) {
      addToQueue(input);
    }
  }

  /**
   * Add KeyValueReader to queue if it has more key-value
   *
   * @param kvReader
   * @throws IOException
   */
  private void addToQueue(KeyValueReader kvReader) throws IOException {
    if (kvReader.next()) {
      kvReaderStandardObjMap.remove(kvReader);
      pQueue.add(kvReader);
    }
  }

  /**
   * @return true if there are more key-values and advances to next key-values
   * @throws IOException
   */
  @Override
  public boolean next() throws IOException {
    //add the previous nextKVReader back to queue
    if(nextKVReader != null){
      addToQueue(nextKVReader);
    }

    //get the new nextKVReader with lowest key
    nextKVReader = pQueue.poll();
    return nextKVReader != null;
  }

  @Override
  public Object getCurrentKey() throws IOException {
    return nextKVReader.getCurrentKey();
  }

  @Override
  public Object getCurrentValue() throws IOException {
    return nextKVReader.getCurrentValue();
  }

  /**
   * Comparator that compares KeyValuesReader on their current key
   */
  class KVReaderComparator implements Comparator {

    @SuppressWarnings({ "unchecked" })
    @Override
    public int compare(KeyValueReader kvReadr1, KeyValueReader kvReadr2) {
      try {
        ObjectInspector oi = inputObjInspectors[0];
        List row1, row2;
        try {
          if (kvReaderStandardObjMap.containsKey(kvReadr1)) {
            row1 = kvReaderStandardObjMap.get(kvReadr1);
          } else {
            // we need to copy to standard object otherwise deserializer overwrites the values
            row1 =
                (List) ObjectInspectorUtils.copyToStandardObject(
                    deserializer.deserialize((Writable) kvReadr1.getCurrentValue()), oi,
                    ObjectInspectorCopyOption.WRITABLE);
            kvReaderStandardObjMap.put(kvReadr1, row1);
          }

          if (kvReaderStandardObjMap.containsKey(kvReadr2)) {
            row2 = kvReaderStandardObjMap.get(kvReadr2);
          } else {
            row2 =
                (List) ObjectInspectorUtils.copyToStandardObject(
                    deserializer.deserialize((Writable) kvReadr2.getCurrentValue()), oi,
                    ObjectInspectorCopyOption.WRITABLE);
            kvReaderStandardObjMap.put(kvReadr2, row2);
          }
        } catch (SerDeException e) {
          throw new IOException(e);
        }

        StructObjectInspector structOI = (StructObjectInspector) oi;
        int compare = 0;
        int index = 0;
        for (StructField sf : structFields) {
          int pos = structOI.getAllStructFieldRefs().indexOf(sf);
          Object key1 = row1.get(pos);
          Object key2 = row2.get(pos);
          ObjectInspector stdOI = fieldOIs.get(index);
          compare = ObjectInspectorUtils.compare(key1, stdOI, key2, stdOI);
          index++;
          if (compare != 0) {
            return compare;
          }
        }

        return compare;
      } catch (IOException e) {
        l4j.error("Caught exception while reading shuffle input", e);
        //die!
        throw new RuntimeException(e);
      }
    }
  }
}