All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.timeseries.model.deepar.DeepARNetwork Maven / Gradle / Ivy

/*
 * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */

package ai.djl.timeseries.model.deepar;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Block;
import ai.djl.nn.recurrent.LSTM;
import ai.djl.timeseries.block.FeatureEmbedder;
import ai.djl.timeseries.block.MeanScaler;
import ai.djl.timeseries.block.NopScaler;
import ai.djl.timeseries.block.Scaler;
import ai.djl.timeseries.dataset.FieldName;
import ai.djl.timeseries.distribution.output.DistributionOutput;
import ai.djl.timeseries.distribution.output.StudentTOutput;
import ai.djl.timeseries.timefeature.Lag;
import ai.djl.timeseries.timefeature.TimeFeature;
import ai.djl.timeseries.transform.ExpectedNumInstanceSampler;
import ai.djl.timeseries.transform.InstanceSampler;
import ai.djl.timeseries.transform.PredictionSplitSampler;
import ai.djl.timeseries.transform.TimeSeriesTransform;
import ai.djl.timeseries.transform.convert.AsArray;
import ai.djl.timeseries.transform.convert.VstackFeatures;
import ai.djl.timeseries.transform.feature.AddAgeFeature;
import ai.djl.timeseries.transform.feature.AddObservedValuesIndicator;
import ai.djl.timeseries.transform.feature.AddTimeFeature;
import ai.djl.timeseries.transform.field.RemoveFields;
import ai.djl.timeseries.transform.field.SelectField;
import ai.djl.timeseries.transform.field.SetField;
import ai.djl.timeseries.transform.split.InstanceSplit;
import ai.djl.training.ParameterStore;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;

/**
 * Implements the deepar model.
 *
 * 

This closely follows the Salinas et al. 2020 and * its gluonts implementation. */ public abstract class DeepARNetwork extends AbstractBlock { private static final String[] TRAIN_INPUT_FIELDS = { FieldName.FEAT_STATIC_CAT.name(), FieldName.FEAT_STATIC_REAL.name(), "PAST_" + FieldName.FEAT_TIME.name(), "PAST_" + FieldName.TARGET.name(), "PAST_" + FieldName.OBSERVED_VALUES.name(), "PAST_" + FieldName.IS_PAD.name(), "FUTURE_" + FieldName.FEAT_TIME.name(), "FUTURE_" + FieldName.TARGET.name(), "FUTURE_" + FieldName.OBSERVED_VALUES.name() }; private static final String[] PRED_INPUT_FIELDS = { FieldName.FEAT_STATIC_CAT.name(), FieldName.FEAT_STATIC_REAL.name(), "PAST_" + FieldName.FEAT_TIME.name(), "PAST_" + FieldName.TARGET.name(), "PAST_" + FieldName.OBSERVED_VALUES.name(), "FUTURE_" + FieldName.FEAT_TIME.name(), "PAST_" + FieldName.IS_PAD.name() }; protected String freq; protected int historyLength; protected int contextLength; protected int predictionLength; protected boolean useFeatDynamicReal; protected boolean useFeatStaticCat; protected boolean useFeatStaticReal; protected DistributionOutput distrOutput; protected List cardinality; protected List embeddingDimension; protected List lagsSeq; protected int numParallelSamples; protected FeatureEmbedder embedder; protected Block paramProj; protected LSTM rnn; protected Scaler scaler; DeepARNetwork(Builder builder) { freq = builder.freq; predictionLength = builder.predictionLength; contextLength = builder.contextLength != 0 ? builder.contextLength : predictionLength; distrOutput = builder.distrOutput; cardinality = builder.cardinality; useFeatStaticReal = builder.useFeatStaticReal; useFeatDynamicReal = builder.useFeatDynamicReal; useFeatStaticCat = builder.useFeatStaticCat; numParallelSamples = builder.numParallelSamples; paramProj = addChildBlock("param_proj", distrOutput.getArgsProj()); if (builder.embeddingDimension != null || builder.cardinality == null) { embeddingDimension = builder.embeddingDimension; } else { embeddingDimension = new ArrayList<>(); for (int cat : cardinality) { embeddingDimension.add(Math.min(50, (cat + 1) / 2)); } } lagsSeq = builder.lagsSeq == null ? Lag.getLagsForFreq(builder.freq) : builder.lagsSeq; historyLength = contextLength + lagsSeq.stream().max(Comparator.naturalOrder()).get(); embedder = addChildBlock( "feature_embedder", FeatureEmbedder.builder() .setCardinalities(cardinality) .setEmbeddingDims(embeddingDimension) .build()); if (builder.scaling) { scaler = addChildBlock( "scaler", MeanScaler.builder() .setDim(1) .optKeepDim(true) .optMinimumScale(1e-10f) .build()); } else { scaler = addChildBlock("scaler", NopScaler.builder().setDim(1).optKeepDim(true).build()); } rnn = addChildBlock( "rnn_lstm", LSTM.builder() .setNumLayers(builder.numLayers) .setStateSize(builder.hiddenSize) .optDropRate(builder.dropRate) .optBatchFirst(true) .optReturnState(true) .build()); } /** {@inheritDoc} */ @Override protected void initializeChildBlocks( NDManager manager, DataType dataType, Shape... inputShapes) { Shape targetShape = inputShapes[3].slice(2); Shape contextShape = new Shape(1, contextLength).addAll(targetShape); scaler.initialize(manager, dataType, contextShape, contextShape); long scaleSize = scaler.getOutputShapes(new Shape[] {contextShape, contextShape})[1].get(1); embedder.initialize(manager, dataType, inputShapes[0]); long embeddedCatSize = embedder.getOutputShapes(new Shape[] {inputShapes[0]})[0].get(1); Shape inputShape = new Shape(1, contextLength * 2L - 1).addAll(targetShape); Shape lagsShape = inputShape.add(lagsSeq.size()); long featSize = inputShapes[2].get(2) + embeddedCatSize + inputShapes[1].get(1) + scaleSize; Shape rnnInputShape = lagsShape.slice(0, lagsShape.dimension() - 1).add(lagsShape.tail() + featSize); rnn.initialize(manager, dataType, rnnInputShape); Shape rnnOutShape = rnn.getOutputShapes(new Shape[] {rnnInputShape})[0]; paramProj.initialize(manager, dataType, rnnOutShape); } /** * Applies the underlying RNN to the provided target data and covariates. * * @param ps the parameter store * @param inputs the input NDList * @param training true for a training forward pass * @return a {@link NDList} containing arguments of the output distribution, scaling factor, raw * output of rnn, static input of rnn, output state of rnn */ protected NDList unrollLaggedRnn(ParameterStore ps, NDList inputs, boolean training) { try (NDManager scope = inputs.getManager().newSubManager()) { scope.tempAttachAll(inputs); NDArray featStaticCat = inputs.get(0); NDArray featStaticReal = inputs.get(1); NDArray pastTimeFeat = inputs.get(2); NDArray pastTarget = inputs.get(3); NDArray pastObservedValues = inputs.get(4); NDArray futureTimeFeat = inputs.get(5); NDArray futureTarget = inputs.size() > 6 ? inputs.get(6) : null; NDArray context = pastTarget.get(":,{}:", -contextLength); NDArray observedContext = pastObservedValues.get(":,{}:", -contextLength); NDArray scale = scaler.forward(ps, new NDList(context, observedContext), training).get(1); NDArray priorSequence = pastTarget.get(":,:{}", -contextLength).div(scale); NDArray sequence = futureTarget != null ? context.concat(futureTarget.get(":, :-1"), 1).div(scale) : context.div(scale); NDArray embeddedCat = embedder.forward(ps, new NDList(featStaticCat), training).singletonOrThrow(); NDArray staticFeat = NDArrays.concat(new NDList(embeddedCat, featStaticReal, scale.log()), 1); NDArray expandedStaticFeat = staticFeat.expandDims(1).repeat(1, sequence.getShape().get(1)); NDArray timeFeat = futureTimeFeat != null ? pastTimeFeat .get(":, {}:", -contextLength + 1) .concat(futureTimeFeat, 1) : pastTimeFeat.get(":, {}:", -contextLength + 1); NDArray features = expandedStaticFeat.concat(timeFeat, -1); NDArray lags = laggedSequenceValues(lagsSeq, priorSequence, sequence); NDArray rnnInput = lags.concat(features, -1); NDList outputs = rnn.forward(ps, new NDList(rnnInput), training); NDArray output = outputs.get(0); NDArray hiddenState = outputs.get(1); NDArray cellState = outputs.get(2); NDList params = paramProj.forward(ps, new NDList(output), training); scale.setName("scale"); output.setName("output"); staticFeat.setName("static_feat"); hiddenState.setName("hidden_state"); cellState.setName("cell_state"); return scope.ret( params.addAll(new NDList(scale, output, staticFeat, hiddenState, cellState))); } } /** * Construct an {@link NDArray} of lagged values from a given sequence. * * @param indices indices of lagged observations * @param priorSequence the input sequence prior to the time range for which the output is * required * @param sequence the input sequence in the time range where the output is required * @return the lagged values */ protected NDArray laggedSequenceValues( List indices, NDArray priorSequence, NDArray sequence) { if (Collections.max(indices) > (int) priorSequence.getShape().get(1)) { throw new IllegalArgumentException( String.format( "lags cannot go further than prior sequence length, found lag %d while" + " prior sequence is only %d-long", Collections.max(indices), priorSequence.getShape().get(1))); } try (NDManager scope = NDManager.subManagerOf(priorSequence)) { scope.tempAttachAll(priorSequence, sequence); NDArray fullSequence = priorSequence.concat(sequence, 1); NDList lagsValues = new NDList(indices.size()); for (int lagIndex : indices) { long begin = -lagIndex - sequence.getShape().get(1); long end = -lagIndex; lagsValues.add( end < 0 ? fullSequence.get(":, {}:{}", begin, end) : fullSequence.get(":, {}:", begin)); } NDArray lags = NDArrays.stack(lagsValues, -1); return scope.ret(lags.reshape(lags.getShape().get(0), lags.getShape().get(1), -1)); } } /** * Return the context length. * * @return the context length */ public int getContextLength() { return contextLength; } /** * Return the history length. * * @return the history length */ public int getHistoryLength() { return historyLength; } /** * Construct a training transformation of deepar model. * * @param manager the {@link NDManager} to create value * @return the transformation */ public List createTrainingTransformation(NDManager manager) { List transformation = createTransformation(manager); InstanceSampler sampler = new ExpectedNumInstanceSampler(0, 0, predictionLength, 1.0); transformation.add( new InstanceSplit( FieldName.TARGET, FieldName.IS_PAD, FieldName.START, FieldName.FORECAST_START, sampler, historyLength, predictionLength, new FieldName[] {FieldName.FEAT_TIME, FieldName.OBSERVED_VALUES}, distrOutput.getValueInSupport())); transformation.add(new SelectField(TRAIN_INPUT_FIELDS)); return transformation; } /** * Construct a prediction transformation of deepar model. * * @param manager the {@link NDManager} to create value * @return the transformation */ public List createPredictionTransformation(NDManager manager) { List transformation = createTransformation(manager); InstanceSampler sampler = PredictionSplitSampler.newValidationSplitSampler(); transformation.add( new InstanceSplit( FieldName.TARGET, FieldName.IS_PAD, FieldName.START, FieldName.FORECAST_START, sampler, historyLength, predictionLength, new FieldName[] {FieldName.FEAT_TIME, FieldName.OBSERVED_VALUES}, distrOutput.getValueInSupport())); transformation.add(new SelectField(PRED_INPUT_FIELDS)); return transformation; } private List createTransformation(NDManager manager) { List transformation = new ArrayList<>(); List removeFieldNames = new ArrayList<>(); removeFieldNames.add(FieldName.FEAT_DYNAMIC_CAT); if (!useFeatStaticReal) { removeFieldNames.add(FieldName.FEAT_STATIC_REAL); } if (!useFeatDynamicReal) { removeFieldNames.add(FieldName.FEAT_DYNAMIC_REAL); } transformation.add(new RemoveFields(removeFieldNames)); if (!useFeatStaticCat) { transformation.add( new SetField(FieldName.FEAT_STATIC_CAT, manager.zeros(new Shape(1)))); } if (!useFeatDynamicReal) { transformation.add( new SetField(FieldName.FEAT_STATIC_REAL, manager.zeros(new Shape(1)))); } transformation.add(new AsArray(FieldName.FEAT_STATIC_CAT, 1, DataType.INT32)); transformation.add(new AsArray(FieldName.FEAT_STATIC_REAL, 1)); transformation.add( new AddObservedValuesIndicator(FieldName.TARGET, FieldName.OBSERVED_VALUES)); transformation.add( new AddTimeFeature( FieldName.START, FieldName.TARGET, FieldName.FEAT_TIME, TimeFeature.timeFeaturesFromFreqStr(freq), predictionLength, freq)); transformation.add( new AddAgeFeature(FieldName.TARGET, FieldName.FEAT_AGE, predictionLength, true)); FieldName[] inputFields; if (!useFeatDynamicReal) { inputFields = new FieldName[] {FieldName.FEAT_TIME, FieldName.FEAT_AGE}; } else { inputFields = new FieldName[] { FieldName.FEAT_TIME, FieldName.FEAT_AGE, FieldName.FEAT_DYNAMIC_REAL }; } transformation.add(new VstackFeatures(FieldName.FEAT_TIME, inputFields)); return transformation; } /** * Create a builder to build a {@code DeepARTrainingNetwork} or {@code DeepARPredictionNetwork}. * * @return a new builder */ public static Builder builder() { return new Builder(); } /** * The builder to construct a {@code DeepARTrainingNetwork} or {@code DeepARPredictionNetwork}. * type of {@link ai.djl.nn.Block}. */ public static final class Builder { private String freq; private int contextLength; private int predictionLength; private int numParallelSamples = 100; private int numLayers = 2; private int hiddenSize = 40; private float dropRate = 0.1f; private boolean useFeatDynamicReal; private boolean useFeatStaticCat; private boolean useFeatStaticReal; private boolean scaling = true; private DistributionOutput distrOutput = new StudentTOutput(); private List cardinality; private List embeddingDimension; private List lagsSeq; /** * Set the prediction frequency. * * @param freq the frequency * @return this builder */ public Builder setFreq(String freq) { this.freq = freq; return this; } /** * Set the prediction length. * * @param predictionLength the prediction length * @return this builder */ public Builder setPredictionLength(int predictionLength) { this.predictionLength = predictionLength; return this; } /** * Set the cardinality for static categorical feature. * * @param cardinality the cardinality * @return this builder */ public Builder setCardinality(List cardinality) { this.cardinality = cardinality; return this; } /** * Set the optional {@link DistributionOutput} default {@link StudentTOutput}. * * @param distrOutput the {@link DistributionOutput} * @return this builder */ public Builder optDistrOutput(DistributionOutput distrOutput) { this.distrOutput = distrOutput; return this; } /** * Set the optional context length. * * @param contextLength the context length * @return this builder */ public Builder optContextLength(int contextLength) { this.contextLength = contextLength; return this; } /** * Set the optional number parallel samples. * * @param numParallelSamples the num parallel samples * @return this builder */ public Builder optNumParallelSamples(int numParallelSamples) { this.numParallelSamples = numParallelSamples; return this; } /** * Set the optional number of rnn layers. * * @param numLayers the number of rnn layers * @return this builder */ public Builder optNumLayers(int numLayers) { this.numLayers = numLayers; return this; } /** * Set the optional number of rnn hidden size. * * @param hiddenSize the number of rnn hidden size * @return this builder */ public Builder optHiddenSize(int hiddenSize) { this.hiddenSize = hiddenSize; return this; } /** * Set the optional number of rnn drop rate. * * @param dropRate the number of rnn drop rate * @return this builder */ public Builder optDropRate(float dropRate) { this.dropRate = dropRate; return this; } /** * Set the optional embedding dimension. * * @param embeddingDimension the embedding dimension * @return this builder */ public Builder optEmbeddingDimension(List embeddingDimension) { this.embeddingDimension = embeddingDimension; return this; } /** * Set the optional lags sequence, default generate from frequency. * * @param lagsSeq the lags sequence * @return this builder */ public Builder optLagsSeq(List lagsSeq) { this.lagsSeq = lagsSeq; return this; } /** * Set whether to use dynamic real feature. * * @param useFeatDynamicReal whether to use dynamic real feature * @return this builder */ public Builder optUseFeatDynamicReal(boolean useFeatDynamicReal) { this.useFeatDynamicReal = useFeatDynamicReal; return this; } /** * Set whether to use static categorical feature. * * @param useFeatStaticCat whether to use static categorical feature * @return this builder */ public Builder optUseFeatStaticCat(boolean useFeatStaticCat) { this.useFeatStaticCat = useFeatStaticCat; return this; } /** * Set whether to use static real feature. * * @param useFeatStaticReal whether to use static real feature * @return this builder */ public Builder optUseFeatStaticReal(boolean useFeatStaticReal) { this.useFeatStaticReal = useFeatStaticReal; return this; } /** * Build a {@link DeepARTrainingNetwork} block. * * @return the {@link DeepARTrainingNetwork} block. */ public DeepARTrainingNetwork buildTrainingNetwork() { return new DeepARTrainingNetwork(this); } /** * Build a {@link DeepARPredictionNetwork} block. * * @return the {@link DeepARPredictionNetwork} block. */ public DeepARPredictionNetwork buildPredictionNetwork() { return new DeepARPredictionNetwork(this); } } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy