com.accelad.math.nilgiri.autodiff.DifferentialFunctionFactory Maven / Gradle / Ivy
package com.accelad.math.nilgiri.autodiff;
import java.util.ArrayList;
import java.util.List;
import com.accelad.math.nilgiri.AbstractFactory;
import com.accelad.math.nilgiri.Field;
public class DifferentialFunctionFactory> {
protected AbstractFactory mFactory;
public DifferentialFunctionFactory(AbstractFactory mFactory) {
if (mFactory != null) {
this.mFactory = mFactory;
} else {
throw new IllegalArgumentException("Input not null value.");
}
}
public Constant val(X iX) {
return new Constant<>(iX, mFactory);
}
public ConstantVector val(X... iX) {
int size = iX.length;
ArrayList> list = new ArrayList<>(size);
for (int i = 0; i < size; i++) {
list.add(val(iX[i]));
}
return new ConstantVector<>(mFactory, list);
}
// ZeroVector
public ConstantVector zero(int iSize) {
ArrayList> list = new ArrayList<>(iSize);
for (int i = 0; i < iSize; i++) {
list.add(zero());
}
return new ConstantVector<>(mFactory, list);
}
public Variable var(String iName, X iX, PreEvaluator preEvaluator) {
return new Variable<>(iName, iX, mFactory, preEvaluator);
}
public Variable var(String iName, X iX) {
return new Variable<>(iName, iX, mFactory);
}
public VariableVector var(String iName, X... iX) {
int size = iX.length;
ArrayList> list = new ArrayList<>(size);
for (int i = 0; i < size; i++) {
list.add(var(iName + String.valueOf(i), iX[i]));
}
return new VariableVector<>(mFactory, list);
}
public VariableVector var(String iName, int iSize) {
ArrayList> list = new ArrayList<>(iSize);
for (int i = 0; i < iSize; i++) {
list.add(var(iName + String.valueOf(i), mFactory.zero()));
}
return new VariableVector<>(mFactory, list);
}
public DifferentialVectorFunction function(DifferentialFunction... iX) {
int size = iX.length;
ArrayList> list = new ArrayList<>(size);
for (int i = 0; i < size; i++) {
list.add(iX[i]);
}
return new DifferentialVectorFunction<>(mFactory, list);
}
public Zero zero() {
return new Zero<>(mFactory);
}
public One one() {
return new One<>(mFactory);
}
public DifferentialFunction cos(DifferentialFunction iX) {
return new AbstractUnaryFunction(iX) {
@Override
public X getValue() {
return mFactory.cos(arg().getValue());
}
@Override
public double getReal() {
return Math.cos(arg().getReal());
}
@Override
public DifferentialFunction diff(Variable i_v) {
return (sin(arg()).mul(arg().diff(i_v))).negate();
}
@Override
public String toString() {
return "cos(" + arg().toString() + ")";
}
@Override
public String getFormula(List> variables) {
return "Math.cos(" + arg().getFormula(variables) + ")";
}
};
}
public DifferentialFunction sin(DifferentialFunction iX) {
return new AbstractUnaryFunction(iX) {
@Override
public X getValue() {
return mFactory.sin(arg().getValue());
}
@Override
public double getReal() {
return Math.sin(arg().getReal());
}
@Override
public DifferentialFunction diff(Variable i_v) {
return cos(arg()).mul(arg().diff(i_v));
}
@Override
public String toString() {
return "sin(" + arg().toString() + ")";
}
@Override
public String getFormula(List> variables) {
return "Math.sin(" + arg().getFormula(variables) + ")";
}
};
}
public DifferentialFunction tan(DifferentialFunction iX) {
return new AbstractUnaryFunction(iX) {
@Override
public X getValue() {
return mFactory.tan(arg().getValue());
}
@Override
public double getReal() {
return Math.tan(arg().getReal());
}
@Override
public DifferentialFunction diff(Variable i_v) {
return (new PolynomialTerm<>(1, cos(arg()), -2)).mul(arg().diff(i_v));
}
@Override
public String toString() {
return "tan(" + arg().toString() + ")";
}
@Override
public String getFormula(List> variables) {
return "Math.tan(" + arg().getFormula(variables) + ")";
}
};
}
public DifferentialFunction exp(DifferentialFunction iX) {
return new AbstractUnaryFunction(iX) {
@Override
public X getValue() {
return mFactory.exp(arg().getValue());
}
@Override
public double getReal() {
return Math.exp(arg().getReal());
}
@Override
public DifferentialFunction diff(Variable i_v) {
return exp(arg()).mul(arg().diff(i_v));
}
@Override
public String toString() {
return "exp(" + arg().toString() + ")";
}
@Override
public String getFormula(List> variables) {
return "Math.exp(" + arg().getFormula(variables) + ")";
}
};
}
public DifferentialFunction log(DifferentialFunction iX) {
return new AbstractUnaryFunction(iX) {
@Override
public X getValue() {
return mFactory.log(arg().getValue());
}
@Override
public double getReal() {
return Math.log(arg().getReal());
}
@Override
public DifferentialFunction diff(Variable i_v) {
return new Inverse<>(arg()).mul(arg().diff(i_v));
}
@Override
public String toString() {
return "log(" + arg().toString() + ")";
}
@Override
public String getFormula(List> variables) {
return "Math.log(" + arg().getFormula(variables) + ")";
}
};
}
public DifferentialFunction pow(DifferentialFunction iX, Constant i_y) {
return new AbstractBinaryFunction(iX, i_y) {
@Override
public X getValue() {
return mFactory.pow(larg().getValue(), rarg().getValue());
}
@Override
public double getReal() {
return Math.pow(larg().getReal(), rarg().getReal());
}
@Override
public DifferentialFunction diff(Variable i_v) {
Constant ym1 = DifferentialFunctionFactory.this
.val(rarg().getValue().minus(mFactory.one()));
return rarg().mul(DifferentialFunctionFactory.this.pow(larg(), ym1))
.mul(larg().diff(i_v));
}
@Override
public String toString() {
return "pow(" + larg().toString() + ", " + rarg().toString() + ")";
}
@Override
public String getFormula(List> variables) {
return "Math.pow(" + larg().getFormula(variables) + ","
+ rarg().getFormula(variables) + ")";
}
};
}
public DifferentialFunction sqrt(DifferentialFunction iX) {
return new AbstractUnaryFunction(iX) {
@Override
public X getValue() {
return mFactory.sqrt(arg().getValue());
}
@Override
public double getReal() {
return Math.sqrt(arg().getReal());
}
@Override
public DifferentialFunction diff(Variable i_v) {
return ((sqrt(arg()).inverse())
.div(DifferentialFunctionFactory.this.val(mFactory.one().mul(2L))))
.mul(arg().diff(i_v));
}
@Override
public String toString() {
return "sqrt(" + arg().toString() + ")";
}
@Override
public String getFormula(List> variables) {
return "Math.sqrt(" + arg().getFormula(variables) + ")";
}
};
}
public DifferentialFunction square(DifferentialFunction iX) {
return new AbstractUnaryFunction(iX) {
@Override
public X getValue() {
return mFactory.square(arg().getValue());
}
@Override
public double getReal() {
return Math.pow(arg().getReal(), 2);
}
@Override
public DifferentialFunction diff(Variable i_v) {
return arg().mul(DifferentialFunctionFactory.this.val(mFactory.one().mul(2L)))
.mul(arg().diff(i_v));
}
@Override
public String toString() {
return "square(" + arg().toString() + ")";
}
@Override
public String getFormula(List> variables) {
return "Math.pow(" + arg().getFormula(variables) + ", 2d )";
}
};
}
public DifferentialFunction floor(DifferentialFunction iX) {
return new AbstractUnaryFunction(iX) {
@Override
public X getValue() {
return mFactory.floor(arg().getValue());
}
@Override
public double getReal() {
return Math.floor(arg().getReal());
}
@Override
public DifferentialFunction diff(Variable i_v) {
throw new RuntimeException("not allowed");
}
@Override
public String toString() {
return "floor(" + arg().toString() + ")";
}
@Override
public String getFormula(List> variables) {
return "Math.floor(" + arg().getFormula(variables) + ")";
}
};
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy