org.apache.flink.runtime.iterative.task.IterationHeadTask Maven / Gradle / Ivy
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.runtime.iterative.task;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.apache.flink.runtime.accumulators.AccumulatorRegistry;
import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
import org.apache.flink.runtime.io.network.api.writer.RecordWriter;
import org.apache.flink.runtime.operators.Driver;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.flink.api.common.functions.Function;
import org.apache.flink.api.common.operators.util.JoinHashMap;
import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.api.common.typeutils.TypeComparatorFactory;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.common.typeutils.TypeSerializerFactory;
import org.apache.flink.core.memory.DataInputView;
import org.apache.flink.core.memory.MemorySegment;
import org.apache.flink.runtime.io.disk.InputViewIterator;
import org.apache.flink.runtime.iterative.concurrent.BlockingBackChannel;
import org.apache.flink.runtime.iterative.concurrent.BlockingBackChannelBroker;
import org.apache.flink.runtime.iterative.concurrent.Broker;
import org.apache.flink.runtime.iterative.concurrent.IterationAggregatorBroker;
import org.apache.flink.runtime.iterative.concurrent.SolutionSetBroker;
import org.apache.flink.runtime.iterative.concurrent.SolutionSetUpdateBarrier;
import org.apache.flink.runtime.iterative.concurrent.SolutionSetUpdateBarrierBroker;
import org.apache.flink.runtime.iterative.concurrent.SuperstepBarrier;
import org.apache.flink.runtime.iterative.concurrent.SuperstepKickoffLatch;
import org.apache.flink.runtime.iterative.concurrent.SuperstepKickoffLatchBroker;
import org.apache.flink.runtime.iterative.event.AllWorkersDoneEvent;
import org.apache.flink.runtime.iterative.event.TerminationEvent;
import org.apache.flink.runtime.iterative.event.WorkerDoneEvent;
import org.apache.flink.runtime.iterative.io.SerializedUpdateBuffer;
import org.apache.flink.runtime.operators.BatchTask;
import org.apache.flink.runtime.operators.hash.CompactingHashTable;
import org.apache.flink.runtime.operators.util.TaskConfig;
import org.apache.flink.types.Value;
import org.apache.flink.util.Collector;
import org.apache.flink.util.MutableObjectIterator;
/**
* The head is responsible for coordinating an iteration and can run a
* {@link Driver} inside. It will read
* the initial input and establish a {@link BlockingBackChannel} to the iteration's tail. After successfully processing
* the input, it will send EndOfSuperstep events to its outputs. It must also be connected to a
* synchronization task and after each superstep, it will wait
* until it receives an {@link AllWorkersDoneEvent} from the sync, which signals that all other heads have also finished
* their iteration. Starting with
* the second iteration, the input for the head is the output of the tail, transmitted through the backchannel. Once the
* iteration is done, the head
* will send a {@link TerminationEvent} to all it's connected tasks, signaling them to shutdown.
*
* Assumption on the ordering of the outputs: - The first n output gates write to channels that go to the tasks of the
* step function. - The next m output gates to to the tasks that consume the final solution. - The last output gate
* connects to the synchronization task.
*
* @param
* The type of the bulk partial solution / solution set and the final output.
* @param
* The type of the feed-back data set (bulk partial solution / workset). For bulk iterations, {@code Y} is the
* same as {@code X}
*/
public class IterationHeadTask extends AbstractIterativeTask {
private static final Logger log = LoggerFactory.getLogger(IterationHeadTask.class);
private Collector finalOutputCollector;
private TypeSerializerFactory feedbackTypeSerializer;
private TypeSerializerFactory solutionTypeSerializer;
private ResultPartitionWriter toSync;
private int feedbackDataInput; // workset or bulk partial solution
// --------------------------------------------------------------------------------------------
@Override
protected int getNumTaskInputs() {
// this task has an additional input in the workset case for the initial solution set
boolean isWorkset = config.getIsWorksetIteration();
return driver.getNumberOfInputs() + (isWorkset ? 1 : 0);
}
@Override
protected void initOutputs() throws Exception {
// initialize the regular outputs first (the ones into the step function).
super.initOutputs();
// at this time, the outputs to the step function are created
// add the outputs for the final solution
List> finalOutputWriters = new ArrayList>();
final TaskConfig finalOutConfig = this.config.getIterationHeadFinalOutputConfig();
final ClassLoader userCodeClassLoader = getUserCodeClassLoader();
AccumulatorRegistry.Reporter reporter = getEnvironment().getAccumulatorRegistry().getReadWriteReporter();
this.finalOutputCollector = BatchTask.getOutputCollector(this, finalOutConfig,
userCodeClassLoader, finalOutputWriters, config.getNumOutputs(), finalOutConfig.getNumOutputs(), reporter);
// sanity check the setup
final int writersIntoStepFunction = this.eventualOutputs.size();
final int writersIntoFinalResult = finalOutputWriters.size();
final int syncGateIndex = this.config.getIterationHeadIndexOfSyncOutput();
if (writersIntoStepFunction + writersIntoFinalResult != syncGateIndex) {
throw new Exception("Error: Inconsistent head task setup - wrong mapping of output gates.");
}
// now, we can instantiate the sync gate
this.toSync = getEnvironment().getWriter(syncGateIndex);
}
/**
* the iteration head prepares the backchannel: it allocates memory, instantiates a {@link BlockingBackChannel} and
* hands it to the iteration tail via a {@link Broker} singleton
**/
private BlockingBackChannel initBackChannel() throws Exception {
/* get the size of the memory available to the backchannel */
int backChannelMemoryPages = getMemoryManager().computeNumberOfPages(this.config.getRelativeBackChannelMemory());
/* allocate the memory available to the backchannel */
List segments = new ArrayList();
int segmentSize = getMemoryManager().getPageSize();
getMemoryManager().allocatePages(this, segments, backChannelMemoryPages);
/* instantiate the backchannel */
BlockingBackChannel backChannel = new BlockingBackChannel(new SerializedUpdateBuffer(segments, segmentSize,
getIOManager()));
/* hand the backchannel over to the iteration tail */
Broker broker = BlockingBackChannelBroker.instance();
broker.handIn(brokerKey(), backChannel);
return backChannel;
}
private CompactingHashTable initCompactingHashTable() throws Exception {
// get some memory
double hashjoinMemorySize = config.getRelativeSolutionSetMemory();
final ClassLoader userCodeClassLoader = getUserCodeClassLoader();
TypeSerializerFactory solutionTypeSerializerFactory = config.getSolutionSetSerializer(userCodeClassLoader);
TypeComparatorFactory solutionTypeComparatorFactory = config.getSolutionSetComparator(userCodeClassLoader);
TypeSerializer solutionTypeSerializer = solutionTypeSerializerFactory.getSerializer();
TypeComparator solutionTypeComparator = solutionTypeComparatorFactory.createComparator();
CompactingHashTable hashTable = null;
List memSegments = null;
boolean success = false;
try {
int numPages = getMemoryManager().computeNumberOfPages(hashjoinMemorySize);
memSegments = getMemoryManager().allocatePages(getContainingTask(), numPages);
hashTable = new CompactingHashTable(solutionTypeSerializer, solutionTypeComparator, memSegments);
success = true;
return hashTable;
} finally {
if (!success) {
if (hashTable != null) {
try {
hashTable.close();
} catch (Throwable t) {
log.error("Error closing the solution set hash table after unsuccessful creation.", t);
}
}
if (memSegments != null) {
try {
getMemoryManager().release(memSegments);
} catch (Throwable t) {
log.error("Error freeing memory after error during solution set hash table creation.", t);
}
}
}
}
}
private JoinHashMap initJoinHashMap() {
TypeSerializerFactory solutionTypeSerializerFactory = config.getSolutionSetSerializer
(getUserCodeClassLoader());
TypeComparatorFactory solutionTypeComparatorFactory = config.getSolutionSetComparator
(getUserCodeClassLoader());
TypeSerializer solutionTypeSerializer = solutionTypeSerializerFactory.getSerializer();
TypeComparator solutionTypeComparator = solutionTypeComparatorFactory.createComparator();
return new JoinHashMap(solutionTypeSerializer, solutionTypeComparator);
}
private void readInitialSolutionSet(CompactingHashTable solutionSet, MutableObjectIterator solutionSetInput) throws IOException {
solutionSet.open();
solutionSet.buildTableWithUniqueKey(solutionSetInput);
}
private void readInitialSolutionSet(JoinHashMap solutionSet, MutableObjectIterator solutionSetInput) throws IOException {
TypeSerializer serializer = solutionTypeSerializer.getSerializer();
X next;
while ((next = solutionSetInput.next(serializer.createInstance())) != null) {
solutionSet.insertOrReplace(next);
}
}
private SuperstepBarrier initSuperstepBarrier() {
SuperstepBarrier barrier = new SuperstepBarrier(getUserCodeClassLoader());
this.toSync.subscribeToEvent(barrier, AllWorkersDoneEvent.class);
this.toSync.subscribeToEvent(barrier, TerminationEvent.class);
return barrier;
}
@Override
public void run() throws Exception {
final String brokerKey = brokerKey();
final int workerIndex = getEnvironment().getTaskInfo().getIndexOfThisSubtask();
final boolean objectSolutionSet = config.isSolutionSetUnmanaged();
CompactingHashTable solutionSet = null; // if workset iteration
JoinHashMap solutionSetObjectMap = null; // if workset iteration with unmanaged solution set
boolean waitForSolutionSetUpdate = config.getWaitForSolutionSetUpdate();
boolean isWorksetIteration = config.getIsWorksetIteration();
try {
/* used for receiving the current iteration result from iteration tail */
SuperstepKickoffLatch nextStepKickoff = new SuperstepKickoffLatch();
SuperstepKickoffLatchBroker.instance().handIn(brokerKey, nextStepKickoff);
BlockingBackChannel backChannel = initBackChannel();
SuperstepBarrier barrier = initSuperstepBarrier();
SolutionSetUpdateBarrier solutionSetUpdateBarrier = null;
feedbackDataInput = config.getIterationHeadPartialSolutionOrWorksetInputIndex();
feedbackTypeSerializer = this.getInputSerializer(feedbackDataInput);
excludeFromReset(feedbackDataInput);
int initialSolutionSetInput;
if (isWorksetIteration) {
initialSolutionSetInput = config.getIterationHeadSolutionSetInputIndex();
solutionTypeSerializer = config.getSolutionSetSerializer(getUserCodeClassLoader());
// setup the index for the solution set
@SuppressWarnings("unchecked")
MutableObjectIterator solutionSetInput = (MutableObjectIterator) createInputIterator(inputReaders[initialSolutionSetInput], solutionTypeSerializer);
// read the initial solution set
if (objectSolutionSet) {
solutionSetObjectMap = initJoinHashMap();
readInitialSolutionSet(solutionSetObjectMap, solutionSetInput);
SolutionSetBroker.instance().handIn(brokerKey, solutionSetObjectMap);
} else {
solutionSet = initCompactingHashTable();
readInitialSolutionSet(solutionSet, solutionSetInput);
SolutionSetBroker.instance().handIn(brokerKey, solutionSet);
}
if (waitForSolutionSetUpdate) {
solutionSetUpdateBarrier = new SolutionSetUpdateBarrier();
SolutionSetUpdateBarrierBroker.instance().handIn(brokerKey, solutionSetUpdateBarrier);
}
}
else {
// bulk iteration case
@SuppressWarnings("unchecked")
TypeSerializerFactory solSer = (TypeSerializerFactory) feedbackTypeSerializer;
solutionTypeSerializer = solSer;
// = termination Criterion tail
if (waitForSolutionSetUpdate) {
solutionSetUpdateBarrier = new SolutionSetUpdateBarrier();
SolutionSetUpdateBarrierBroker.instance().handIn(brokerKey, solutionSetUpdateBarrier);
}
}
// instantiate all aggregators and register them at the iteration global registry
RuntimeAggregatorRegistry aggregatorRegistry = new RuntimeAggregatorRegistry(config.getIterationAggregators
(getUserCodeClassLoader()));
IterationAggregatorBroker.instance().handIn(brokerKey, aggregatorRegistry);
DataInputView superstepResult = null;
while (this.running && !terminationRequested()) {
if (log.isInfoEnabled()) {
log.info(formatLogString("starting iteration [" + currentIteration() + "]"));
}
barrier.setup();
if (waitForSolutionSetUpdate) {
solutionSetUpdateBarrier.setup();
}
if (!inFirstIteration()) {
feedBackSuperstepResult(superstepResult);
}
super.run();
// signal to connected tasks that we are done with the superstep
sendEndOfSuperstepToAllIterationOutputs();
if (waitForSolutionSetUpdate) {
solutionSetUpdateBarrier.waitForSolutionSetUpdate();
}
// blocking call to wait for the result
superstepResult = backChannel.getReadEndAfterSuperstepEnded();
if (log.isInfoEnabled()) {
log.info(formatLogString("finishing iteration [" + currentIteration() + "]"));
}
sendEventToSync(new WorkerDoneEvent(workerIndex, aggregatorRegistry.getAllAggregators()));
if (log.isInfoEnabled()) {
log.info(formatLogString("waiting for other workers in iteration [" + currentIteration() + "]"));
}
barrier.waitForOtherWorkers();
if (barrier.terminationSignaled()) {
if (log.isInfoEnabled()) {
log.info(formatLogString("head received termination request in iteration ["
+ currentIteration()
+ "]"));
}
requestTermination();
nextStepKickoff.signalTermination();
} else {
incrementIterationCounter();
String[] globalAggregateNames = barrier.getAggregatorNames();
Value[] globalAggregates = barrier.getAggregates();
aggregatorRegistry.updateGlobalAggregatesAndReset(globalAggregateNames, globalAggregates);
nextStepKickoff.triggerNextSuperstep();
}
}
if (log.isInfoEnabled()) {
log.info(formatLogString("streaming out final result after [" + currentIteration() + "] iterations"));
}
if (isWorksetIteration) {
if (objectSolutionSet) {
streamSolutionSetToFinalOutput(solutionSetObjectMap);
} else {
streamSolutionSetToFinalOutput(solutionSet);
}
} else {
streamOutFinalOutputBulk(new InputViewIterator(superstepResult, this.solutionTypeSerializer.getSerializer()));
}
this.finalOutputCollector.close();
} finally {
// make sure we unregister everything from the broker:
// - backchannel
// - aggregator registry
// - solution set index
IterationAggregatorBroker.instance().remove(brokerKey);
BlockingBackChannelBroker.instance().remove(brokerKey);
SuperstepKickoffLatchBroker.instance().remove(brokerKey);
SolutionSetBroker.instance().remove(brokerKey);
SolutionSetUpdateBarrierBroker.instance().remove(brokerKey);
if (solutionSet != null) {
solutionSet.close();
}
}
}
private void streamOutFinalOutputBulk(MutableObjectIterator results) throws IOException {
final Collector out = this.finalOutputCollector;
X record = this.solutionTypeSerializer.getSerializer().createInstance();
while ((record = results.next(record)) != null) {
out.collect(record);
}
}
private void streamSolutionSetToFinalOutput(CompactingHashTable hashTable) throws IOException {
final MutableObjectIterator results = hashTable.getEntryIterator();
final Collector output = this.finalOutputCollector;
X record = solutionTypeSerializer.getSerializer().createInstance();
while ((record = results.next(record)) != null) {
output.collect(record);
}
}
@SuppressWarnings("unchecked")
private void streamSolutionSetToFinalOutput(JoinHashMap soluionSet) throws IOException {
final Collector output = this.finalOutputCollector;
for (Object e : soluionSet.values()) {
output.collect((X) e);
}
}
private void feedBackSuperstepResult(DataInputView superstepResult) {
this.inputs[this.feedbackDataInput] =
new InputViewIterator(superstepResult, this.feedbackTypeSerializer.getSerializer());
}
private void sendEndOfSuperstepToAllIterationOutputs() throws IOException, InterruptedException {
if (log.isDebugEnabled()) {
log.debug(formatLogString("Sending end-of-superstep to all iteration outputs."));
}
for (RecordWriter eventualOutput : this.eventualOutputs) {
eventualOutput.sendEndOfSuperstep();
}
}
private void sendEventToSync(WorkerDoneEvent event) throws IOException, InterruptedException {
if (log.isInfoEnabled()) {
log.info(formatLogString("sending " + WorkerDoneEvent.class.getSimpleName() + " to sync"));
}
this.toSync.writeEventToAllChannels(event);
}
}