All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.linalg.inverse.InvertMatrix Maven / Gradle / Ivy

The newest version!
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.nd4j.linalg.inverse;

import org.apache.commons.math3.linear.DecompositionSolver;
import org.apache.commons.math3.linear.LUDecomposition;
import org.apache.commons.math3.linear.QRDecomposition;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.SingularMatrixException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.checkutil.CheckUtil;
import org.nd4j.linalg.factory.Nd4j;

public class InvertMatrix {


    /**
     * Inverts a matrix
     * @param arr the array to invert
     * @param inPlace Whether to store the result in {@code arr}
     * @return the inverted matrix
     */
    public static INDArray invert(INDArray arr, boolean inPlace) {
        if(arr.rank() == 2 && arr.length() == 1) {
            //[1,1] edge case. Matrix inversion: [x] * [1/x] = [1]
            if(inPlace){
                return arr.rdivi(1.0);
            } else {
                return arr.rdiv(1.0);
            }
        }
        if (!arr.isSquare()) {
            throw new IllegalArgumentException("invalid array: must be square matrix");
        }


        return Nd4j.linalg().matrixInverse(arr);

    }

    /**
     * Calculates pseudo inverse of a matrix using QR decomposition
     * @param arr the array to invert
     * @return the pseudo inverted matrix
     */
    public static INDArray pinvert(INDArray arr, boolean inPlace) {

        // TODO : do it natively instead of relying on commons-maths

        RealMatrix realMatrix = CheckUtil.convertToApacheMatrix(arr);
        QRDecomposition decomposition = new QRDecomposition(realMatrix, 0);
        DecompositionSolver solver = decomposition.getSolver();

        if (!solver.isNonSingular()) {
            throw new IllegalArgumentException("invalid array: must be singular matrix");
        }

        RealMatrix pinvRM = solver.getInverse();

        INDArray pseudoInverse = CheckUtil.convertFromApacheMatrix(pinvRM, arr.dataType());

        if (inPlace)
            arr.assign(pseudoInverse);
        return pseudoInverse;

    }

    /**
     * Compute the left pseudo inverse. Input matrix must have full column rank.
     *
     * See also: Moore–Penrose inverse
     *
     * @param arr Input matrix
     * @param inPlace Whether to store the result in {@code arr}
     * @return Left pseudo inverse of {@code arr}
     * @exception IllegalArgumentException Input matrix {@code arr} did not have full column rank.
     */
    public static INDArray pLeftInvert(INDArray arr, boolean inPlace) {
        try {
          final INDArray inv = invert(arr.transpose().mmul(arr), inPlace).mmul(arr.transpose());
          if (inPlace) arr.assign(inv);
          return inv;
        } catch (SingularMatrixException e) {
          throw new IllegalArgumentException(
              "Full column rank condition for left pseudo inverse was not met.");
        }
    }

    /**
     * Compute the right pseudo inverse. Input matrix must have full row rank.
     *
     * See also: Moore–Penrose inverse
     *
     * @param arr Input matrix
     * @param inPlace Whether to store the result in {@code arr}
     * @return Right pseudo inverse of {@code arr}
     * @exception IllegalArgumentException Input matrix {@code arr} did not have full row rank.
     */
    public static INDArray pRightInvert(INDArray arr, boolean inPlace) {
        try{
            final INDArray inv = arr.transpose().mmul(invert(arr.mmul(arr.transpose()), inPlace));
            if (inPlace) arr.assign(inv);
            return inv;
        } catch (SingularMatrixException e){
            throw new IllegalArgumentException(
                "Full row rank condition for right pseudo inverse was not met.");
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy