All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.ojalgo.ann.ArtificialNeuralNetwork Maven / Gradle / Ivy

Go to download

oj! Algorithms - ojAlgo - is Open Source Java code that has to do with mathematics, linear algebra and optimisation.

There is a newer version: 55.0.1
Show newest version
/*
 * Copyright 1997-2020 Optimatika
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 * SOFTWARE.
 */
package org.ojalgo.ann;

import static org.ojalgo.function.constant.PrimitiveMath.*;

import java.io.*;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.function.Function;

import org.ojalgo.function.BasicFunction;
import org.ojalgo.function.PrimitiveFunction;
import org.ojalgo.function.aggregator.Aggregator;
import org.ojalgo.function.constant.PrimitiveMath;
import org.ojalgo.function.special.MissingMath;
import org.ojalgo.matrix.store.MatrixStore;
import org.ojalgo.matrix.store.PhysicalStore;
import org.ojalgo.matrix.store.Primitive32Store;
import org.ojalgo.matrix.store.Primitive64Store;
import org.ojalgo.structure.Access1D;
import org.ojalgo.structure.Structure2D;

public final class ArtificialNeuralNetwork implements BasicFunction.PlainUnary, MatrixStore> {

    /**
     * https://en.wikipedia.org/wiki/Activation_function
     *
     * @author apete
     */
    public enum Activator {

        /**
         * (-,+)
         */
        IDENTITY(args -> (arg -> arg), arg -> ONE, true),
        /**
         * ReLU: [0,+)
         */
        RECTIFIER(args -> (arg -> Math.max(ZERO, arg)), arg -> arg > ZERO ? ONE : ZERO, true),
        /**
         * [0,1]
         */
        SIGMOID(args -> (LOGISTIC), arg -> arg * (ONE - arg), true),
        /**
         * [0,1] 
* Currently this can only be used in the final layer in combination with * {@link ArtificialNeuralNetwork.Error#CROSS_ENTROPY}. All other usage will give incorrect network * training. */ SOFTMAX(args -> { PhysicalStore parts = args.copy(); parts.modifyAll(EXP); final double total = parts.aggregateAll(Aggregator.SUM); return arg -> EXP.invoke(arg) / total; }, arg -> ONE, false), /** * [-1,1] */ TANH(args -> PrimitiveMath.TANH, arg -> ONE - (arg * arg), true); private final PrimitiveFunction.Unary myDerivativeInTermsOfOutput; private final Function, PrimitiveFunction.Unary> myFunction; private final boolean mySingleFolded; Activator(final Function, PrimitiveFunction.Unary> function, final PrimitiveFunction.Unary derivativeInTermsOfOutput, final boolean singleFolded) { myFunction = function; myDerivativeInTermsOfOutput = derivativeInTermsOfOutput; mySingleFolded = singleFolded; } PrimitiveFunction.Unary getDerivativeInTermsOfOutput() { return myDerivativeInTermsOfOutput; } PrimitiveFunction.Unary getFunction(final PhysicalStore arguments) { return myFunction.apply(arguments); } PrimitiveFunction.Unary getFunction(final PhysicalStore arguments, final double probabilityToKeep) { if ((ZERO < probabilityToKeep) && (probabilityToKeep <= ONE)) { return new NodeDroppingActivatorFunction(probabilityToKeep, this.getFunction(arguments)); } else { throw new IllegalArgumentException(); } } boolean isSingleFolded() { return mySingleFolded; } } public enum Error implements PrimitiveFunction.Binary { /** * Currently this can only be used in in combination with {@link Activator#SOFTMAX} in the final * layer. All other usage will give incorrect network training. */ CROSS_ENTROPY((target, current) -> -target * Math.log(current), (target, current) -> (current - target)), /** * */ HALF_SQUARED_DIFFERENCE((target, current) -> HALF * (target - current) * (target - current), (target, current) -> (current - target)); private final PrimitiveFunction.Binary myDerivative; private final PrimitiveFunction.Binary myFunction; Error(final PrimitiveFunction.Binary function, final PrimitiveFunction.Binary derivative) { myFunction = function; myDerivative = derivative; } public double invoke(final Access1D target, final Access1D current) { int limit = MissingMath.toMinIntExact(target.count(), current.count()); double retVal = ZERO; for (int i = 0; i < limit; i++) { retVal += myFunction.invoke(target.doubleValue(i), current.doubleValue(i)); } return retVal; } public double invoke(final double target, final double current) { return myFunction.invoke(target, current); } PrimitiveFunction.Binary getDerivative() { return myDerivative; } } public static NetworkBuilder builder(final int numberOfNetworkInputNodes) { return ArtificialNeuralNetwork.builder(Primitive64Store.FACTORY, numberOfNetworkInputNodes); } /** * @deprecated Use {@link #builder(int)} instead */ @Deprecated public static NetworkTrainer builder(final int numberOfInputNodes, final int... nodesPerCalculationLayer) { return ArtificialNeuralNetwork.builder(Primitive64Store.FACTORY, numberOfInputNodes, nodesPerCalculationLayer); } public static NetworkBuilder builder(final PhysicalStore.Factory factory, final int numberOfNetworkInputNodes) { return new NetworkBuilder(factory, numberOfNetworkInputNodes); } /** * @deprecated Use {@link #builder(org.ojalgo.matrix.store.PhysicalStore.Factory, int)} instead */ @Deprecated public static NetworkTrainer builder(final PhysicalStore.Factory factory, final int numberOfInputNodes, final int... nodesPerCalculationLayer) { NetworkBuilder builder = ArtificialNeuralNetwork.builder(factory, numberOfInputNodes); for (int i = 0; i < nodesPerCalculationLayer.length; i++) { builder.layer(nodesPerCalculationLayer[i]); } return builder.get().newTrainer(); } /** * Read (reconstruct) an ANN from the specified input previously written by {@link #writeTo(DataOutput)}. */ public static ArtificialNeuralNetwork from(final DataInput input) throws IOException { return FileFormat.read(null, input); } /** * @see #from(DataInput) */ public static ArtificialNeuralNetwork from(final File file) { return ArtificialNeuralNetwork.from(null, file); } /** * @see #from(DataInput) */ public static ArtificialNeuralNetwork from(final Path path, final OpenOption... options) { return ArtificialNeuralNetwork.from(null, path, options); } /** * Read (reconstruct) an ANN from the specified input previously written by {@link #writeTo(DataOutput)}. */ public static ArtificialNeuralNetwork from(final PhysicalStore.Factory factory, final DataInput input) throws IOException { return FileFormat.read(factory, input); } /** * @see #from(DataInput) */ public static ArtificialNeuralNetwork from(final PhysicalStore.Factory factory, final File file) { try (DataInputStream input = new DataInputStream(new BufferedInputStream(new FileInputStream(file)))) { return ArtificialNeuralNetwork.from(factory, input); } catch (IOException cause) { throw new RuntimeException(cause); } } /** * @see #from(DataInput) */ public static ArtificialNeuralNetwork from(final PhysicalStore.Factory factory, final Path path, final OpenOption... options) { try (DataInputStream input = new DataInputStream(new BufferedInputStream(Files.newInputStream(path, options)))) { return ArtificialNeuralNetwork.from(factory, input); } catch (IOException cause) { throw new RuntimeException(cause); } } private transient TrainingConfiguration myConfiguration = null; private final PhysicalStore.Factory myFactory; private final CalculationLayer[] myLayers; ArtificialNeuralNetwork(final NetworkBuilder builder) { super(); myFactory = builder.getFactory(); List templates = builder.getLayers(); myLayers = new CalculationLayer[templates.size()]; for (int i = 0; i < myLayers.length; i++) { LayerTemplate layerTemplate = templates.get(i); myLayers[i] = new CalculationLayer(myFactory, layerTemplate.inputs, layerTemplate.outputs, layerTemplate.activator); } } ArtificialNeuralNetwork(final PhysicalStore.Factory factory, final int inputs, final int[] layers) { super(); myFactory = factory; myLayers = new CalculationLayer[layers.length]; int tmpIn = inputs; int tmpOut = inputs; for (int i = 0; i < layers.length; i++) { tmpIn = tmpOut; tmpOut = layers[i]; myLayers[i] = new CalculationLayer(factory, tmpIn, tmpOut, ArtificialNeuralNetwork.Activator.SIGMOID); } } /** * @return The number of calculation layers */ public int depth() { return myLayers.length; } @Override public boolean equals(final Object obj) { if (this == obj) { return true; } if (obj == null) { return false; } if (!(obj instanceof ArtificialNeuralNetwork)) { return false; } ArtificialNeuralNetwork other = (ArtificialNeuralNetwork) obj; if (!Arrays.equals(myLayers, other.myLayers)) { return false; } return true; } public Activator getActivator(final int layer) { return myLayers[layer].getActivator(); } public double getBias(final int layer, final int output) { return myLayers[layer].getBias(output); } public double getWeight(final int layer, final int input, final int output) { return myLayers[layer].getWeight(input, output); } @Override public int hashCode() { final int prime = 31; int result = 1; result = (prime * result) + Arrays.hashCode(myLayers); return result; } /** * @deprecated v49 Use {@link #newInvoker()} and then {@link NetworkInvoker#invoke(Access1D)} instead */ @Deprecated public MatrixStore invoke(final Access1D input) { return this.newInvoker().invoke(input); } /** * If you create multiple invokers you can use them in different threads simutaneously - the invoker * contains any/all invocation specific state. */ public NetworkInvoker newInvoker() { return new NetworkInvoker(this); } public NetworkTrainer newTrainer() { NetworkTrainer trainer = new NetworkTrainer(this); if (this.getOutputActivator() == Activator.SOFTMAX) { trainer.error(Error.CROSS_ENTROPY); } else { trainer.error(Error.HALF_SQUARED_DIFFERENCE); } return trainer; } @Override public String toString() { StringBuilder tmpBuilder = new StringBuilder(); tmpBuilder.append("ArtificialNeuralNetwork [Layers="); tmpBuilder.append(Arrays.toString(myLayers)); tmpBuilder.append("]"); return tmpBuilder.toString(); } /** * @return The max number of nodes in any layer */ public int width() { int retVal = myLayers[0].countInputNodes(); for (CalculationLayer layer : myLayers) { retVal = Math.max(retVal, layer.countOutputNodes()); } return retVal; } /** * Will write (save) the ANN to the specified output. Can then later be read back by using * {@link #from(DataInput)}. */ public void writeTo(final DataOutput output) throws IOException { int version = (myFactory == Primitive32Store.FACTORY) ? 2 : 1; FileFormat.write(this, version, output); } /** * @see #writeTo(DataOutput) */ public void writeTo(final File file) { try (DataOutputStream output = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(file)))) { this.writeTo(output); } catch (IOException cause) { throw new RuntimeException(cause); } } /** * @see #writeTo(DataOutput) */ public void writeTo(final Path path, final OpenOption... options) { try (DataOutputStream output = new DataOutputStream(new BufferedOutputStream(Files.newOutputStream(path, options)))) { this.writeTo(output); } catch (IOException cause) { throw new RuntimeException(cause); } } void adjust(final int layer, final Access1D input, final PhysicalStore output, final PhysicalStore upstreamGradient, final PhysicalStore downstreamGradient) { myLayers[layer].adjust(input, output, layer == 0 ? null : upstreamGradient, downstreamGradient, -myConfiguration.learningRate, myConfiguration.probabilityDidKeepInput(layer), myConfiguration.regularisation()); } int countInputNodes(final int layer) { return myLayers[layer].countInputNodes(); } int countOutputNodes(final int layer) { return myLayers[layer].countOutputNodes(); } Activator getOutputActivator() { return myLayers[myLayers.length - 1].getActivator(); } List> getWeights() { final ArrayList> retVal = new ArrayList<>(); for (int i = 0; i < myLayers.length; i++) { retVal.add(myLayers[i].getLogicalWeights()); } return retVal; } PhysicalStore invoke(final int layer, final Access1D input, final PhysicalStore output) { if (myConfiguration != null) { return myLayers[layer].invoke(input, output, myConfiguration.probabilityWillKeepOutput(layer, this.depth())); } else { return myLayers[layer].invoke(input, output); } } PhysicalStore newStore(final int rows, final int columns) { return myFactory.make(rows, columns); } void randomise() { for (int l = 0; l < myLayers.length; l++) { myLayers[l].randomise(); } } void scale(final int layer, final double factor) { myLayers[layer].scale(factor); } void setActivator(final int layer, final Activator activator) { myLayers[layer].setActivator(activator); } void setBias(final int layer, final int output, final double bias) { myLayers[layer].setBias(output, bias); } void setConfiguration(final TrainingConfiguration configuration) { if ((myConfiguration != null) && (configuration == null)) { for (int l = 1, limit = this.depth(); l < limit; l++) { this.scale(l, myConfiguration.probabilityDidKeepInput(l)); } } myConfiguration = configuration; } void setWeight(final int layer, final int input, final int output, final double weight) { myLayers[layer].setWeight(input, output, weight); } Structure2D[] structure() { Structure2D[] retVal = new Structure2D[myLayers.length]; for (int l = 0; l < retVal.length; l++) { retVal[l] = myLayers[l].getStructure(); } return retVal; } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy