All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.flink.optimizer.postpass.JavaApiPostPass Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.optimizer.postpass;

import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;

import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.operators.DualInputOperator;
import org.apache.flink.api.common.operators.GenericDataSourceBase;
import org.apache.flink.api.common.operators.Operator;
import org.apache.flink.api.common.operators.SingleInputOperator;
import org.apache.flink.api.common.operators.base.BulkIterationBase;
import org.apache.flink.api.common.operators.base.DeltaIterationBase;
import org.apache.flink.api.common.operators.base.GroupReduceOperatorBase;
import org.apache.flink.api.common.operators.util.FieldList;
import org.apache.flink.api.common.typeinfo.AtomicType;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.CompositeType;
import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.api.common.typeutils.TypeComparatorFactory;
import org.apache.flink.api.common.typeutils.TypePairComparatorFactory;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.common.typeutils.TypeSerializerFactory;
import org.apache.flink.api.java.operators.translation.PlanUnwrappingReduceGroupOperator;
import org.apache.flink.api.java.tuple.Tuple;
import org.apache.flink.api.java.typeutils.runtime.RuntimeComparatorFactory;
import org.apache.flink.api.java.typeutils.runtime.RuntimePairComparatorFactory;
import org.apache.flink.api.java.typeutils.runtime.RuntimeSerializerFactory;
import org.apache.flink.optimizer.CompilerException;
import org.apache.flink.optimizer.CompilerPostPassException;
import org.apache.flink.optimizer.plan.BulkIterationPlanNode;
import org.apache.flink.optimizer.plan.BulkPartialSolutionPlanNode;
import org.apache.flink.optimizer.plan.Channel;
import org.apache.flink.optimizer.plan.DualInputPlanNode;
import org.apache.flink.optimizer.plan.NAryUnionPlanNode;
import org.apache.flink.optimizer.plan.OptimizedPlan;
import org.apache.flink.optimizer.plan.PlanNode;
import org.apache.flink.optimizer.plan.SingleInputPlanNode;
import org.apache.flink.optimizer.plan.SinkPlanNode;
import org.apache.flink.optimizer.plan.SolutionSetPlanNode;
import org.apache.flink.optimizer.plan.SourcePlanNode;
import org.apache.flink.optimizer.plan.WorksetIterationPlanNode;
import org.apache.flink.optimizer.plan.WorksetPlanNode;
import org.apache.flink.optimizer.util.NoOpUnaryUdfOp;
import org.apache.flink.runtime.operators.DriverStrategy;

/**
 * The post-optimizer plan traversal. This traversal fills in the API specific utilities (serializers and
 * comparators).
 */
public class JavaApiPostPass implements OptimizerPostPass {
	
	private final Set alreadyDone = new HashSet();

	private ExecutionConfig executionConfig = null;
	
	@Override
	public void postPass(OptimizedPlan plan) {

		executionConfig = plan.getOriginalPlan().getExecutionConfig();

		for (SinkPlanNode sink : plan.getDataSinks()) {
			traverse(sink);
		}
	}
	

	protected void traverse(PlanNode node) {
		if (!alreadyDone.add(node)) {
			// already worked on that one
			return;
		}
		
		// distinguish the node types
		if (node instanceof SinkPlanNode) {
			// descend to the input channel
			SinkPlanNode sn = (SinkPlanNode) node;
			Channel inchannel = sn.getInput();
			traverseChannel(inchannel);
		}
		else if (node instanceof SourcePlanNode) {
			TypeInformation typeInfo = getTypeInfoFromSource((SourcePlanNode) node);
			((SourcePlanNode) node).setSerializer(createSerializer(typeInfo));
		}
		else if (node instanceof BulkIterationPlanNode) {
			BulkIterationPlanNode iterationNode = (BulkIterationPlanNode) node;

			if (iterationNode.getRootOfStepFunction() instanceof NAryUnionPlanNode) {
				throw new CompilerException("Optimizer cannot compile an iteration step function where next partial solution is created by a Union node.");
			}
			
			// traverse the termination criterion for the first time. create schema only, no utilities. Needed in case of intermediate termination criterion
			if (iterationNode.getRootOfTerminationCriterion() != null) {
				SingleInputPlanNode addMapper = (SingleInputPlanNode) iterationNode.getRootOfTerminationCriterion();
				traverseChannel(addMapper.getInput());
			}

			BulkIterationBase operator = (BulkIterationBase) iterationNode.getProgramOperator();

			// set the serializer
			iterationNode.setSerializerForIterationChannel(createSerializer(operator.getOperatorInfo().getOutputType()));

			// done, we can now propagate our info down
			traverseChannel(iterationNode.getInput());
			traverse(iterationNode.getRootOfStepFunction());
		}
		else if (node instanceof WorksetIterationPlanNode) {
			WorksetIterationPlanNode iterationNode = (WorksetIterationPlanNode) node;
			
			if (iterationNode.getNextWorkSetPlanNode() instanceof NAryUnionPlanNode) {
				throw new CompilerException("Optimizer cannot compile a workset iteration step function where the next workset is produced by a Union node.");
			}
			if (iterationNode.getSolutionSetDeltaPlanNode() instanceof NAryUnionPlanNode) {
				throw new CompilerException("Optimizer cannot compile a workset iteration step function where the solution set delta is produced by a Union node.");
			}
			
			DeltaIterationBase operator = (DeltaIterationBase) iterationNode.getProgramOperator();
			
			// set the serializers and comparators for the workset iteration
			iterationNode.setSolutionSetSerializer(createSerializer(operator.getOperatorInfo().getFirstInputType()));
			iterationNode.setWorksetSerializer(createSerializer(operator.getOperatorInfo().getSecondInputType()));
			iterationNode.setSolutionSetComparator(createComparator(operator.getOperatorInfo().getFirstInputType(),
					iterationNode.getSolutionSetKeyFields(), getSortOrders(iterationNode.getSolutionSetKeyFields(), null)));
			
			// traverse the inputs
			traverseChannel(iterationNode.getInput1());
			traverseChannel(iterationNode.getInput2());
			
			// traverse the step function
			traverse(iterationNode.getSolutionSetDeltaPlanNode());
			traverse(iterationNode.getNextWorkSetPlanNode());
		}
		else if (node instanceof SingleInputPlanNode) {
			SingleInputPlanNode sn = (SingleInputPlanNode) node;
			
			if (!(sn.getOptimizerNode().getOperator() instanceof SingleInputOperator)) {
				
				// Special case for delta iterations
				if(sn.getOptimizerNode().getOperator() instanceof NoOpUnaryUdfOp) {
					traverseChannel(sn.getInput());
					return;
				} else {
					throw new RuntimeException("Wrong operator type found in post pass.");
				}
			}
			
			SingleInputOperator singleInputOperator = (SingleInputOperator) sn.getOptimizerNode().getOperator();
			
			// parameterize the node's driver strategy
			for(int i=0;i dualInputOperator = (DualInputOperator) dn.getOptimizerNode().getOperator();
			
			// parameterize the node's driver strategy
			if (dn.getDriverStrategy().getNumRequiredComparators() > 0) {
				dn.setComparator1(createComparator(dualInputOperator.getOperatorInfo().getFirstInputType(), dn.getKeysForInput1(),
					getSortOrders(dn.getKeysForInput1(), dn.getSortOrders())));
				dn.setComparator2(createComparator(dualInputOperator.getOperatorInfo().getSecondInputType(), dn.getKeysForInput2(),
						getSortOrders(dn.getKeysForInput2(), dn.getSortOrders())));

				dn.setPairComparator(createPairComparator(dualInputOperator.getOperatorInfo().getFirstInputType(),
						dualInputOperator.getOperatorInfo().getSecondInputType()));
				
			}
						
			traverseChannel(dn.getInput1());
			traverseChannel(dn.getInput2());
			
			// don't forget the broadcast inputs
			for (Channel c: dn.getBroadcastInputs()) {
				traverseChannel(c);
			}
			
		}
		// catch the sources of the iterative step functions
		else if (node instanceof BulkPartialSolutionPlanNode ||
				node instanceof SolutionSetPlanNode ||
				node instanceof WorksetPlanNode)
		{
			// Do nothing :D
		}
		else if (node instanceof NAryUnionPlanNode){
			// Traverse to all child channels
			for (Channel channel : node.getInputs()) {
				traverseChannel(channel);
			}
		}
		else {
			throw new CompilerPostPassException("Unknown node type encountered: " + node.getClass().getName());
		}
	}
	
	private void traverseChannel(Channel channel) {
		
		PlanNode source = channel.getSource();
		Operator javaOp = source.getProgramOperator();
		
//		if (!(javaOp instanceof BulkIteration) && !(javaOp instanceof JavaPlanNode)) {
//			throw new RuntimeException("Wrong operator type found in post pass: " + javaOp);
//		}

		TypeInformation type = javaOp.getOperatorInfo().getOutputType();


		if(javaOp instanceof GroupReduceOperatorBase &&
				(source.getDriverStrategy() == DriverStrategy.SORTED_GROUP_COMBINE || source.getDriverStrategy() == DriverStrategy.ALL_GROUP_REDUCE_COMBINE)) {
			GroupReduceOperatorBase groupNode = (GroupReduceOperatorBase) javaOp;
			type = groupNode.getInput().getOperatorInfo().getOutputType();
		}
		else if(javaOp instanceof PlanUnwrappingReduceGroupOperator &&
				source.getDriverStrategy().equals(DriverStrategy.SORTED_GROUP_COMBINE)) {
			PlanUnwrappingReduceGroupOperator groupNode = (PlanUnwrappingReduceGroupOperator) javaOp;
			type = groupNode.getInput().getOperatorInfo().getOutputType();
		}
		
		// the serializer always exists
		channel.setSerializer(createSerializer(type));
			
		// parameterize the ship strategy
		if (channel.getShipStrategy().requiresComparator()) {
			channel.setShipStrategyComparator(createComparator(type, channel.getShipStrategyKeys(), 
				getSortOrders(channel.getShipStrategyKeys(), channel.getShipStrategySortOrder())));
		}
			
		// parameterize the local strategy
		if (channel.getLocalStrategy().requiresComparator()) {
			channel.setLocalStrategyComparator(createComparator(type, channel.getLocalStrategyKeys(),
				getSortOrders(channel.getLocalStrategyKeys(), channel.getLocalStrategySortOrder())));
		}
		
		// descend to the channel's source
		traverse(channel.getSource());
	}
	
	
	@SuppressWarnings("unchecked")
	private static  TypeInformation getTypeInfoFromSource(SourcePlanNode node) {
		Operator op = node.getOptimizerNode().getOperator();
		
		if (op instanceof GenericDataSourceBase) {
			return ((GenericDataSourceBase) op).getOperatorInfo().getOutputType();
		} else {
			throw new RuntimeException("Wrong operator type found in post pass.");
		}
	}
	
	private  TypeSerializerFactory createSerializer(TypeInformation typeInfo) {
		TypeSerializer serializer = typeInfo.createSerializer(executionConfig);

		return new RuntimeSerializerFactory(serializer, typeInfo.getTypeClass());
	}
	
	@SuppressWarnings("unchecked")
	private  TypeComparatorFactory createComparator(TypeInformation typeInfo, FieldList keys, boolean[] sortOrder) {
		
		TypeComparator comparator;
		if (typeInfo instanceof CompositeType) {
			comparator = ((CompositeType) typeInfo).createComparator(keys.toArray(), sortOrder, 0, executionConfig);
		}
		else if (typeInfo instanceof AtomicType) {
			// handle grouping of atomic types
			comparator = ((AtomicType) typeInfo).createComparator(sortOrder[0], executionConfig);
		}
		else {
			throw new RuntimeException("Unrecognized type: " + typeInfo);
		}

		return new RuntimeComparatorFactory(comparator);
	}
	
	private static  TypePairComparatorFactory createPairComparator(TypeInformation typeInfo1, TypeInformation typeInfo2) {
//		@SuppressWarnings("unchecked")
//		TupleTypeInfo info1 = (TupleTypeInfo) typeInfo1;
//		@SuppressWarnings("unchecked")
//		TupleTypeInfo info2 = (TupleTypeInfo) typeInfo2;
		
		return new RuntimePairComparatorFactory();
	}
	
	private static final boolean[] getSortOrders(FieldList keys, boolean[] orders) {
		if (orders == null) {
			orders = new boolean[keys.size()];
			Arrays.fill(orders, true);
		}
		return orders;
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy