All Downloads are FREE. Search and download functionalities are using the official Maven repository.

smile.cas.Matrix.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (c) 2010-2021 Haifeng Li. All rights reserved.
 *
 * Smile is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * Smile is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with Smile.  If not, see .
 */

package smile.cas

/** Matrix: rank-2 tensor. */
trait Matrix extends Tensor {
  override def rank: Option[Int] = Some(2)
  override def shape: Option[Array[IntScalar]] = Some(Array(size._1, size._2))

  /** The size of matrix (rows x columns). */
  def size: (IntScalar, IntScalar)
  /** Returns the partial derivative. */
  def d(dx: Var): Matrix
  /** Applies the expression. */
  def apply(env: Map[String, Tensor]): Matrix
  /** Applies the expression. */
  def apply(env: (String, Tensor)*): Matrix = apply(Map(env: _*))
  /** Simplify the expression. */
  def simplify: Matrix = this

  def + (y: Matrix): Matrix = AddMatrix(this, y).simplify
  def - (y: Matrix): Matrix = AddMatrix(this, NegMatrix(y)).simplify

  def * (y: Scalar): Matrix = ScalarMatrixProduct(y, this).simplify
  def * (y: Vector): Vector = MatrixVectorProduct(this, y).simplify
  def * (y: Matrix): Matrix = MatrixProduct(this, y).simplify

  def unary_+ : Matrix = simplify
  def unary_- : Matrix = NegMatrix(this).simplify

  /** The transpose of matrix. */
  def t : Matrix = MatrixTranspose(this).simplify

  /** The inverse of matrix. */
  def inv : Matrix = {
    if (isInstanceOf[ZeroMatrix] || isInstanceOf[OneMatrix])
      throw new UnsupportedOperationException("Matrix is not of full rank")

    val s = size
    if (s._1 != s._2) throw new UnsupportedOperationException("Inverse of non square matrix")

    MatrixInverse(this).simplify
  }
}

/** Matrix of all 0's */
case class ZeroMatrix(size: (IntScalar, IntScalar) = (IntConst("m"), IntConst("n"))) extends Matrix {
  override def toString: String = "0"
  override def d(dx: Var): Matrix = this
  override def apply(env: Map[String, Tensor]): Matrix = this
}

/** Matrix of all 1's */
case class OneMatrix(size: (IntScalar, IntScalar) = (IntConst("m"), IntConst("n"))) extends Matrix {
  override def toString: String = "1"
  override def d(dx: Var): Matrix = this
  override def apply(env: Map[String, Tensor]): Matrix = this
}

/** Identity matrix */
case class IdentityMatrix(size: (IntScalar, IntScalar) = (IntConst("n"), IntConst("n"))) extends Matrix {
  override def toString: String = "I"
  override def d(dx: Var): Matrix = this
  override def apply(env: Map[String, Tensor]): Matrix = this
}

/** Constant matrix. */
case class ConstMatrix(symbol: String, size: (IntScalar, IntScalar) = (IntConst("m"), IntConst("n"))) extends Matrix {
  override def toString: String = symbol
  override def d(dx: Var): Matrix = ZeroMatrix(size)
  override def apply(env: Map[String, Tensor]): Matrix = env.get(symbol) match {
    case Some(x: ZeroMatrix) => x
    case Some(x: OneMatrix) => x
    case Some(x: IdentityMatrix) => x
    case x => throw new IllegalArgumentException(s"Invalid type: ${x.getClass}, expected 0 | 1 | I")
  }
}

/** Abstract matrix variable */
case class MatrixVar(symbol: String, size: (IntScalar, IntScalar) = (IntConst("m"), IntConst("n"))) extends Matrix {
  override def toString: String = symbol

  override def d(dx: Var): Matrix = TangentMatrix(this, dx)

  override def apply(env: Map[String, Tensor]): Matrix = env.get(symbol) match {
    case None => this
    case Some(x : Matrix) => x
    case x => throw new IllegalArgumentException(s"Invalid type: ${x.getClass}, expected Matrix")
  }
}

/** Diagonal matrix */
case class DiagonalMatrix(x: Scalar*) extends Matrix {
  override def toString: String = x.mkString("diag[", ", ", "]")

  override def size: (IntScalar, IntScalar) = (IntVal(x.length), IntVal(x.length))

  override def d(dx: Var): Matrix = {
    DiagonalMatrix(x.map(_.d(dx)): _*).simplify
  }

  override def apply(env: Map[String, Tensor]): Matrix = DiagonalMatrix(x.map(_(env)): _*).simplify

  override def simplify: Matrix = {
    if (x.forall(_ == Val(1))) IdentityMatrix(size)
    else if (x.forall(_ == Val(0))) ZeroMatrix(size)
    else this
  }
}

/** Row-wise matrix */
case class RowMatrix(x: Vector*) extends Matrix {
  override def toString: String = x.mkString("row[", ", ", "]")

  override def size: (IntScalar, IntScalar) = (IntVal(x.length), IntVal(x.length))

  override def d(dx: Var): Matrix = {
    RowMatrix(x.map(_.d(dx)): _*).simplify
  }

  override def apply(env: Map[String, Tensor]): Matrix = RowMatrix(x.map(_(env)): _*).simplify

  override def simplify: Matrix = {
    if (x.forall(_ == OneVector(x.length))) OneMatrix(size)
    else if (x.forall(_ == ZeroVector(x.length))) ZeroMatrix(size)
    else this
  }
}

/** The derivative of a matrix y with respect to a scalar x. */
case class TangentMatrix(y: MatrixVar, x: Var) extends Matrix {
  override def toString: String = s"\u2202${y.symbol}/\u2202${x.symbol}"

  override def size: (IntScalar, IntScalar) = y.size

  override def d(dx: Var): Matrix = throw new UnsupportedOperationException("derivative of tangent matrix")

  override def apply(env: Map[String, Tensor]): Matrix = {
    val yv = env.get(y.symbol) match {
      case None | Some(ZeroMatrix(_)) | Some(OneMatrix(_)) | Some(IdentityMatrix(_)) | Some(ConstMatrix(_, _)) => y
      case Some(y: Matrix) => y
      case a => throw new IllegalArgumentException(s"Invalid type: ${a.getClass}, expected Vector")
    }

    env.get(x.symbol) match {
      case None => yv.d(x)
      case Some(x: Var) => yv.d(x)
      case a => throw new IllegalArgumentException(s"Invalid type: ${a.getClass}, expected Var")
    }
  }
}

/** A + B */
case class AddMatrix(A: Matrix, B: Matrix) extends Matrix {
  if (A.size != B.size) throw new IllegalArgumentException(s"Vector sizes mismatch: ${A.size} vs ${B.size}")

  override def toString: String = B match {
    case NegMatrix(B) => s"$A - $B"
    case _ => s"$A + $B"
  }

  override def equals(o: Any): Boolean = o match {
    case AddMatrix(a, b) => (A == a && B == b) || (A == b && B == a)
    case _ => false
  }

  override def size: (IntScalar, IntScalar) = A.size

  override def apply(env: Map[String, Tensor]): Matrix = A(env) + B(env)

  override def d(dx: Var): Matrix = {
    A.d(dx) + B.d(dx)
  }

  override def simplify: Matrix = (A, B) match {
    case (ZeroMatrix(_), b) => b
    case (a, ZeroMatrix(_)) => a
    case (a, NegMatrix(NegMatrix(b))) => a + b
    case (NegMatrix(a), NegMatrix(b)) => -(a + b)
    case (NegMatrix(a), b) if a == b => ZeroMatrix(a.size)
    case (a, NegMatrix(b)) if a == b => ZeroMatrix(a.size)
    case (ScalarMatrixProduct(a, b), ScalarMatrixProduct(c, d)) if b == d => (a + c) * b
    case (ScalarMatrixProduct(a, b), ScalarMatrixProduct(c, d)) if a == c => a * (b + d)
    case (ScalarMatrixProduct(a, b), c) if b == c => (a + 1) * b
    case (a, ScalarMatrixProduct(b, c)) if a == c => (b + 1) * a
    case (ScalarMatrixProduct(a, b), NegMatrix(ScalarMatrixProduct(c, d))) if b == d => (a - c) * b
    case (ScalarMatrixProduct(a, b), NegMatrix(ScalarMatrixProduct(c, d))) if a == c => a * (b - d)
    case (NegMatrix(ScalarMatrixProduct(a, b)), c) if b == c => (1 - a) * b
    case (a, NegMatrix(ScalarMatrixProduct(b, c))) if a == c => (1 - b) * a
    case (MatrixProduct(a, b), MatrixProduct(c, d)) if b == d => (a + c) * b
    case (MatrixProduct(a, b), MatrixProduct(c, d)) if a == c => a * (b + d)
    case (a, b) if a == b => 2 * a
    case (a, b) => AddMatrix(a, b)
  }
}

/** -A */
case class NegMatrix(A: Matrix) extends Matrix {
  override def toString: String = A match {
    case AddMatrix(_, _) | OuterProduct(_, _) => s"-($A)"
    case _ => s"-$A"
  }

  override def size: (IntScalar, IntScalar) = A.size

  override def apply(env: Map[String, Tensor]): Matrix = -A(env)

  override def d(dx: Var): Matrix = {
    -A.d(dx)
  }

  override def simplify: Matrix = A match {
    case a @ ZeroMatrix(_) => a
    case NegMatrix(a) => a
    case _ => this
  }
}

/** A' */
case class MatrixTranspose(A: Matrix) extends Matrix {
  override def toString: String = A match {
    case AddMatrix(_, _) | OuterProduct(_, _) => s"($A)'"
    case _ => s"$A'"
  }

  override def size: (IntScalar, IntScalar) = (A.size._2, A.size._1)

  override def apply(env: Map[String, Tensor]): Matrix = A(env).t

  override def d(dx: Var): Matrix = {
    A.d(dx).t
  }

  override def simplify: Matrix = A match {
    case ZeroMatrix(_) => ZeroMatrix(size)
    case MatrixTranspose(a) => a
    case _ => this
  }
}

