All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.mahout.classifier.mlp.NeuralNetworkFunctions Maven / Gradle / Ivy

/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.mahout.classifier.mlp;

import org.apache.mahout.math.function.DoubleDoubleFunction;
import org.apache.mahout.math.function.DoubleFunction;
import org.apache.mahout.math.function.Functions;

/**
 * The functions that will be used by NeuralNetwork.
 * @deprecated as of as of 0.10.0.
 */
@Deprecated
public class NeuralNetworkFunctions {

  /**
   * The derivation of identity function (f(x) = x).
   */
  public static DoubleFunction derivativeIdentityFunction = new DoubleFunction() {
    @Override
    public double apply(double x) {
      return 1;
    }
  };

  /**
   * The derivation of minus squared function (f(t, o) = (o - t)^2).
   */
  public static DoubleDoubleFunction derivativeMinusSquared = new DoubleDoubleFunction() {
    @Override
    public double apply(double target, double output) {
      return 2 * (output - target);
    }
  };

  /**
   * The cross entropy function (f(t, o) = -t * log(o) - (1 - t) * log(1 - o)).
   */
  public static DoubleDoubleFunction crossEntropy = new DoubleDoubleFunction() {
    @Override
    public double apply(double target, double output) {
      return -target * Math.log(output) - (1 - target) * Math.log(1 - output);
    }
  };

  /**
   * The derivation of cross entropy function (f(t, o) = -t * log(o) - (1 - t) *
   * log(1 - o)).
   */
  public static DoubleDoubleFunction derivativeCrossEntropy = new DoubleDoubleFunction() {
    @Override
    public double apply(double target, double output) {
      double adjustedTarget = target;
      double adjustedActual = output;
      if (adjustedActual == 1) {
        adjustedActual = 0.999;
      } else if (output == 0) {
        adjustedActual = 0.001;
      }
      if (adjustedTarget == 1) {
        adjustedTarget = 0.999;
      } else if (adjustedTarget == 0) {
        adjustedTarget = 0.001;
      }
      return -adjustedTarget / adjustedActual + (1 - adjustedTarget) / (1 - adjustedActual);
    }
  };

  /**
   * Get the corresponding function by its name.
   * Currently supports: "Identity", "Sigmoid".
   * 
   * @param function The name of the function.
   * @return The corresponding double function.
   */
  public static DoubleFunction getDoubleFunction(String function) {
    if (function.equalsIgnoreCase("Identity")) {
      return Functions.IDENTITY;
    } else if (function.equalsIgnoreCase("Sigmoid")) {
      return Functions.SIGMOID;
    } else {
      throw new IllegalArgumentException("Function not supported.");
    }
  }

  /**
   * Get the derivation double function by the name.
   * Currently supports: "Identity", "Sigmoid".
   * 
   * @param function The name of the function.
   * @return The double function.
   */
  public static DoubleFunction getDerivativeDoubleFunction(String function) {
    if (function.equalsIgnoreCase("Identity")) {
      return derivativeIdentityFunction;
    } else if (function.equalsIgnoreCase("Sigmoid")) {
      return Functions.SIGMOIDGRADIENT;
    } else {
      throw new IllegalArgumentException("Function not supported.");
    }
  }

  /**
   * Get the corresponding double-double function by the name.
   * Currently supports: "Minus_Squared", "Cross_Entropy".
   * 
   * @param function The name of the function.
   * @return The double-double function.
   */
  public static DoubleDoubleFunction getDoubleDoubleFunction(String function) {
    if (function.equalsIgnoreCase("Minus_Squared")) {
      return Functions.MINUS_SQUARED;
    } else if (function.equalsIgnoreCase("Cross_Entropy")) {
      return derivativeCrossEntropy;
    } else {
      throw new IllegalArgumentException("Function not supported.");
    }
  }

  /**
   * Get the corresponding derivation of double double function by the name.
   * Currently supports: "Minus_Squared", "Cross_Entropy".
   * 
   * @param function The name of the function.
   * @return The double-double-function.
   */
  public static DoubleDoubleFunction getDerivativeDoubleDoubleFunction(String function) {
    if (function.equalsIgnoreCase("Minus_Squared")) {
      return derivativeMinusSquared;
    } else if (function.equalsIgnoreCase("Cross_Entropy")) {
      return derivativeCrossEntropy;
    } else {
      throw new IllegalArgumentException("Function not supported.");
    }
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy