
cc.redberry.transformation.GetScalarDerivative Maven / Gradle / Ivy
/*
* Redberry: symbolic tensor computations.
*
* Copyright (c) 2010-2012:
* Stanislav Poslavsky
* Bolotin Dmitriy
*
* This file is part of Redberry.
*
* Redberry is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* Redberry is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with Redberry. If not, see .
*/
package cc.redberry.transformation;
import java.util.ArrayList;
import java.util.List;
import cc.redberry.core.number.ComplexElement;
import cc.redberry.core.tensor.AbstractScalarFunction;
import cc.redberry.core.tensor.Derivative;
import cc.redberry.core.tensor.Product;
import cc.redberry.core.tensor.SimpleTensor;
import cc.redberry.core.tensor.Sum;
import cc.redberry.core.tensor.Tensor;
import cc.redberry.core.tensor.TensorField;
import cc.redberry.core.tensor.TensorIterator;
import cc.redberry.core.tensor.TensorNumber;
import cc.redberry.core.tensor.testing.TTest;
/**
*
* @author Dmitry Bolotin
* @author Stanislav Poslavsky
*/
public class GetScalarDerivative {
public static GetScalarDerivative INSTANCE = new GetScalarDerivative();
private GetScalarDerivative() {
}
public Tensor transform(Tensor tensor) {
if (!TTest.testIsScalar(tensor))
return tensor;
if (!(tensor instanceof Derivative))
return tensor;
Derivative derivative = (Derivative) tensor;
Tensor target = derivative.getTarget().clone();
for (int i = 0; i < derivative.getDerivativeOrder(); ++i) {
target = getDerivative(target, derivative.getVariation(i));
if (target == null)
return TensorNumber.createZERO();
}
return target;
}
private Tensor getDerivative(Tensor target, SimpleTensor var) {
if (target instanceof Sum) {
Sum sum = (Sum) target;
TensorIterator it = sum.iterator();
Tensor current, derivative;
Sum res = new Sum();
while (it.hasNext()) {
current = it.next();
derivative = getDerivative(current, var);
if (derivative == null)
continue;//it.remove();
else res.add(derivative);// if (derivative != current)
//it.set(derivative);
}
if (res.isEmpty())
return null;
return res.equivalent();
} else if (target instanceof Product) {
Product product = (Product) target;
Tensor derivative;
List resultProducts = new ArrayList<>();
for (int i = 0; i < product.size(); ++i) {
derivative = getDerivative(product.getElements().get(i), var);
if (derivative == null)
continue;
Product clone = (Product) product.clone();
clone.getElements().remove(i);
if (!isOne(derivative))
clone.add(derivative);
resultProducts.add(clone.equivalent());
}
if (resultProducts.isEmpty())
return null;
if (resultProducts.size() == 1)
return resultProducts.get(0);
return new Sum(resultProducts);
} else if (target.getClass() == SimpleTensor.class) {
if (((SimpleTensor) target).getName() == var.getName())
return TensorNumber.createONE();
return null;
} else if (target.getClass() == TensorField.class) {
TensorField field = (TensorField) target;
Tensor[] args = field.getArgs();
for (int i = 0; i < args.length; ++i)
if (getDerivative(args[i], var) != null)
return Derivative.createFromInversed(target, new SimpleTensor[]{var});
return null;
} else if (target instanceof AbstractScalarFunction) {
AbstractScalarFunction func = (AbstractScalarFunction) target;
Tensor der = getDerivative(func.getInnerTensor(), var);
if (der == null)
return null;
if (isOne(der))
return func.derivative();
return new Product(func.derivative(), der);
}
return null;
}
private static boolean isOne(Tensor t) {
if (t instanceof TensorNumber && ((TensorNumber) t).getValue().equals(ComplexElement.ONE))
return true;
return false;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy