org.deeplearning4j.arbiter.saver.local.graph.LocalComputationGraphSaver Maven / Gradle / Ivy
/*
*
* * Copyright 2016 Skymind,Inc.
* *
* * Licensed under the Apache License, Version 2.0 (the "License");
* * you may not use this file except in compliance with the License.
* * You may obtain a copy of the License at
* *
* * http://www.apache.org/licenses/LICENSE-2.0
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS,
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* * See the License for the specific language governing permissions and
* * limitations under the License.
*
*/
package org.deeplearning4j.arbiter.saver.local.graph;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.deeplearning4j.arbiter.GraphConfiguration;
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
import org.deeplearning4j.arbiter.optimize.api.saving.ResultReference;
import org.deeplearning4j.arbiter.optimize.api.saving.ResultSaver;
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.*;
import java.nio.file.Files;
/**Basic MultiLayerNetwork saver. Saves config, parameters and score to: baseDir/0/, baseDir/1/, etc
* where index is given by OptimizationResult.getIndex()
*/
public class LocalComputationGraphSaver implements ResultSaver {
private static Logger log = LoggerFactory.getLogger(LocalComputationGraphSaver.class);
private String path;
private File fPath;
public LocalComputationGraphSaver(String path){
if(path==null) throw new NullPointerException();
this.path = path;
this.fPath = new File(path);
File baseDirectory = new File(path);
if(!baseDirectory.isDirectory() ){
throw new IllegalArgumentException("Invalid path: is not directory. " + path);
}
log.info("LocalComputationGraphSaver saving networks to local directory: {}",path);
}
@Override
public ResultReference saveModel(OptimizationResult result) throws IOException {
String dir = new File(path,result.getIndex() + "/").getAbsolutePath();
File f = new File(dir);
f.mkdir();
File paramsFile = new File(FilenameUtils.concat(dir,"params.bin"));
File jsonFile = new File(FilenameUtils.concat(dir,"config.json"));
File scoreFile = new File(FilenameUtils.concat(dir,"score.txt"));
File additionalResultsFile = new File(FilenameUtils.concat(dir,"additionalResults.bin"));
File esConfigFile = new File(FilenameUtils.concat(dir,"earlyStoppingConfig.bin"));
File numEpochsFile = new File(FilenameUtils.concat(dir,"numEpochs.txt"));
FileUtils.writeStringToFile(scoreFile, String.valueOf(result.getScore()));
String jsonConfig = result.getCandidate().getValue().getConfiguration().toJson();
FileUtils.writeStringToFile(jsonFile, jsonConfig);
if(result.getResult() != null) {
INDArray params = result.getResult().params();
try (DataOutputStream dos = new DataOutputStream(Files.newOutputStream(paramsFile.toPath()))) {
Nd4j.write(params, dos);
}
}
A additionalResults = result.getModelSpecificResults();
if(additionalResults != null) {
try(ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(additionalResultsFile))) {
oos.writeObject(additionalResults);
}
}
//Write early stopping configuration (if present) to file:
EarlyStoppingConfiguration esc = result.getCandidate().getValue().getEarlyStoppingConfiguration();
if(esc != null){
try(ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(esConfigFile))){
oos.writeObject(esc);
}
} else {
int nEpochs = result.getCandidate().getValue().getNumEpochs();
FileUtils.writeStringToFile(numEpochsFile,String.valueOf(nEpochs));
}
log.debug("Deeplearning4j model result (id={}, score={}) saved to directory: {}",result.getIndex(), result.getScore(), dir);
return new LocalFileGraphResultReference<>(result.getIndex(),dir,
jsonFile,
paramsFile,
scoreFile,
additionalResultsFile,
esConfigFile,
numEpochsFile,
result.getCandidate());
}
@Override
public String toString(){
return "LocalMultiLayerNetworkScoreSaver(path=" + fPath.getAbsolutePath() + ")";
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy