All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.flink.graph.gsa.GatherSumApplyIteration Maven / Gradle / Ivy

There is a newer version: 1.3.3
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.graph.gsa;

import org.apache.flink.api.common.aggregators.Aggregator;
import org.apache.flink.api.common.functions.FlatJoinFunction;
import org.apache.flink.api.common.functions.RichFlatJoinFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.functions.RichReduceFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.functions.FunctionAnnotation.ForwardedFields;
import org.apache.flink.api.java.functions.FunctionAnnotation.ForwardedFieldsSecond;
import org.apache.flink.api.java.operators.CustomUnaryOperation;
import org.apache.flink.api.java.operators.DeltaIteration;
import org.apache.flink.api.java.operators.JoinOperator;
import org.apache.flink.api.java.operators.MapOperator;
import org.apache.flink.api.java.operators.ReduceOperator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.ResultTypeQueryable;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.api.java.typeutils.TypeExtractor;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.graph.Edge;
import org.apache.flink.graph.EdgeDirection;
import org.apache.flink.graph.Graph;
import org.apache.flink.graph.Vertex;
import org.apache.flink.util.Collector;
import java.util.Map;

import com.google.common.base.Preconditions;

/**
 * This class represents iterative graph computations, programmed in a gather-sum-apply perspective.
 *
 * @param  The type of the vertex key in the graph
 * @param  The type of the vertex value in the graph
 * @param  The type of the edge value in the graph
 * @param  The intermediate type used by the gather, sum and apply functions
 */
public class GatherSumApplyIteration implements CustomUnaryOperation,
		Vertex> {

	private DataSet> vertexDataSet;
	private DataSet> edgeDataSet;

	private final GatherFunction gather;
	private final SumFunction sum;
	private final ApplyFunction apply;
	private final int maximumNumberOfIterations;
	private EdgeDirection direction = EdgeDirection.OUT;

	private GSAConfiguration configuration;

	// ----------------------------------------------------------------------------------

	private GatherSumApplyIteration(GatherFunction gather, SumFunction sum,
			ApplyFunction apply, DataSet> edges, int maximumNumberOfIterations) {

		Preconditions.checkNotNull(gather);
		Preconditions.checkNotNull(sum);
		Preconditions.checkNotNull(apply);
		Preconditions.checkNotNull(edges);
		Preconditions.checkArgument(maximumNumberOfIterations > 0, "The maximum number of iterations must be at least one.");

		this.gather = gather;
		this.sum = sum;
		this.apply = apply;
		this.edgeDataSet = edges;
		this.maximumNumberOfIterations = maximumNumberOfIterations;
	}

	// --------------------------------------------------------------------------------------------
	//  Custom Operator behavior
	// --------------------------------------------------------------------------------------------

	/**
	 * Sets the input data set for this operator. In the case of this operator this input data set represents
	 * the set of vertices with their initial state.
	 *
	 * @param dataSet The input data set, which in the case of this operator represents the set of
	 *                vertices with their initial state.
	 */
	@Override
	public void setInput(DataSet> dataSet) {
		this.vertexDataSet = dataSet;
	}

	/**
	 * Computes the results of the gather-sum-apply iteration
	 *
	 * @return The resulting DataSet
	 */
	@Override
	public DataSet> createResult() {
		if (vertexDataSet == null) {
			throw new IllegalStateException("The input data set has not been set.");
		}

		// Prepare type information
		TypeInformation keyType = ((TupleTypeInfo) vertexDataSet.getType()).getTypeAt(0);
		TypeInformation messageType = TypeExtractor.createTypeInfo(gather, GatherFunction.class, gather.getClass(), 2);
		TypeInformation> innerType = new TupleTypeInfo>(keyType, messageType);
		TypeInformation> outputType = vertexDataSet.getType();

		// create a graph
		Graph graph =
				Graph.fromDataSet(vertexDataSet, edgeDataSet, vertexDataSet.getExecutionEnvironment());

		// check whether the numVertices option is set and, if so, compute the total number of vertices
		// and set it within the gather, sum and apply functions
		if (this.configuration != null && this.configuration.isOptNumVertices()) {
			try {
				long numberOfVertices = graph.numberOfVertices();
				gather.setNumberOfVertices(numberOfVertices);
				sum.setNumberOfVertices(numberOfVertices);
				apply.setNumberOfVertices(numberOfVertices);
			} catch (Exception e) {
				e.printStackTrace();
			}
		}

		// Prepare UDFs
		GatherUdf gatherUdf = new GatherUdf(gather, innerType);
		SumUdf sumUdf = new SumUdf(sum, innerType);
		ApplyUdf applyUdf = new ApplyUdf(apply, outputType);

		final int[] zeroKeyPos = new int[] {0};
		final DeltaIteration, Vertex> iteration =
				vertexDataSet.iterateDelta(vertexDataSet, maximumNumberOfIterations, zeroKeyPos);

		// set up the iteration operator
		if (this.configuration != null) {

			iteration.name(this.configuration.getName(
					"Gather-sum-apply iteration (" + gather + " | " + sum + " | " + apply + ")"));
			iteration.parallelism(this.configuration.getParallelism());
			iteration.setSolutionSetUnManaged(this.configuration.isSolutionSetUnmanagedMemory());

			// register all aggregators
			for (Map.Entry> entry : this.configuration.getAggregators().entrySet()) {
				iteration.registerAggregator(entry.getKey(), entry.getValue());
			}
		}
		else {
			// no configuration provided; set default name
			iteration.name("Gather-sum-apply iteration (" + gather + " | " + sum + " | " + apply + ")");
		}

		// Prepare the neighbors
		if(this.configuration != null) {
			direction = this.configuration.getDirection();
		}
		DataSet>> neighbors;
		switch(direction) {
			case OUT:
				neighbors = iteration
				.getWorkset().join(edgeDataSet)
				.where(0).equalTo(0).with(new ProjectKeyWithNeighborOUT());
				break;
			case IN:
				neighbors = iteration
				.getWorkset().join(edgeDataSet)
				.where(0).equalTo(1).with(new ProjectKeyWithNeighborIN());
				break;
			case ALL:
				neighbors =  iteration
						.getWorkset().join(edgeDataSet)
						.where(0).equalTo(0).with(new ProjectKeyWithNeighborOUT()).union(iteration
								.getWorkset().join(edgeDataSet)
								.where(0).equalTo(1).with(new ProjectKeyWithNeighborIN()));
				break;
			default:
				neighbors = iteration
						.getWorkset().join(edgeDataSet)
						.where(0).equalTo(0).with(new ProjectKeyWithNeighborOUT());
				break;
		}

		// Gather, sum and apply
		MapOperator>, Tuple2> gatherMapOperator = neighbors.map(gatherUdf);

		// configure map gather function with name and broadcast variables
		gatherMapOperator = gatherMapOperator.name("Gather");

		if (this.configuration != null) {
			for (Tuple2> e : this.configuration.getGatherBcastVars()) {
				gatherMapOperator = gatherMapOperator.withBroadcastSet(e.f1, e.f0);
			}
		}
		DataSet> gatheredSet = gatherMapOperator;

		ReduceOperator> sumReduceOperator = gatheredSet.groupBy(0).reduce(sumUdf);

		// configure reduce sum function with name and broadcast variables
		sumReduceOperator = sumReduceOperator.name("Sum");

		if (this.configuration != null) {
			for (Tuple2> e : this.configuration.getSumBcastVars()) {
				sumReduceOperator = sumReduceOperator.withBroadcastSet(e.f1, e.f0);
			}
		}
		DataSet> summedSet = sumReduceOperator;

		JoinOperator> appliedSet = summedSet
				.join(iteration.getSolutionSet())
				.where(0)
				.equalTo(0)
				.with(applyUdf);

		// configure join apply function with name and broadcast variables
		appliedSet = appliedSet.name("Apply");

		if (this.configuration != null) {
			for (Tuple2> e : this.configuration.getApplyBcastVars()) {
				appliedSet = appliedSet.withBroadcastSet(e.f1, e.f0);
			}
		}

		// let the operator know that we preserve the key field
		appliedSet.withForwardedFieldsFirst("0").withForwardedFieldsSecond("0");

		return iteration.closeWith(appliedSet, appliedSet);
	}

	/**
	 * Creates a new gather-sum-apply iteration operator for graphs
	 *
	 * @param edges The edge DataSet
	 *
	 * @param gather The gather function of the GSA iteration
	 * @param sum The sum function of the GSA iteration
	 * @param apply The apply function of the GSA iteration
	 *
	 * @param maximumNumberOfIterations The maximum number of iterations executed
	 *
	 * @param  The type of the vertex key in the graph
	 * @param  The type of the vertex value in the graph
	 * @param  The type of the edge value in the graph
	 * @param  The intermediate type used by the gather, sum and apply functions
	 *
	 * @return An in stance of the gather-sum-apply graph computation operator.
	 */
	public static final  GatherSumApplyIteration
		withEdges(DataSet> edges, GatherFunction gather,
		SumFunction sum, ApplyFunction apply, int maximumNumberOfIterations) {

		return new GatherSumApplyIteration(gather, sum, apply, edges, maximumNumberOfIterations);
	}

	// --------------------------------------------------------------------------------------------
	//  Wrapping UDFs
	// --------------------------------------------------------------------------------------------

	@SuppressWarnings("serial")
	@ForwardedFields("f0")
	private static final class GatherUdf extends RichMapFunction>,
			Tuple2> implements ResultTypeQueryable> {

		private final GatherFunction gatherFunction;
		private transient TypeInformation> resultType;

		private GatherUdf(GatherFunction gatherFunction, TypeInformation> resultType) {
			this.gatherFunction = gatherFunction;
			this.resultType = resultType;
		}

		@Override
		public Tuple2 map(Tuple2> neighborTuple) {
			M result = this.gatherFunction.gather(neighborTuple.f1);
			return new Tuple2(neighborTuple.f0, result);
		}

		@Override
		public void open(Configuration parameters) throws Exception {
			if (getIterationRuntimeContext().getSuperstepNumber() == 1) {
				this.gatherFunction.init(getIterationRuntimeContext());
			}
			this.gatherFunction.preSuperstep();
		}

		@Override
		public void close() throws Exception {
			this.gatherFunction.postSuperstep();
		}

		@Override
		public TypeInformation> getProducedType() {
			return this.resultType;
		}
	}

	@SuppressWarnings("serial")
	private static final class SumUdf extends RichReduceFunction>
			implements ResultTypeQueryable>{

		private final SumFunction sumFunction;
		private transient TypeInformation> resultType;

		private SumUdf(SumFunction sumFunction, TypeInformation> resultType) {
			this.sumFunction = sumFunction;
			this.resultType = resultType;
		}

		@Override
		public Tuple2 reduce(Tuple2 arg0, Tuple2 arg1) throws Exception {
			K key = arg0.f0;
			M result = this.sumFunction.sum(arg0.f1, arg1.f1);
			return new Tuple2(key, result);
		}

		@Override
		public void open(Configuration parameters) throws Exception {
			if (getIterationRuntimeContext().getSuperstepNumber() == 1) {
				this.sumFunction.init(getIterationRuntimeContext());
			}
			this.sumFunction.preSuperstep();
		}

		@Override
		public void close() throws Exception {
			this.sumFunction.postSuperstep();
		}

		@Override
		public TypeInformation> getProducedType() {
			return this.resultType;
		}
	}

	@SuppressWarnings("serial")
	private static final class ApplyUdf extends RichFlatJoinFunction,
			Vertex, Vertex> implements ResultTypeQueryable> {

		private final ApplyFunction applyFunction;
		private transient TypeInformation> resultType;

		private ApplyUdf(ApplyFunction applyFunction, TypeInformation> resultType) {
			this.applyFunction = applyFunction;
			this.resultType = resultType;
		}

		@Override
		public void join(Tuple2 newValue, final Vertex currentValue, final Collector> out) throws Exception {

			this.applyFunction.setOutput(currentValue, out);
			this.applyFunction.apply(newValue.f1, currentValue.getValue());
		}

		@Override
		public void open(Configuration parameters) throws Exception {
			if (getIterationRuntimeContext().getSuperstepNumber() == 1) {
				this.applyFunction.init(getIterationRuntimeContext());
			}
			this.applyFunction.preSuperstep();
		}

		@Override
		public void close() throws Exception {
			this.applyFunction.postSuperstep();
		}

		@Override
		public TypeInformation> getProducedType() {
			return this.resultType;
		}
	}

	@SuppressWarnings("serial")
	@ForwardedFieldsSecond("f1->f0")
	private static final class ProjectKeyWithNeighborOUT implements FlatJoinFunction<
			Vertex, Edge, Tuple2>> {

		public void join(Vertex vertex, Edge edge, Collector>> out) {
			out.collect(new Tuple2>(
					edge.getTarget(), new Neighbor(vertex.getValue(), edge.getValue())));
		}
	}

	@SuppressWarnings("serial")
	@ForwardedFieldsSecond({"f0"})
	private static final class ProjectKeyWithNeighborIN implements FlatJoinFunction<
			Vertex, Edge, Tuple2>> {

		public void join(Vertex vertex, Edge edge, Collector>> out) {
			out.collect(new Tuple2>(
					edge.getSource(), new Neighbor(vertex.getValue(), edge.getValue())));
		}
	}




	/**
	 * Configures this gather-sum-apply iteration with the provided parameters.
	 *
	 * @param parameters the configuration parameters
	 */
	public void configure(GSAConfiguration parameters) {
		this.configuration = parameters;
	}

	/**
	 * @return the configuration parameters of this gather-sum-apply iteration
	 */
	public GSAConfiguration getIterationConfiguration() {
		return this.configuration;
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy