org.deeplearning4j.arbiter.saver.local.LocalFileNetResultReference Maven / Gradle / Ivy
/*-
*
* * Copyright 2016 Skymind,Inc.
* *
* * Licensed under the Apache License, Version 2.0 (the "License");
* * you may not use this file except in compliance with the License.
* * You may obtain a copy of the License at
* *
* * http://www.apache.org/licenses/LICENSE-2.0
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS,
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* * See the License for the specific language governing permissions and
* * limitations under the License.
*
*/
package org.deeplearning4j.arbiter.saver.local;
import lombok.AllArgsConstructor;
import org.apache.commons.io.FileUtils;
import org.deeplearning4j.arbiter.DL4JConfiguration;
import org.deeplearning4j.arbiter.GraphConfiguration;
import org.deeplearning4j.arbiter.optimize.api.Candidate;
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
import org.deeplearning4j.arbiter.optimize.api.saving.ResultReference;
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.util.ModelSerializer;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
/**
* Result reference for MultiLayerNetworks and ComputationGraphs saved to local file system
*/
@AllArgsConstructor
public class LocalFileNetResultReference implements ResultReference {
private int index;
private String dir;
private boolean isGraph;
private File modelFile;
private File scoreFile;
private File additionalResultsFile;
private File esConfigFile;
private File numEpochsFile;
private Candidate candidate;
@Override
public OptimizationResult getResult() throws IOException {
Model m;
if (isGraph) {
m = ModelSerializer.restoreComputationGraph(modelFile, false);
} else {
m = ModelSerializer.restoreMultiLayerNetwork(modelFile, false);
}
String scoreStr = FileUtils.readFileToString(scoreFile);
//TODO: properly parsing. Probably want to store additional info other than just score...
double d = Double.parseDouble(scoreStr);
EarlyStoppingConfiguration earlyStoppingConfiguration = null;
if (esConfigFile != null && esConfigFile.exists()) {
try (ObjectInputStream ois = new ObjectInputStream(new FileInputStream(esConfigFile))) {
earlyStoppingConfiguration = (EarlyStoppingConfiguration) ois.readObject();
} catch (ClassNotFoundException e) {
throw new RuntimeException("Error loading early stopping configuration", e);
}
}
int nEpochs = 1;
if (numEpochsFile != null && numEpochsFile.exists()) {
String numEpochs = FileUtils.readFileToString(numEpochsFile);
nEpochs = Integer.parseInt(numEpochs);
}
Object dl4jConfiguration;
if (isGraph) {
dl4jConfiguration = new GraphConfiguration(((ComputationGraph) m).getConfiguration(),
earlyStoppingConfiguration, nEpochs);
} else {
dl4jConfiguration = new DL4JConfiguration(((MultiLayerNetwork) m).getLayerWiseConfigurations(),
earlyStoppingConfiguration, nEpochs);
}
Object additionalResults;
if (additionalResultsFile.exists()) {
try (ObjectInputStream ois = new ObjectInputStream(new FileInputStream(additionalResultsFile))) {
additionalResults = ois.readObject();
} catch (ClassNotFoundException e) {
throw new RuntimeException("Error loading additional results", e);
}
} else {
additionalResults = null;
}
return new OptimizationResult(candidate, m, d, index, additionalResults, null, this);
}
@Override
public String toString() {
return "LocalFileNetResultReference(" + dir + ")";
}
}