All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.flink.runtime.iterative.task.IterationHeadTask Maven / Gradle / Ivy

There is a newer version: 1.13.6
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.runtime.iterative.task;

import org.apache.flink.api.common.functions.Function;
import org.apache.flink.api.common.operators.util.JoinHashMap;
import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.api.common.typeutils.TypeComparatorFactory;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.common.typeutils.TypeSerializerFactory;
import org.apache.flink.core.io.IOReadableWritable;
import org.apache.flink.core.memory.DataInputView;
import org.apache.flink.core.memory.MemorySegment;
import org.apache.flink.runtime.execution.Environment;
import org.apache.flink.runtime.io.disk.InputViewIterator;
import org.apache.flink.runtime.io.network.TaskEventDispatcher;
import org.apache.flink.runtime.io.network.api.EndOfSuperstepEvent;
import org.apache.flink.runtime.io.network.api.writer.RecordWriter;
import org.apache.flink.runtime.io.network.api.writer.RecordWriterBuilder;
import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
import org.apache.flink.runtime.iterative.concurrent.BlockingBackChannel;
import org.apache.flink.runtime.iterative.concurrent.BlockingBackChannelBroker;
import org.apache.flink.runtime.iterative.concurrent.Broker;
import org.apache.flink.runtime.iterative.concurrent.IterationAggregatorBroker;
import org.apache.flink.runtime.iterative.concurrent.SolutionSetBroker;
import org.apache.flink.runtime.iterative.concurrent.SolutionSetUpdateBarrier;
import org.apache.flink.runtime.iterative.concurrent.SolutionSetUpdateBarrierBroker;
import org.apache.flink.runtime.iterative.concurrent.SuperstepBarrier;
import org.apache.flink.runtime.iterative.concurrent.SuperstepKickoffLatch;
import org.apache.flink.runtime.iterative.concurrent.SuperstepKickoffLatchBroker;
import org.apache.flink.runtime.iterative.event.AllWorkersDoneEvent;
import org.apache.flink.runtime.iterative.event.TerminationEvent;
import org.apache.flink.runtime.iterative.event.WorkerDoneEvent;
import org.apache.flink.runtime.iterative.io.SerializedUpdateBuffer;
import org.apache.flink.runtime.operators.BatchTask;
import org.apache.flink.runtime.operators.Driver;
import org.apache.flink.runtime.operators.hash.CompactingHashTable;
import org.apache.flink.runtime.operators.util.TaskConfig;
import org.apache.flink.types.Value;
import org.apache.flink.util.Collector;
import org.apache.flink.util.MutableObjectIterator;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

/**
 * The head is responsible for coordinating an iteration and can run a
 * {@link Driver} inside. It will read
 * the initial input and establish a {@link BlockingBackChannel} to the iteration's tail. After successfully processing
 * the input, it will send EndOfSuperstep events to its outputs. It must also be connected to a
 * synchronization task and after each superstep, it will wait
 * until it receives an {@link AllWorkersDoneEvent} from the sync, which signals that all other heads have also finished
 * their iteration. Starting with
 * the second iteration, the input for the head is the output of the tail, transmitted through the backchannel. Once the
 * iteration is done, the head
 * will send a {@link TerminationEvent} to all it's connected tasks, signaling them to shutdown.
 *
 * 

Assumption on the ordering of the outputs: - The first n output gates write to channels that go to the tasks of the * step function. - The next m output gates to to the tasks that consume the final solution. - The last output gate * connects to the synchronization task. * * @param * The type of the bulk partial solution / solution set and the final output. * @param * The type of the feed-back data set (bulk partial solution / workset). For bulk iterations, {@code Y} is the * same as {@code X} */ public class IterationHeadTask extends AbstractIterativeTask { private static final Logger log = LoggerFactory.getLogger(IterationHeadTask.class); private Collector finalOutputCollector; private TypeSerializerFactory feedbackTypeSerializer; private TypeSerializerFactory solutionTypeSerializer; private RecordWriter toSync; private ResultPartitionID toSyncPartitionId; private int feedbackDataInput; // workset or bulk partial solution // -------------------------------------------------------------------------------------------- /** * Create an Invokable task and set its environment. * * @param environment The environment assigned to this invokable. */ public IterationHeadTask(Environment environment) { super(environment); } // -------------------------------------------------------------------------------------------- @Override protected int getNumTaskInputs() { // this task has an additional input in the workset case for the initial solution set boolean isWorkset = config.getIsWorksetIteration(); return driver.getNumberOfInputs() + (isWorkset ? 1 : 0); } @Override protected void initOutputs() throws Exception { // initialize the regular outputs first (the ones into the step function). super.initOutputs(); // at this time, the outputs to the step function are created // add the outputs for the final solution List> finalOutputWriters = new ArrayList>(); final TaskConfig finalOutConfig = this.config.getIterationHeadFinalOutputConfig(); final ClassLoader userCodeClassLoader = getUserCodeClassLoader(); this.finalOutputCollector = BatchTask.getOutputCollector(this, finalOutConfig, userCodeClassLoader, finalOutputWriters, config.getNumOutputs(), finalOutConfig.getNumOutputs()); // sanity check the setup final int writersIntoStepFunction = this.eventualOutputs.size(); final int writersIntoFinalResult = finalOutputWriters.size(); final int syncGateIndex = this.config.getIterationHeadIndexOfSyncOutput(); if (writersIntoStepFunction + writersIntoFinalResult != syncGateIndex) { throw new Exception("Error: Inconsistent head task setup - wrong mapping of output gates."); } // now, we can instantiate the sync gate this.toSync = new RecordWriterBuilder<>().build(getEnvironment().getWriter(syncGateIndex)); this.toSyncPartitionId = getEnvironment().getWriter(syncGateIndex).getPartitionId(); } /** * The iteration head prepares the backchannel: it allocates memory, instantiates a {@link BlockingBackChannel} and * hands it to the iteration tail via a {@link Broker} singleton. **/ private BlockingBackChannel initBackChannel() throws Exception { /* get the size of the memory available to the backchannel */ int backChannelMemoryPages = getMemoryManager().computeNumberOfPages(this.config.getRelativeBackChannelMemory()); /* allocate the memory available to the backchannel */ List segments = new ArrayList(); int segmentSize = getMemoryManager().getPageSize(); getMemoryManager().allocatePages(this, segments, backChannelMemoryPages); /* instantiate the backchannel */ BlockingBackChannel backChannel = new BlockingBackChannel(new SerializedUpdateBuffer(segments, segmentSize, getIOManager())); /* hand the backchannel over to the iteration tail */ Broker broker = BlockingBackChannelBroker.instance(); broker.handIn(brokerKey(), backChannel); return backChannel; } private CompactingHashTable initCompactingHashTable() throws Exception { // get some memory double hashjoinMemorySize = config.getRelativeSolutionSetMemory(); final ClassLoader userCodeClassLoader = getUserCodeClassLoader(); TypeSerializerFactory solutionTypeSerializerFactory = config.getSolutionSetSerializer(userCodeClassLoader); TypeComparatorFactory solutionTypeComparatorFactory = config.getSolutionSetComparator(userCodeClassLoader); TypeSerializer solutionTypeSerializer = solutionTypeSerializerFactory.getSerializer(); TypeComparator solutionTypeComparator = solutionTypeComparatorFactory.createComparator(); CompactingHashTable hashTable = null; List memSegments = null; boolean success = false; try { int numPages = getMemoryManager().computeNumberOfPages(hashjoinMemorySize); memSegments = getMemoryManager().allocatePages(getContainingTask(), numPages); hashTable = new CompactingHashTable(solutionTypeSerializer, solutionTypeComparator, memSegments); success = true; return hashTable; } finally { if (!success) { if (hashTable != null) { try { hashTable.close(); } catch (Throwable t) { log.error("Error closing the solution set hash table after unsuccessful creation.", t); } } if (memSegments != null) { try { getMemoryManager().release(memSegments); } catch (Throwable t) { log.error("Error freeing memory after error during solution set hash table creation.", t); } } } } } private JoinHashMap initJoinHashMap() { TypeSerializerFactory solutionTypeSerializerFactory = config.getSolutionSetSerializer (getUserCodeClassLoader()); TypeComparatorFactory solutionTypeComparatorFactory = config.getSolutionSetComparator (getUserCodeClassLoader()); TypeSerializer solutionTypeSerializer = solutionTypeSerializerFactory.getSerializer(); TypeComparator solutionTypeComparator = solutionTypeComparatorFactory.createComparator(); return new JoinHashMap(solutionTypeSerializer, solutionTypeComparator); } private void readInitialSolutionSet(CompactingHashTable solutionSet, MutableObjectIterator solutionSetInput) throws IOException { solutionSet.open(); solutionSet.buildTableWithUniqueKey(solutionSetInput); } private void readInitialSolutionSet(JoinHashMap solutionSet, MutableObjectIterator solutionSetInput) throws IOException { TypeSerializer serializer = solutionTypeSerializer.getSerializer(); X next; while ((next = solutionSetInput.next(serializer.createInstance())) != null) { solutionSet.insertOrReplace(next); } } private SuperstepBarrier initSuperstepBarrier() { SuperstepBarrier barrier = new SuperstepBarrier(getUserCodeClassLoader()); TaskEventDispatcher taskEventDispatcher = getEnvironment().getTaskEventDispatcher(); ResultPartitionID partitionId = toSyncPartitionId; taskEventDispatcher.subscribeToEvent(partitionId, barrier, AllWorkersDoneEvent.class); taskEventDispatcher.subscribeToEvent(partitionId, barrier, TerminationEvent.class); return barrier; } @Override public void run() throws Exception { final String brokerKey = brokerKey(); final int workerIndex = getEnvironment().getTaskInfo().getIndexOfThisSubtask(); final boolean objectSolutionSet = config.isSolutionSetUnmanaged(); CompactingHashTable solutionSet = null; // if workset iteration JoinHashMap solutionSetObjectMap = null; // if workset iteration with unmanaged solution set boolean waitForSolutionSetUpdate = config.getWaitForSolutionSetUpdate(); boolean isWorksetIteration = config.getIsWorksetIteration(); try { /* used for receiving the current iteration result from iteration tail */ SuperstepKickoffLatch nextStepKickoff = new SuperstepKickoffLatch(); SuperstepKickoffLatchBroker.instance().handIn(brokerKey, nextStepKickoff); BlockingBackChannel backChannel = initBackChannel(); SuperstepBarrier barrier = initSuperstepBarrier(); SolutionSetUpdateBarrier solutionSetUpdateBarrier = null; feedbackDataInput = config.getIterationHeadPartialSolutionOrWorksetInputIndex(); feedbackTypeSerializer = this.getInputSerializer(feedbackDataInput); excludeFromReset(feedbackDataInput); int initialSolutionSetInput; if (isWorksetIteration) { initialSolutionSetInput = config.getIterationHeadSolutionSetInputIndex(); solutionTypeSerializer = config.getSolutionSetSerializer(getUserCodeClassLoader()); // setup the index for the solution set @SuppressWarnings("unchecked") MutableObjectIterator solutionSetInput = (MutableObjectIterator) createInputIterator(inputReaders[initialSolutionSetInput], solutionTypeSerializer); // read the initial solution set if (objectSolutionSet) { solutionSetObjectMap = initJoinHashMap(); readInitialSolutionSet(solutionSetObjectMap, solutionSetInput); SolutionSetBroker.instance().handIn(brokerKey, solutionSetObjectMap); } else { solutionSet = initCompactingHashTable(); readInitialSolutionSet(solutionSet, solutionSetInput); SolutionSetBroker.instance().handIn(brokerKey, solutionSet); } if (waitForSolutionSetUpdate) { solutionSetUpdateBarrier = new SolutionSetUpdateBarrier(); SolutionSetUpdateBarrierBroker.instance().handIn(brokerKey, solutionSetUpdateBarrier); } } else { // bulk iteration case @SuppressWarnings("unchecked") TypeSerializerFactory solSer = (TypeSerializerFactory) feedbackTypeSerializer; solutionTypeSerializer = solSer; // = termination Criterion tail if (waitForSolutionSetUpdate) { solutionSetUpdateBarrier = new SolutionSetUpdateBarrier(); SolutionSetUpdateBarrierBroker.instance().handIn(brokerKey, solutionSetUpdateBarrier); } } // instantiate all aggregators and register them at the iteration global registry RuntimeAggregatorRegistry aggregatorRegistry = new RuntimeAggregatorRegistry(config.getIterationAggregators (getUserCodeClassLoader())); IterationAggregatorBroker.instance().handIn(brokerKey, aggregatorRegistry); DataInputView superstepResult = null; while (this.running && !terminationRequested()) { if (log.isInfoEnabled()) { log.info(formatLogString("starting iteration [" + currentIteration() + "]")); } barrier.setup(); if (waitForSolutionSetUpdate) { solutionSetUpdateBarrier.setup(); } if (!inFirstIteration()) { feedBackSuperstepResult(superstepResult); } super.run(); // signal to connected tasks that we are done with the superstep sendEndOfSuperstepToAllIterationOutputs(); if (waitForSolutionSetUpdate) { solutionSetUpdateBarrier.waitForSolutionSetUpdate(); } // blocking call to wait for the result superstepResult = backChannel.getReadEndAfterSuperstepEnded(); if (log.isInfoEnabled()) { log.info(formatLogString("finishing iteration [" + currentIteration() + "]")); } sendEventToSync(new WorkerDoneEvent(workerIndex, aggregatorRegistry.getAllAggregators())); if (log.isInfoEnabled()) { log.info(formatLogString("waiting for other workers in iteration [" + currentIteration() + "]")); } barrier.waitForOtherWorkers(); if (barrier.terminationSignaled()) { if (log.isInfoEnabled()) { log.info(formatLogString("head received termination request in iteration [" + currentIteration() + "]")); } requestTermination(); nextStepKickoff.signalTermination(); } else { incrementIterationCounter(); String[] globalAggregateNames = barrier.getAggregatorNames(); Value[] globalAggregates = barrier.getAggregates(); aggregatorRegistry.updateGlobalAggregatesAndReset(globalAggregateNames, globalAggregates); nextStepKickoff.triggerNextSuperstep(); } } if (log.isInfoEnabled()) { log.info(formatLogString("streaming out final result after [" + currentIteration() + "] iterations")); } if (isWorksetIteration) { if (objectSolutionSet) { streamSolutionSetToFinalOutput(solutionSetObjectMap); } else { streamSolutionSetToFinalOutput(solutionSet); } } else { streamOutFinalOutputBulk(new InputViewIterator(superstepResult, this.solutionTypeSerializer.getSerializer())); } this.finalOutputCollector.close(); } finally { // make sure we unregister everything from the broker: // - backchannel // - aggregator registry // - solution set index IterationAggregatorBroker.instance().remove(brokerKey); BlockingBackChannelBroker.instance().remove(brokerKey); SuperstepKickoffLatchBroker.instance().remove(brokerKey); SolutionSetBroker.instance().remove(brokerKey); SolutionSetUpdateBarrierBroker.instance().remove(brokerKey); if (solutionSet != null) { solutionSet.close(); } } } private void streamOutFinalOutputBulk(MutableObjectIterator results) throws IOException { final Collector out = this.finalOutputCollector; X record = this.solutionTypeSerializer.getSerializer().createInstance(); while ((record = results.next(record)) != null) { out.collect(record); } } private void streamSolutionSetToFinalOutput(CompactingHashTable hashTable) throws IOException { final MutableObjectIterator results = hashTable.getEntryIterator(); final Collector output = this.finalOutputCollector; X record = solutionTypeSerializer.getSerializer().createInstance(); while ((record = results.next(record)) != null) { output.collect(record); } } @SuppressWarnings("unchecked") private void streamSolutionSetToFinalOutput(JoinHashMap soluionSet) throws IOException { final Collector output = this.finalOutputCollector; for (Object e : soluionSet.values()) { output.collect((X) e); } } private void feedBackSuperstepResult(DataInputView superstepResult) { this.inputs[this.feedbackDataInput] = new InputViewIterator(superstepResult, this.feedbackTypeSerializer.getSerializer()); } private void sendEndOfSuperstepToAllIterationOutputs() throws IOException, InterruptedException { if (log.isDebugEnabled()) { log.debug(formatLogString("Sending end-of-superstep to all iteration outputs.")); } for (RecordWriter eventualOutput : this.eventualOutputs) { eventualOutput.broadcastEvent(EndOfSuperstepEvent.INSTANCE); } } private void sendEventToSync(WorkerDoneEvent event) throws IOException, InterruptedException { if (log.isInfoEnabled()) { log.info(formatLogString("sending " + WorkerDoneEvent.class.getSimpleName() + " to sync")); } this.toSync.broadcastEvent(event); } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy