All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.flink.ml.common.optimizer.SGD Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.ml.common.optimizer;

import org.apache.flink.annotation.Internal;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.iteration.DataStreamList;
import org.apache.flink.iteration.IterationBody;
import org.apache.flink.iteration.IterationBodyResult;
import org.apache.flink.iteration.IterationConfig;
import org.apache.flink.iteration.IterationListener;
import org.apache.flink.iteration.Iterations;
import org.apache.flink.iteration.ReplayableDataStreamList;
import org.apache.flink.iteration.operator.OperatorStateUtils;
import org.apache.flink.ml.common.datastream.DataStreamUtils;
import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
import org.apache.flink.ml.common.lossfunc.LossFunc;
import org.apache.flink.ml.linalg.BLAS;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
import org.apache.flink.ml.regression.linearregression.LinearRegression;
import org.apache.flink.runtime.state.StateInitializationContext;
import org.apache.flink.runtime.state.StateSnapshotContext;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.util.Collector;
import org.apache.flink.util.OutputTag;

import org.apache.commons.collections.IteratorUtils;

import java.io.Serializable;
import java.util.Arrays;
import java.util.List;

/**
 * Stochastic Gradient Descent (SGD) is the mostly wide-used optimizer for optimizing machine
 * learning models. It iteratively makes small adjustments to the machine learning model according
 * to the gradient at each step, to decrease the error of the model.
 *
 * 

See https://en.wikipedia.org/wiki/Stochastic_gradient_descent. */ @Internal public class SGD implements Optimizer { /** Params for SGD optimizer. */ private final SGDParams params; public SGD( int maxIter, double learningRate, int globalBatchSize, double tol, double reg, double elasticNet) { this.params = new SGDParams(maxIter, learningRate, globalBatchSize, tol, reg, elasticNet); } @Override public DataStream optimize( DataStream initModelData, DataStream trainData, LossFunc lossFunc) { DataStreamList resultList = Iterations.iterateBoundedStreamsUntilTermination( DataStreamList.of( initModelData.broadcast().map(modelVec -> modelVec.values)), ReplayableDataStreamList.notReplay(trainData.rebalance().map(x -> x)), IterationConfig.newBuilder().build(), new TrainIterationBody(lossFunc, params)); return resultList.get(0); } /** The iteration implementation for training process. */ private static class TrainIterationBody implements IterationBody { private final LossFunc lossFunc; private final SGDParams params; public TrainIterationBody(LossFunc lossFunc, SGDParams params) { this.lossFunc = lossFunc; this.params = params; } @Override public IterationBodyResult process( DataStreamList variableStreams, DataStreamList dataStreams) { // The variable stream at the first iteration is the initialized model data. // In the following iterations, it contains: [the model update, totalWeight, and // totalLoss]. DataStream variableStream = variableStreams.get(0); DataStream trainData = dataStreams.get(0); final OutputTag modelDataOutputTag = new OutputTag("MODEL_OUTPUT") {}; SingleOutputStreamOperator modelUpdateAndWeightAndLoss = trainData .connect(variableStream) .transform( "CacheDataAndDoTrain", PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO, new CacheDataAndDoTrain(lossFunc, params, modelDataOutputTag)); DataStreamList feedbackVariableStream = IterationBody.forEachRound( DataStreamList.of(modelUpdateAndWeightAndLoss), input -> { DataStream feedback = DataStreamUtils.allReduceSum(input.get(0)); return DataStreamList.of(feedback); }); DataStream terminationCriteria = feedbackVariableStream .get(0) .map( reducedUpdateAndWeightAndLoss -> { double[] value = (double[]) reducedUpdateAndWeightAndLoss; return value[value.length - 1] / value[value.length - 2]; }) .flatMap(new TerminateOnMaxIterOrTol(params.maxIter, params.tol)); return new IterationBodyResult( DataStreamList.of(feedbackVariableStream.get(0)), DataStreamList.of( modelUpdateAndWeightAndLoss.getSideOutput(modelDataOutputTag)), terminationCriteria); } } /** * A stream operator that caches the training data in the first iteration and updates the model * iteratively. The first input is the training data, and the second input is the initial model * data or feedback of model update, totalWeight, and totalLoss. */ private static class CacheDataAndDoTrain extends AbstractStreamOperator implements TwoInputStreamOperator, IterationListener { /** Optimizer-related parameters. */ private final SGDParams params; /** The loss function to optimize. */ private final LossFunc lossFunc; /** The outputTag to output the model data when iteration ends. */ private final OutputTag modelDataOutputTag; /** The cached training data. */ private List trainData; private ListState trainDataState; /** The start index (offset) of the next mini-batch data for training. */ private int nextBatchOffset = 0; private ListState nextBatchOffsetState; /** The model coefficient. */ private DenseVector coefficient; private ListState coefficientState; /** The dimension of the coefficient. */ private int coefficientDim; /** * The double array to sync among all workers. For example, when training {@link * LinearRegression}, this double array consists of [modelUpdate, totalWeight, totalLoss]. */ private double[] feedbackArray; private ListState feedbackArrayState; /** The batch size on this partition. */ private int localBatchSize; private CacheDataAndDoTrain( LossFunc lossFunc, SGDParams params, OutputTag modelDataOutputTag) { this.lossFunc = lossFunc; this.params = params; this.modelDataOutputTag = modelDataOutputTag; } @Override public void open() { int numTasks = getRuntimeContext().getNumberOfParallelSubtasks(); int taskId = getRuntimeContext().getIndexOfThisSubtask(); localBatchSize = params.globalBatchSize / numTasks; if (params.globalBatchSize % numTasks > taskId) { localBatchSize++; } } private double getTotalWeight() { return feedbackArray[coefficientDim]; } private void setTotalWeight(double totalWeight) { feedbackArray[coefficientDim] = totalWeight; } private double getTotalLoss() { return feedbackArray[coefficientDim + 1]; } private void setTotalLoss(double totalLoss) { feedbackArray[coefficientDim + 1] = totalLoss; } private void updateModel() { if (getTotalWeight() > 0) { BLAS.axpy( -params.learningRate / getTotalWeight(), new DenseVector(feedbackArray), coefficient, coefficientDim); double regLoss = RegularizationUtils.regularize( coefficient, params.reg, params.elasticNet, params.learningRate); setTotalLoss(getTotalLoss() + regLoss); } } @Override public void onEpochWatermarkIncremented( int epochWatermark, Context context, Collector collector) throws Exception { if (epochWatermark == 0) { coefficient = new DenseVector(feedbackArray); coefficientDim = coefficient.size(); feedbackArray = new double[coefficient.size() + 2]; } else { updateModel(); } if (trainData == null) { trainData = IteratorUtils.toList(trainDataState.get().iterator()); } // TODO: supports efficient shuffle of training set on each partition. if (trainData.size() > 0) { List miniBatchData = trainData.subList( nextBatchOffset, Math.min(nextBatchOffset + localBatchSize, trainData.size())); nextBatchOffset += localBatchSize; nextBatchOffset = nextBatchOffset >= trainData.size() ? 0 : nextBatchOffset; // Does the training. Arrays.fill(feedbackArray, 0); double totalLoss = 0; double totalWeight = 0; DenseVector cumGradientsWrapper = new DenseVector(feedbackArray); for (LabeledPointWithWeight dataPoint : miniBatchData) { totalLoss += lossFunc.computeLoss(dataPoint, coefficient); lossFunc.computeGradient(dataPoint, coefficient, cumGradientsWrapper); totalWeight += dataPoint.getWeight(); } setTotalLoss(totalLoss); setTotalWeight(totalWeight); collector.collect(feedbackArray); } } @Override public void onIterationTerminated(Context context, Collector collector) { trainDataState.clear(); if (getRuntimeContext().getIndexOfThisSubtask() == 0) { updateModel(); context.output(modelDataOutputTag, coefficient); } } @Override public void processElement1(StreamRecord streamRecord) throws Exception { trainDataState.add(streamRecord.getValue()); } @Override public void processElement2(StreamRecord streamRecord) { feedbackArray = streamRecord.getValue(); } @Override public void initializeState(StateInitializationContext context) throws Exception { super.initializeState(context); coefficientState = context.getOperatorStateStore() .getListState( new ListStateDescriptor<>( "coefficientState", DenseVectorTypeInfo.INSTANCE)); OperatorStateUtils.getUniqueElement(coefficientState, "coefficientState") .ifPresent(x -> coefficient = x); if (coefficient != null) { coefficientDim = coefficient.size(); } feedbackArrayState = context.getOperatorStateStore() .getListState( new ListStateDescriptor<>( "feedbackArrayState", PrimitiveArrayTypeInfo .DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO)); OperatorStateUtils.getUniqueElement(feedbackArrayState, "feedbackArrayState") .ifPresent(x -> feedbackArray = x); trainDataState = context.getOperatorStateStore() .getListState( new ListStateDescriptor<>( "trainDataState", TypeInformation.of(LabeledPointWithWeight.class))); nextBatchOffsetState = context.getOperatorStateStore() .getListState( new ListStateDescriptor<>( "nextBatchOffsetState", BasicTypeInfo.INT_TYPE_INFO)); nextBatchOffset = OperatorStateUtils.getUniqueElement( nextBatchOffsetState, "nextBatchOffsetState") .orElse(0); } @Override public void snapshotState(StateSnapshotContext context) throws Exception { coefficientState.clear(); if (coefficient != null) { coefficientState.add(coefficient); } feedbackArrayState.clear(); if (feedbackArray != null) { feedbackArrayState.add(feedbackArray); } nextBatchOffsetState.clear(); nextBatchOffsetState.add(nextBatchOffset); } } /** Parameters for {@link SGD}. */ private static class SGDParams implements Serializable { public final int maxIter; public final double learningRate; public final int globalBatchSize; public final double tol; public final double reg; public final double elasticNet; private SGDParams( int maxIter, double learningRate, int globalBatchSize, double tol, double reg, double elasticNet) { this.maxIter = maxIter; this.learningRate = learningRate; this.globalBatchSize = globalBatchSize; this.tol = tol; this.reg = reg; this.elasticNet = elasticNet; } } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy