All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.arbiter.saver.local.FileModelSaver Maven / Gradle / Ivy

There is a newer version: 1.0.0-beta7
Show newest version
/*******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package org.deeplearning4j.arbiter.saver.local;

import lombok.AllArgsConstructor;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.deeplearning4j.arbiter.DL4JConfiguration;
import org.deeplearning4j.arbiter.GraphConfiguration;
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
import org.deeplearning4j.arbiter.optimize.api.saving.ResultReference;
import org.deeplearning4j.arbiter.optimize.api.saving.ResultSaver;
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.shade.jackson.annotation.JsonCreator;
import org.nd4j.shade.jackson.annotation.JsonProperty;

import java.io.*;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

/**
 * Basic MultiLayerNetwork saver. Saves config, parameters and score to: baseDir/0/, baseDir/1/, etc
 * where index is given by OptimizationResult.getIndex()
 *
 * @author Alex Black
 */
@Slf4j
@NoArgsConstructor
@AllArgsConstructor
@EqualsAndHashCode
public class FileModelSaver implements ResultSaver {
    @JsonProperty
    private String path;
    private File fPath;

    @JsonCreator
    public FileModelSaver(@NonNull String path) {
        this(new File(path));
    }

    public FileModelSaver(@NonNull File file){
        this.path = file.getPath();
        this.fPath = file;

        if(!fPath.exists()){
            fPath.mkdirs();
        } else if (!fPath.isDirectory()) {
            throw new IllegalArgumentException("Invalid path: exists and is not directory. " + path);
        }

        log.info("FileModelSaver saving networks to local directory: {}", path);
    }

    @Override
    public ResultReference saveModel(OptimizationResult result, Object modelResult) throws IOException {
        String dir = new File(path, result.getIndex() + "/").getAbsolutePath();

        File f = new File(dir);
        f.mkdir();

        File modelFile = new File(FilenameUtils.concat(dir, "model.bin"));
        File scoreFile = new File(FilenameUtils.concat(dir, "score.txt"));
        File additionalResultsFile = new File(FilenameUtils.concat(dir, "additionalResults.bin"));
        File esConfigFile = new File(FilenameUtils.concat(dir, "earlyStoppingConfig.bin"));
        File numEpochsFile = new File(FilenameUtils.concat(dir, "numEpochs.txt"));

        FileUtils.writeStringToFile(scoreFile, String.valueOf(result.getScore()));

        Model m = (Model) modelResult;
        ModelSerializer.writeModel(m, modelFile, true);


        Object additionalResults = result.getModelSpecificResults();
        if (additionalResults != null && additionalResults instanceof Serializable) {
            try (ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(additionalResultsFile))) {
                oos.writeObject(additionalResults);
            }
        }

        //Write early stopping configuration (if present) to file:
        int nEpochs;
        EarlyStoppingConfiguration esc;
        if (result.getCandidate().getValue() instanceof DL4JConfiguration) {
            DL4JConfiguration c = ((DL4JConfiguration) result.getCandidate().getValue());
            esc = c.getEarlyStoppingConfiguration();
            nEpochs = c.getNumEpochs();
        } else {
            GraphConfiguration c = ((GraphConfiguration) result.getCandidate().getValue());
            esc = c.getEarlyStoppingConfiguration();
            nEpochs = c.getNumEpochs();
        }


        if (esc != null) {
            try (ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(esConfigFile))) {
                oos.writeObject(esc);
            }
        } else {
            FileUtils.writeStringToFile(numEpochsFile, String.valueOf(nEpochs));
        }

        log.debug("Deeplearning4j model result (id={}, score={}) saved to directory: {}", result.getIndex(),
                        result.getScore(), dir);

        boolean isGraph = m instanceof ComputationGraph;
        return new LocalFileNetResultReference(result.getIndex(), dir, isGraph, modelFile, scoreFile,
                        additionalResultsFile, esConfigFile, numEpochsFile, result.getCandidate());
    }

    @Override
    public List> getSupportedCandidateTypes() {
        return Collections.>singletonList(Object.class);
    }

    @Override
    public List> getSupportedModelTypes() {
        return Arrays.>asList(MultiLayerNetwork.class, ComputationGraph.class);
    }

    @Override
    public String toString() {
        return "FileModelSaver(path=" + path + ")";
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy