All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.autodiff.listeners.checkpoint.CheckpointListener Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.nd4j.autodiff.listeners.checkpoint;


import org.nd4j.shade.guava.io.Files;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.BaseListener;
import org.nd4j.autodiff.listeners.ListenerResponse;
import org.nd4j.autodiff.listeners.Loss;
import org.nd4j.autodiff.listeners.records.LossCurve;
import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.dataset.api.MultiDataSet;

import java.io.*;
import java.nio.charset.StandardCharsets;
import java.text.SimpleDateFormat;
import java.util.*;
import java.util.concurrent.TimeUnit;

@Slf4j
public class CheckpointListener extends BaseListener implements Serializable {

    private enum KeepMode {ALL, LAST, LAST_AND_EVERY};

    private File rootDir;
    private String fileNamePrefix;
    private KeepMode keepMode;
    private int keepLast;
    private int keepEvery;
    private boolean logSaving;
    private boolean deleteExisting;
    private boolean saveUpdaterState;

    private Integer saveEveryNEpochs;
    private Integer saveEveryNIterations;
    private boolean saveEveryNIterSinceLast;
    private Long saveEveryAmount;
    private TimeUnit saveEveryUnit;
    private Long saveEveryMs;
    private boolean saveEverySinceLast;

    private int lastCheckpointNum = -1;
    private File checkpointRecordFile;

    private Checkpoint lastCheckpoint;
    private long startTime = -1;
    private int startIter = -1;
    private Long lastSaveEveryMsNoSinceLast;

    private CheckpointListener(Builder builder){
        this.rootDir = builder.rootDir;
        this.fileNamePrefix = builder.fileNamePrefix;
        this.keepMode = builder.keepMode;
        this.keepLast = builder.keepLast;
        this.keepEvery = builder.keepEvery;
        this.logSaving = builder.logSaving;
        this.deleteExisting = builder.deleteExisting;
        this.saveUpdaterState = builder.saveUpdaterState;

        this.saveEveryNEpochs = builder.saveEveryNEpochs;
        this.saveEveryNIterations = builder.saveEveryNIterations;
        this.saveEveryNIterSinceLast = builder.saveEveryNIterSinceLast;
        this.saveEveryAmount = builder.saveEveryAmount;
        this.saveEveryUnit = builder.saveEveryUnit;
        this.saveEverySinceLast = builder.saveEverySinceLast;

        if(saveEveryAmount != null){
            saveEveryMs = TimeUnit.MILLISECONDS.convert(saveEveryAmount, saveEveryUnit);
        }

        if(!rootDir.exists()){
            rootDir.mkdir();
        }

        this.checkpointRecordFile = new File(rootDir, "checkpointInfo.txt");
        if(this.checkpointRecordFile.exists() && this.checkpointRecordFile.length() > 0){

            if(deleteExisting){
                //Delete any files matching:
                //"checkpoint_" + checkpointNum + "_" + modelType + ".zip";
                this.checkpointRecordFile.delete();
                File[] files = rootDir.listFiles();
                if(files != null && files.length > 0){
                    for(File f : files){
                        String name = f.getName();
                        if(name.startsWith("checkpoint_") && (name.endsWith("MultiLayerNetwork.zip") || name.endsWith("ComputationGraph.zip"))){
                            f.delete();
                        }
                    }
                }
            } else {
                throw new IllegalStateException("Detected existing checkpoint files at directory " + rootDir.getAbsolutePath() +
                        ". Use deleteExisting(true) to delete existing checkpoint files when present.");
            }
        }
    }

    @Override
    public ListenerResponse epochEnd(SameDiff sameDiff, At at, LossCurve lossCurve, long epochTimeMillis) {
        if(saveEveryNEpochs != null && (at.epoch()+1) % saveEveryNEpochs == 0){
            //Save:
            saveCheckpoint(sameDiff, at);
        }
        //General saving conditions: don't need to check here - will check in iterationDone
        return ListenerResponse.CONTINUE;
    }

    @Override
    public boolean isActive(Operation operation) {
        return operation == Operation.TRAINING;
    }

    @Override
    public void iterationDone(SameDiff sd, At at, MultiDataSet dataSet, Loss loss) {
        if (startTime < 0) {
            startTime = System.currentTimeMillis();
            startIter = at.iteration();
            return;
        }

        //Check iterations saving condition:
        if(saveEveryNIterations != null){
            if(saveEveryNIterSinceLast){
                //Consider last saved model when deciding whether to save
                long lastSaveIter = (lastCheckpoint != null ? lastCheckpoint.getIteration() : startIter);
                if(at.iteration() - lastSaveIter >= saveEveryNIterations){
                    saveCheckpoint(sd, at);
                    return;
                }
            } else {
                //Same every N iterations, regardless of saving time
                if((at.iteration()+1) % saveEveryNIterations == 0){
                    saveCheckpoint(sd, at);
                    return;
                }
            }
        }

        //Check time saving condition:
        long time = System.currentTimeMillis();
        if(saveEveryUnit != null){
            if(saveEverySinceLast){
                //Consider last saved when deciding whether to save
                long lastSaveTime = (lastCheckpoint != null ? lastCheckpoint.getTimestamp() : startTime);
                if((time - lastSaveTime) >= saveEveryMs){
                    saveCheckpoint(sd, at);
                    return;
                }
            } else {
                //Save periodically, regardless of when last model was saved
                long lastSave = (lastSaveEveryMsNoSinceLast != null ? lastSaveEveryMsNoSinceLast : startTime);
                if((time - lastSave) > saveEveryMs){
                    saveCheckpoint(sd, at);
                    lastSaveEveryMsNoSinceLast = time;
                    return;
                }
            }
        }
    }

    private void saveCheckpoint(SameDiff sd, At at) {
        try{
            saveCheckpointHelper(sd, at);
        } catch (Exception e){
            throw new RuntimeException("Error saving checkpoint", e);
        }
    }

    private void saveCheckpointHelper(SameDiff model, At at) throws Exception {
        if(!checkpointRecordFile.exists()){
            checkpointRecordFile.createNewFile();
            writeCheckpointInfo(Checkpoint.getFileHeader() + "\n", checkpointRecordFile);
        }

        Checkpoint c = new Checkpoint(++lastCheckpointNum, System.currentTimeMillis(), at.iteration(), at.epoch(),null);
        String filename = getFileName(lastCheckpointNum, at, c.getTimestamp());
        c.setFilename(filename);

        File saveFile = new File(rootDir, c.getFilename());
        model.save(saveFile, this.saveUpdaterState);

        String s = c.toFileString();
        writeCheckpointInfo(s + "\n", checkpointRecordFile);

        if(logSaving){
            log.info("Model checkpoint saved: epoch {}, iteration {}, path: {}", c.getEpoch(), c.getIteration(),
                    new File(rootDir, c.getFilename()).getPath() );
        }
        this.lastCheckpoint = c;


        //Finally: determine if we should delete some old models...
        if(keepMode == null || keepMode == KeepMode.ALL){
            return;
        } else if(keepMode == KeepMode.LAST){
            List checkpoints = availableCheckpoints();
            Iterator iter = checkpoints.iterator();
            while(checkpoints.size() > keepLast){
                Checkpoint toRemove = iter.next();
                File f = getFileForCheckpoint(toRemove);
                f.delete();
                iter.remove();
            }
        } else {
            //Keep mode: last N and every M
            for(Checkpoint cp : availableCheckpoints()){
                if(cp.getCheckpointNum() > 0 && (cp.getCheckpointNum()+1) % keepEvery == 0){
                    //One of the "every M to keep" models
                    continue;
                } else if(cp.getCheckpointNum() > lastCheckpointNum - keepLast ){        //Example: latest is 5, keep last 2 -> keep checkpoints 4 and 5
                    //One of last N to keep
                    continue;
                }
                //Otherwise: delete file
                File f = getFileForCheckpoint(cp);
                f.delete();
            }
        }
    }

    //Filename format: "_checkpoint-#_epoch-#_iter-#_YYYY-MM-dd_HH-MM-ss.bin"
    private String getFileName(int checkpointNum, At at, long time){
        StringBuilder sb = new StringBuilder();
        if(fileNamePrefix != null){
            sb.append(fileNamePrefix);
            if(!fileNamePrefix.endsWith("_")){
                sb.append("_");
            }
        }
        sb.append("checkpoint-")
                .append(checkpointNum)
                .append("_epoch-").append(at.epoch())
                .append("_iter-").append(at.iteration());

        SimpleDateFormat sdf = new SimpleDateFormat("YYYY-MM-dd_HH-mm-ss");
        String date = sdf.format(new Date(time));
        sb.append("_").append(date)
                .append(".bin");

        return sb.toString();
    }

    private static String writeCheckpointInfo(String str, File f){
        try {
            if(!f.exists()){
                f.createNewFile();
            }
            Files.append(str, f, StandardCharsets.UTF_8);
        } catch (IOException e){
            throw new RuntimeException(e);
        }
        return str;
    }

    /**
     * List all available checkpoints. A checkpoint is 'available' if the file can be loaded. Any checkpoint files that
     * have been automatically deleted (given the configuration) will not be returned here.
     *
     * @return List of checkpoint files that can be loaded
     */
    public List availableCheckpoints(){
        if(!checkpointRecordFile.exists()){
            return Collections.emptyList();
        }

        return availableCheckpoints(rootDir);
    }

    /**
     * List all available checkpoints. A checkpoint is 'available' if the file can be loaded. Any checkpoint files that
     * have been automatically deleted (given the configuration) will not be returned here.
     * Note that the checkpointInfo.txt file must exist, as this stores checkpoint information
     *
     * @return List of checkpoint files that can be loaded from the specified directory
     */
    public static List availableCheckpoints(File directory){
        File checkpointRecordFile = new File(directory, "checkpointInfo.txt");
        Preconditions.checkState(checkpointRecordFile.exists(), "Could not find checkpoint record file at expected path %s", checkpointRecordFile.getAbsolutePath());

        List lines;
        try(InputStream is = new BufferedInputStream(new FileInputStream(checkpointRecordFile))){
            lines = IOUtils.readLines(is);
        } catch (IOException e){
            throw new RuntimeException("Error loading checkpoint data from file: " + checkpointRecordFile.getAbsolutePath(), e);
        }

        List out = new ArrayList<>(lines.size()-1); //Assume first line is header
        for( int i=1; i all = availableCheckpoints(rootDir);
        if(all.isEmpty()){
            return null;
        }
        return all.get(all.size()-1);
    }

    /**
     * Get the model file for the given checkpoint. Checkpoint model file must exist
     *
     * @param checkpoint Checkpoint to get the model file for
     * @return Model file for the checkpoint
     */
    public File getFileForCheckpoint(Checkpoint checkpoint){
        return getFileForCheckpoint(checkpoint.getCheckpointNum());
    }

    /**
     * Get the model file for the given checkpoint number. Checkpoint model file must exist
     *
     * @param checkpointNum Checkpoint number to get the model file for
     * @return Model file for the checkpoint
     */
    public File getFileForCheckpoint(int checkpointNum) {
        return getFileForCheckpoint(rootDir, checkpointNum);
    }

    public static File getFileForCheckpoint(File rootDir, int checkpointNum){
        //Scan the root directory, for a file matching the checkpoint filename pattern:
        //Filename format: "_checkpoint-#_epoch-#_iter-#_YYYY-MM-dd_HH-MM-ss.bin"

        if(checkpointNum < 0){
            throw new IllegalArgumentException("Invalid checkpoint number: " + checkpointNum);
        }

        String contains = "_checkpoint-" + checkpointNum + "_epoch-";

        File[] allFiles = rootDir.listFiles();
        if(allFiles != null){
            for(File f : allFiles){
                if(f.getAbsolutePath().contains(contains)){
                    return f;
                }
            }
        }

        throw new IllegalStateException("Model file for checkpoint " + checkpointNum + " does not exist");
    }

    /**
     * Load a given checkpoint number
     *
     * @param loadUpdaterState If true: load the updater state. See {@link SameDiff#load(File, boolean)} for more details
     *
     */
    public SameDiff loadCheckpoint(int checkpointNum, boolean loadUpdaterState){
        return loadCheckpoint(rootDir, checkpointNum, loadUpdaterState);
    }

    /**
     * Load a SameDiff instance for the given checkpoint that resides in the specified root directory
     *
     * @param rootDir          Directory that the checkpoint resides in
     * @param checkpointNum    Checkpoint model number to load
     * @param loadUpdaterState If true: load the updater state. See {@link SameDiff#load(File, boolean)} for more details
     * @return The loaded model
     */
    public static SameDiff loadCheckpoint(File rootDir, int checkpointNum, boolean loadUpdaterState) {
        File f = getFileForCheckpoint(rootDir, checkpointNum);
        return SameDiff.load(f, loadUpdaterState);
    }

    /**
     * Load the last (most recent) checkpoint from the specified root directory
     * @param rootDir Root directory to load checpoint from
     * @return ComputationGraph for last checkpoint
     */
    public static SameDiff loadLastCheckpoint(File rootDir, boolean loadUpdaterState){
        Checkpoint last = lastCheckpoint(rootDir);
        return loadCheckpoint(rootDir, last.getCheckpointNum(), loadUpdaterState);
    }

    public static Builder builder(@NonNull File rootDir){
        return new Builder(rootDir);
    }

    public static class Builder {

        private File rootDir;
        private String fileNamePrefix = "SameDiff";
        private KeepMode keepMode;
        private int keepLast;
        private int keepEvery;
        private boolean saveUpdaterState = true;
        private boolean logSaving = true;
        private boolean deleteExisting = false;

        private Integer saveEveryNEpochs;
        private Integer saveEveryNIterations;
        private boolean saveEveryNIterSinceLast;
        private Long saveEveryAmount;
        private TimeUnit saveEveryUnit;
        private boolean saveEverySinceLast;

        /**
         * @param rootDir Root directory to save models to
         */
        public Builder(@NonNull String rootDir){
            this(new File(rootDir));
        }

        /**
         * @param rootDir Root directory to save models to
         */
        public Builder(@NonNull File rootDir){
            this.rootDir = rootDir;
        }

        public Builder fileNamePrefix(String fileNamePrefix){
            this.fileNamePrefix = fileNamePrefix;
            return this;
        }

        /**
         * Save a model at the end of every epoch
         */
        public Builder saveEveryEpoch(){
            return saveEveryNEpochs(1);
        }

        /**
         * Save a model at the end of every N epochs
         */
        public Builder saveEveryNEpochs(int n){
            this.saveEveryNEpochs = n;
            return this;
        }

        /**
         * Save a model every N iterations
         */
        public Builder saveEveryNIterations(int n){
            return saveEveryNIterations(n, false);
        }

        /**
         * Save a model every N iterations (if sinceLast == false), or if N iterations have passed since
         * the last model vas saved (if sinceLast == true)
         */
        public Builder saveEveryNIterations(int n, boolean sinceLast){
            this.saveEveryNIterations = n;
            this.saveEveryNIterSinceLast = sinceLast;
            return this;
        }

        /**
         * Save a model periodically
         *
         * @param amount   Quantity of the specified time unit
         * @param timeUnit Time unit
         */
        public Builder saveEvery(long amount, TimeUnit timeUnit){
            return saveEvery(amount, timeUnit, false);
        }

        /**
         * Save a model periodically (if sinceLast == false), or if the specified amount of time has elapsed since
         * the last model was saved (if sinceLast == true)
         *
         * @param amount   Quantity of the specified time unit
         * @param timeUnit Time unit
         */
        public Builder saveEvery(long amount, TimeUnit timeUnit, boolean sinceLast){
            this.saveEveryAmount = amount;
            this.saveEveryUnit = timeUnit;
            this.saveEverySinceLast = sinceLast;
            return this;
        }

        /**
         * Keep all model checkpoints - i.e., don't delete any. Note that this is the default.
         */
        public Builder keepAll(){
            this.keepMode = KeepMode.ALL;
            return this;
        }

        /**
         * Keep only the last N most recent model checkpoint files. Older checkpoints will automatically be deleted.
         * @param n Number of most recent checkpoints to keep
         */
        public Builder keepLast(int n){
            if(n <= 0){
                throw new IllegalArgumentException("Number of model files to keep should be > 0 (got: " + n + ")");
            }
            this.keepMode = KeepMode.LAST;
            this.keepLast = n;
            return this;
        }

        /**
         * Keep the last N most recent model checkpoint files, and every M checkpoint files.
* For example: suppose you save every 100 iterations, for 2050 iteration, and use keepLastAndEvery(3,5). * This means after 2050 iterations you would have saved 20 checkpoints - some of which will be deleted. * Those remaining in this example: iterations 500, 1000, 1500, 1800, 1900, 2000. * @param nLast Most recent checkpoints to keep * @param everyN Every N checkpoints to keep (regardless of age) */ public Builder keepLastAndEvery(int nLast, int everyN){ if(nLast <= 0){ throw new IllegalArgumentException("Most recent number of model files to keep should be > 0 (got: " + nLast + ")"); } if(everyN <= 0){ throw new IllegalArgumentException("Every n model files to keep should be > 0 (got: " + everyN + ")"); } this.keepMode = KeepMode.LAST_AND_EVERY; this.keepLast = nLast; this.keepEvery = everyN; return this; } /** * If true (the default) log a message every time a model is saved * * @param logSaving Whether checkpoint saves should be logged or not */ public Builder logSaving(boolean logSaving){ this.logSaving = logSaving; return this; } /** * Whether the updater state (history/state for Adam, Nesterov momentum, etc) should be saved with each checkpoint.
* Updater state is saved by default. * If you expect to continue training on any of the checkpoints, this should be set to true. However, it will increase * the file size. * @param saveUpdaterState If true: updater state will be saved with checkpoints. False: not saved. */ public Builder saveUpdaterState(boolean saveUpdaterState){ this.saveUpdaterState = saveUpdaterState; return this; } /** * If the checkpoint listener is set to save to a non-empty directory, should the CheckpointListener-related * content be deleted?
* This is disabled by default (and instead, an exception will be thrown if existing data is found)
* WARNING: Be careful when enabling this, as it deletes all saved checkpoint models in the specified directory! */ public Builder deleteExisting(boolean deleteExisting){ this.deleteExisting = deleteExisting; return this; } public CheckpointListener build(){ if(saveEveryNEpochs == null && saveEveryAmount == null && saveEveryNIterations == null){ throw new IllegalStateException("Cannot construct listener: no models will be saved (must use at least" + " one of: save every N epochs, every N iterations, or every T time periods)"); } return new CheckpointListener(this); } } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy