io.mxnet.caffetranslator.Solver Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of caffe-translator Show documentation
Show all versions of caffe-translator Show documentation
Translate Caffe code to MXNet Python code
The newest version!
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file Solver.java
* \brief Model for the Caffe solver prototxt
*/
package io.mxnet.caffetranslator;
import lombok.Getter;
import org.antlr.v4.runtime.CharStream;
import org.antlr.v4.runtime.CharStreams;
import org.antlr.v4.runtime.CommonTokenStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.lang.reflect.Field;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class Solver {
private final String solverPath;
private boolean parseDone;
private Map> properties;
/**
* Fields corresponding to keys that can be present in the solver prototxt. 'setFields' sets these
* using reflection after parsing the solver prototxt. A solver object is passed to string templates
* and the templates read these fields.
*/
@Getter
private String base_lr, momentum, weight_decay, lr_policy, gamma, stepsize, stepvalue, max_iter,
solver_mode, snapshot, snapshot_prefix, test_iter, test_interval, display, type, delta,
momentum2, rms_decay, solver_type;
public Solver(String solverPath) {
this.solverPath = solverPath;
properties = new HashMap<>();
}
public boolean parsePrototxt() {
CharStream cs = null;
try {
FileInputStream fis = new FileInputStream(new File(solverPath));
cs = CharStreams.fromStream(fis, StandardCharsets.UTF_8);
} catch (IOException e) {
System.err.println("Unable to read prototxt " + solverPath);
return false;
}
CaffePrototxtLexer lexer = new CaffePrototxtLexer(cs);
CommonTokenStream tokens = new CommonTokenStream(lexer);
CaffePrototxtParser parser = new CaffePrototxtParser(tokens);
SolverListener solverListener = new SolverListener();
parser.addParseListener(solverListener);
parser.solver();
properties = solverListener.getProperties();
setFields(properties);
parseDone = true;
return true;
}
private void setFields(Map> properties) {
Class> cls = getClass();
for (Map.Entry> entry : properties.entrySet()) {
String key = entry.getKey();
try {
Field field = cls.getDeclaredField(key);
field.set(this, entry.getValue().get(0));
} catch (NoSuchFieldException e) {
// Just ignore
} catch (IllegalAccessException e) {
/**
* This shouldn't happen. If it does happen because we overlooked something, print
* it in the console so we can investigate it.
*/
e.printStackTrace();
}
}
setDefaults();
}
private void setDefaults() {
if (type == null) {
type = "SGD";
}
if (delta == null) {
delta = "1e-8";
}
if (momentum2 == null) {
momentum2 = "0.999";
}
if (rms_decay == null) {
rms_decay = "0.99";
}
}
public String getProperty(String key) {
List list = getProperties(key);
if (list == null) {
return null;
}
return getProperties(key).get(0);
}
public List getProperties(String key) {
if (!parseDone) {
parsePrototxt();
}
return properties.get(key);
}
public String getProperty(String key, String defaultValue) {
String value = getProperty(key);
if (value == null) {
return defaultValue;
} else {
return value;
}
}
}