All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.mxnet.caffetranslator.Solver Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file Solver.java
 * \brief Model for the Caffe solver prototxt
 */

package io.mxnet.caffetranslator;

import lombok.Getter;
import org.antlr.v4.runtime.CharStream;
import org.antlr.v4.runtime.CharStreams;
import org.antlr.v4.runtime.CommonTokenStream;

import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.lang.reflect.Field;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class Solver {

    private final String solverPath;
    private boolean parseDone;
    private Map> properties;
    /**
     * Fields corresponding to keys that can be present in the solver prototxt. 'setFields' sets these
     * using reflection after parsing the solver prototxt. A solver object is passed to string templates
     * and the templates read these fields.
     */
    @Getter
    private String base_lr, momentum, weight_decay, lr_policy, gamma, stepsize, stepvalue, max_iter,
            solver_mode, snapshot, snapshot_prefix, test_iter, test_interval, display, type, delta,
            momentum2, rms_decay, solver_type;

    public Solver(String solverPath) {
        this.solverPath = solverPath;
        properties = new HashMap<>();
    }

    public boolean parsePrototxt() {
        CharStream cs = null;
        try {
            FileInputStream fis = new FileInputStream(new File(solverPath));
            cs = CharStreams.fromStream(fis, StandardCharsets.UTF_8);
        } catch (IOException e) {
            System.err.println("Unable to read prototxt " + solverPath);
            return false;
        }

        CaffePrototxtLexer lexer = new CaffePrototxtLexer(cs);
        CommonTokenStream tokens = new CommonTokenStream(lexer);
        CaffePrototxtParser parser = new CaffePrototxtParser(tokens);

        SolverListener solverListener = new SolverListener();
        parser.addParseListener(solverListener);
        parser.solver();

        properties = solverListener.getProperties();

        setFields(properties);

        parseDone = true;
        return true;
    }

    private void setFields(Map> properties) {
        Class cls = getClass();

        for (Map.Entry> entry : properties.entrySet()) {
            String key = entry.getKey();
            try {
                Field field = cls.getDeclaredField(key);
                field.set(this, entry.getValue().get(0));
            } catch (NoSuchFieldException e) {
                // Just ignore
            } catch (IllegalAccessException e) {
                /**
                 * This shouldn't happen. If it does happen because we overlooked something, print
                 * it in the console so we can investigate it.
                 */
                e.printStackTrace();
            }
        }

        setDefaults();
    }

    private void setDefaults() {
        if (type == null) {
            type = "SGD";
        }
        if (delta == null) {
            delta = "1e-8";
        }
        if (momentum2 == null) {
            momentum2 = "0.999";
        }
        if (rms_decay == null) {
            rms_decay = "0.99";
        }
    }

    public String getProperty(String key) {
        List list = getProperties(key);
        if (list == null) {
            return null;
        }
        return getProperties(key).get(0);
    }

    public List getProperties(String key) {
        if (!parseDone) {
            parsePrototxt();
        }

        return properties.get(key);
    }

    public String getProperty(String key, String defaultValue) {
        String value = getProperty(key);
        if (value == null) {
            return defaultValue;
        } else {
            return value;
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy