All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.training.hyperparameter.EasyHpo Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.training.hyperparameter;

import ai.djl.Model;
import ai.djl.metric.Metrics;
import ai.djl.ndarray.types.Shape;
import ai.djl.training.EasyTrain;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
import ai.djl.training.TrainingResult;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.hyperparameter.optimizer.HpORandom;
import ai.djl.training.hyperparameter.optimizer.HpOptimizer;
import ai.djl.training.hyperparameter.param.HpSet;
import ai.djl.translate.TranslateException;
import ai.djl.util.Pair;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;

/** Helper for easy training with hyperparameters. */
public abstract class EasyHpo {

    private static final Logger logger = LoggerFactory.getLogger(EasyHpo.class);

    /**
     * Fits the model given the implemented abstract methods.
     *
     * @return the best model and training results
     * @throws IOException for various exceptions depending on the dataset
     * @throws TranslateException if there is an error while processing input
     */
    public Pair fit() throws IOException, TranslateException {

        // get training and validation dataset
        RandomAccessDataset trainingSet = getDataset(Dataset.Usage.TRAIN);
        RandomAccessDataset validateSet = getDataset(Dataset.Usage.TEST);

        HpSet hyperParams = setupHyperParams();
        HpOptimizer hpOptimizer = new HpORandom(hyperParams);

        final int hyperparameterTests = numHyperParameterTests();

        for (int i = 0; i < hyperparameterTests; i++) {
            HpSet hpVals = hpOptimizer.nextConfig();
            Pair trained = train(hpVals, trainingSet, validateSet);
            trained.getKey().close();
            float loss = trained.getValue().getValidateLoss();
            hpOptimizer.update(hpVals, loss);
            logger.info(
                    "--------- hp test {}/{} - Loss {} - {}", i, hyperparameterTests, loss, hpVals);
        }

        HpSet bestHpVals = hpOptimizer.getBest().getKey();
        Pair trained = train(bestHpVals, trainingSet, validateSet);
        TrainingResult result = trained.getValue();

        Model model = trained.getKey();
        saveModel(model, result);
        return trained;
    }

    private Pair train(
            HpSet hpVals, RandomAccessDataset trainingSet, RandomAccessDataset validateSet)
            throws IOException, TranslateException {

        // Construct neural network
        Model model = buildModel(hpVals);

        // setup training configuration
        TrainingConfig config = setupTrainingConfig(hpVals);

        try (Trainer trainer = model.newTrainer(config)) {
            trainer.setMetrics(new Metrics());

            // initialize trainer with proper input shape
            trainer.initialize(inputShape(hpVals));

            EasyTrain.fit(trainer, numEpochs(hpVals), trainingSet, validateSet);

            TrainingResult result = trainer.getTrainingResult();
            return new Pair<>(model, result);
        }
    }

    /**
     * Returns the initial hyperparameters.
     *
     * @return the initial hyperparameters
     */
    protected abstract HpSet setupHyperParams();

    /**
     * Returns the dataset to train with.
     *
     * @param usage the usage of the dataset
     * @return the dataset to train with
     * @throws IOException if the dataset could not be loaded
     */
    protected abstract RandomAccessDataset getDataset(Dataset.Usage usage) throws IOException;

    /**
     * Returns the {@link ai.djl.training.TrainingConfig} to use to train each hyperparameter set.
     *
     * @param hpVals the hyperparameters to train with
     * @return the {@link ai.djl.training.TrainingConfig} to use to train each hyperparameter set
     */
    protected abstract TrainingConfig setupTrainingConfig(HpSet hpVals);

    /**
     * Builds the {@link Model} and {@link ai.djl.nn.Block} to train.
     *
     * @param hpVals the hyperparameter values to use for the model
     * @return the model to train
     */
    protected abstract Model buildModel(HpSet hpVals);

    /**
     * Returns the input shape for the model.
     *
     * @param hpVals the hyperparameter values for the model
     * @return returns the model input shape
     */
    protected abstract Shape inputShape(HpSet hpVals);

    /**
     * Returns the number of epochs to train for the current hyperparameter set.
     *
     * @param hpVals the current hyperparameter set
     * @return the number of epochs
     */
    protected abstract int numEpochs(HpSet hpVals);

    /**
     * Returns the number of hyperparameter sets to train with.
     *
     * @return the number of hyperparameter sets to train with
     */
    protected abstract int numHyperParameterTests();

    /**
     * Saves the best hyperparameter set.
     *
     * @param model the model to save
     * @param result the training result for training with this model's hyperparameters
     * @throws IOException if the model could not be saved
     */
    protected void saveModel(Model model, TrainingResult result) throws IOException {}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy