ml.shifu.guagua.example.nn.NNUtils Maven / Gradle / Ivy
/*
* Copyright [2013-2014] PayPal Software Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.guagua.example.nn;
import org.encog.Encog;
import org.encog.engine.network.activation.ActivationLinear;
import org.encog.engine.network.activation.ActivationSigmoid;
import org.encog.mathutil.randomize.NguyenWidrowRandomizer;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.layers.BasicLayer;
/**
* Helper class for NN distributed training.
*/
public final class NNUtils {
private NNUtils() {
}
/**
* Generate basic NN network object
*/
public static BasicNetwork generateNetwork(int in, int hidden, int out) {
final BasicNetwork network = new BasicNetwork();
network.addLayer(new BasicLayer(new ActivationLinear(), true, in));
network.addLayer(new BasicLayer(new ActivationSigmoid(), true, hidden));
network.addLayer(new BasicLayer(new ActivationSigmoid(), false, out));
network.getStructure().finalizeStructure();
network.reset();
return network;
}
/**
* Determine the sign of the value.
*
* @param value
* The value to check.
* @return -1 if less than zero, 1 if greater, or 0 if zero.
*/
public static int sign(final double value) {
if(Math.abs(value) < Encog.DEFAULT_DOUBLE_EQUAL) {
return 0;
} else if(value > 0) {
return 1;
} else {
return -1;
}
}
public static void randomize(int seed, double[] weights) {
NguyenWidrowRandomizer randomizer = new NguyenWidrowRandomizer(-1, 1);
randomizer.randomize(weights);
}
}