/** inv(A) */
case class MatrixInverse(A: Matrix) extends Matrix {
  override def toString: String = s"inv($A)"

  override def size: (IntScalar, IntScalar) = A.size

  override def apply(env: Map[String, Tensor]): Matrix = A(env).inv

  override def d(dx: Var): Matrix = {
    -(A.inv * A.d(dx) * A.inv)
  }
}

/** a * A */
case class ScalarMatrixProduct(a: Scalar, A: Matrix) extends Matrix {
  override def toString: String = {
    val xs = A match {
      case AddMatrix(_, _) | OuterProduct(_, _) => s"($A)"
      case _ => A.toString
    }

    val as = a match {
      case Add(_, _) | Sub(_, _) | Power(_, _) | InnerProduct(_, _) => s"($a)"
      case _ => a.toString
    }

    s"$as * $xs"
  }

  override def size: (IntScalar, IntScalar) = A.size

  override def apply(env: Map[String, Tensor]): Matrix = a(env) * A(env)

  override def d(dx: Var): Matrix = {
    a * A.d(dx) + a.d(dx) * A
  }

  override def simplify: Matrix = (a, A) match {
    case (Val(0), _) => ZeroMatrix(size)
    case (_, ZeroMatrix(_)) => ZeroMatrix(size)
    case (Val(1), b) => b
    case (Val(-1), b) => -b
    case (Neg(a), NegMatrix(b)) => a * b
    case (Neg(a), b) => -(a * b)
    case (a, NegMatrix(b)) => -(a * b)
    case (a, ScalarMatrixProduct(b, c)) => (a * b) * c
    case _ => this
  }
}

/** Matrix vector multiplication (A * x) */
case class MatrixVectorProduct(A: Matrix, x: Vector) extends Vector {
  if (A.size._2 != x.size) throw new IllegalArgumentException(s"Matrix and vector sizes mismatch: ${A.size} vs ${x.size}")

  override def toString: String = {
    val xs = x match {
      case AddVector(_, _) => s"($x)"
      case _ => x.toString
    }

    val as = A match {
      case AddMatrix(_, _) | OuterProduct(_, _) => s"($A)"
      case _ => A.toString
    }

    s"$as * $xs"
  }

  override def size: IntScalar = A.size._1

  override def apply(env: Map[String, Tensor]): Vector = A(env) * x(env)

  override def d(dx: Var): Vector = {
    A * x.d(dx) + A.d(dx) * x
  }

  override def d(dx: VectorVar): Matrix = {
    throw new UnsupportedOperationException
  }

  override def simplify: Vector = (A, x) match {
    case (ZeroMatrix(_), _) => ZeroVector(size)
    case (_, ZeroVector(_)) => ZeroVector(size)
    case (IdentityMatrix(_), b) => b
    case (NegMatrix(a), NegVector(b)) => a * b
    case (NegMatrix(a), b) => -(a * b)
    case (a, NegVector(b)) => -(a * b)
    case _ => this
  }
}

/** Matrix multiplication (A * B) */
case class MatrixProduct(A: Matrix, B: Matrix) extends Matrix {
  if (A.size._2 != B.size._1) throw new IllegalArgumentException(s"Matrix multiplication size mismatch: ${A.size} x ${B.size}")

  override def toString: String = {
    val xs = A match {
      case AddMatrix(_, _) | OuterProduct(_, _) => s"($A)"
      case _ => A.toString
    }

    val ys = B match {
      case AddMatrix(_, _) | OuterProduct(_, _) => s"($B)"
      case _ => B.toString
    }

    s"$xs$ys"
  }

  override def size: (IntScalar, IntScalar) = (A.size._1, B.size._2)

  override def apply(env: Map[String, Tensor]): Matrix = A(env) * B(env)

  override def d(dx: Var): Matrix = {
    (A.d(dx) * B) + (A * B.d(dx))
  }

  override def simplify: Matrix = (A, B) match {
    case (ZeroMatrix(_), _) => ZeroMatrix(size)
    case (_, ZeroMatrix(_)) => ZeroMatrix(size)
    case (IdentityMatrix(_), b) => b
    case (a, IdentityMatrix(_)) => a
    case (NegMatrix(IdentityMatrix(_)), b) => -b
    case (a, NegMatrix(IdentityMatrix(_))) => a
    case (a, MatrixInverse(b)) if a == b => IdentityMatrix(size)
    case (MatrixInverse(a), b) if a == b => IdentityMatrix(size)
    case (a, NegMatrix(MatrixInverse(b))) if a == b => -IdentityMatrix(size)
    case (NegMatrix(MatrixInverse(a)), b) if a == b => -IdentityMatrix(size)
    case (NegMatrix(a), NegMatrix(b)) => a * b
    case (a, NegMatrix(b)) => -(a * b)
    case (NegMatrix(a), b) => -(a * b)
    case (MatrixInverse(a), MatrixInverse(b)) => MatrixInverse(b * a)
    case _ => this
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy