All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.timeseries.transform.TimeSeriesTransform Maven / Gradle / Ivy

/*
 * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.timeseries.transform;

import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.timeseries.TimeSeriesData;
import ai.djl.timeseries.dataset.FieldName;

import java.util.ArrayList;
import java.util.List;

/** This interface is used for data transformation on the {@link TimeSeriesData}. */
public interface TimeSeriesTransform {

    /**
     * Transform process on TimeSeriesData.
     *
     * @param manager The default manager for data process
     * @param data The data to be operated on
     * @param isTrain Whether it is training
     * @return The result {@link TimeSeriesData}.
     */
    TimeSeriesData transform(NDManager manager, TimeSeriesData data, boolean isTrain);

    /**
     * Construct a list of {@link TimeSeriesTransform} that performs identity function.
     *
     * @return a list of identity {@link TimeSeriesTransform}
     */
    static List identityTransformation() {
        List ret = new ArrayList<>();
        ret.add(new IdentityTransform());
        return ret;
    }

    /** An identity transformation. */
    class IdentityTransform implements TimeSeriesTransform {

        /** {@inheritDoc} */
        @Override
        public TimeSeriesData transform(NDManager manager, TimeSeriesData data, boolean isTrain) {
            data.setField("PAST_" + FieldName.TARGET, data.get(FieldName.TARGET));
            data.setField("FUTURE_" + FieldName.TARGET, manager.create(new Shape(0)));
            return data;
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy