All Downloads are FREE. Search and download functionalities are using the official Maven repository.

smile.cas.Vector.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (c) 2010-2021 Haifeng Li. All rights reserved.
 *
 * Smile is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * Smile is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with Smile.  If not, see .
 */

package smile.cas

/** Vector: rank-1 tensor. */
trait Vector extends Tensor {
  override def rank: Option[Int] = Some(1)
  override def shape: Option[Array[IntScalar]] = Some(Array(size))

  /** The size of vector. */
  def size: IntScalar

  /** Applies the expression. */
  def apply(env: Map[String, Tensor]): Vector
  /** Applies the expression. */
  def apply(env: (String, Tensor)*): Vector = apply(Map(env: _*))
  /** Simplify the expression. */
  def simplify: Vector = this

  /** Returns the partial derivative. */
  def d(dx: Var): Vector
  /** Returns the Jacobian matrix. */
  def d(dx: VectorVar): Matrix

  def + (y: Vector): Vector = AddVector(this, y).simplify
  def - (y: Vector): Vector = AddVector(this, NegVector(y)).simplify
  def * (y: Scalar): Vector = ScalarVectorProduct(y, this).simplify
  def / (y: Scalar): Vector = ScalarVectorProduct(1/y, this).simplify

  def * (y: Vector): Scalar = InnerProduct(this, y).simplify
  def *~ (y: Vector): Matrix = OuterProduct(this, y).simplify

  def unary_+ : Vector = simplify
  def unary_- : Vector = NegVector(this).simplify
}

/** Vector of all 0's */
case class ZeroVector(size: IntScalar = IntConst("n")) extends Vector {
  override def toString: String = "0"
  override def d(dx: Var): Vector = this
  override def d(dx: VectorVar): Matrix = ZeroMatrix(size, dx.size)
  override def apply(env: Map[String, Tensor]): Vector = this
}

/** Vector of all 1's */
case class OneVector(size: IntScalar = IntConst("n")) extends Vector {
  override def toString: String = "1"
  override def d(dx: Var): Vector = ZeroVector(size)
  override def d(dx: VectorVar): Matrix = ZeroMatrix(size, dx.size)
  override def apply(env: Map[String, Tensor]): Vector = this
}

/** Vector value. */
case class VectorVal(x: Array[Double]) extends Vector {
  override def toString: String = x.mkString("[", ", ", "]")
  override def equals(that: Any): Boolean = that match {
    case VectorVal(y) => x.sameElements(y)
    case _ => false
  }
  override def size: IntScalar = IntVal(x.length)
  override def d(dx: Var): Vector = ZeroVector(size)
  override def d(dx: VectorVar): Matrix = ZeroMatrix(size, dx.size)
  override def apply(env: Map[String, Tensor]): VectorVal = this

  override def simplify: Vector = {
    if (x.forall(_ == 0)) ZeroVector(IntVal(x.length))
    else if (x.forall(_ == 1)) OneVector(IntVal(x.length))
    else this
  }
}

/** Constant vector. Different from VectorVal that has concrete values, this is of constant yet abstract value. */
case class ConstVector(symbol: String, size: IntScalar = IntConst("n")) extends Vector {
  override def toString: String = symbol
  override def d(dx: Var): Vector = ZeroVector(size)
  override def d(dx: VectorVar): Matrix = ZeroMatrix(size, dx.size)
  override def apply(env: Map[String, Tensor]): Vector = env.get(symbol) match {
    case Some(x: ZeroVector) => x
    case Some(x: OneVector) => x
    case Some(x: VectorVal) => x
    case x => throw new IllegalArgumentException(s"Invalid type: ${x.getClass}, expected 0 | 1 | VectorVal")
  }
}

/** Vector variable */
case class Vars(x: Scalar*) extends Vector {
  override def toString: String = x.mkString("[", ", ", "]")

  override def size: IntScalar = IntVal(x.length)

  override def d(dx: Var): Vector = {
    Vars(x.map(_.d(dx)): _*).simplify
  }

  override def d(dx: VectorVar): Matrix = {
    RowMatrix(x.map(_.d(dx).simplify): _*).simplify
  }

  override def apply(env: Map[String, Tensor]): Vector = {
    Vars(x.map(_(env)): _*).simplify
  }

  override def simplify: Vector = {
    if (x.forall(_.isInstanceOf[Val]))
      VectorVal(x.map(_.asInstanceOf[Val].x).toArray).simplify
    else
      this
  }
}

/** Abstract vector variable */
case class VectorVar(symbol: String, size: IntScalar = IntConst("n")) extends Vector {
  override def toString: String = symbol

  override def d(dx: Var): Vector = TangentVector(this, dx)

  override def d(dx: VectorVar): Matrix = {
    if (symbol.equals(dx.symbol)) IdentityMatrix(size, size) else ZeroMatrix(size, size)
  }

  override def apply(env: Map[String, Tensor]): Vector = env.get(symbol) match {
    case None => this
    case Some(x : Vector) => x
    case x => throw new IllegalArgumentException(s"Invalid type: ${x.getClass}, expected Vector")
  }
}

/** The derivative of a scalar y with respect to a vector x. */
case class GradientVector(y: Var, x: VectorVar) extends Vector {
  override def toString: String = s"\u2202${y.symbol}/\u2202${x.symbol}"

  override def size: IntScalar = x.size

  override def d(dx: Var): Vector = throw new UnsupportedOperationException("derivative of gradient vector")

  override def d(dx: VectorVar): Matrix = throw new UnsupportedOperationException("derivative of gradient vector")

  override def apply(env: Map[String, Tensor]): Vector = {
    val yv = env.get(y.symbol) match {
      case None | Some(Val(_)) => y
      case Some(y: Scalar) => y
      case a => throw new IllegalArgumentException(s"Invalid type: ${a.getClass}, expected Scalar")
    }

    env.get(x.symbol) match {
      case None => yv.d(x)
      case Some(x: VectorVar) => yv.d(x)
      case Some(Vars(x @ _*)) if x.forall(_.isInstanceOf[Var]) => yv.d(x.map(_.asInstanceOf[Var]): _*)
      case a => throw new IllegalArgumentException(s"Invalid type: ${a.getClass}, expected VectorVar or Vars")
    }
  }
}

/** The derivative of a vector y with respect to a scalar x. */
case class TangentVector(y: VectorVar, x: Var) extends Vector {
  override def toString: String = s"\u2202${y.symbol}/\u2202${x.symbol}"

  override def size: IntScalar = y.size

  override def d(dx: Var): Vector = throw new UnsupportedOperationException("derivative of tangent vector")

  override def d(dx: VectorVar): Matrix = throw new UnsupportedOperationException("derivative of tangent vector")

  override def apply(env: Map[String, Tensor]): Vector = {
    val yv = env.get(y.symbol) match {
      case None | Some(ZeroVector(_)) | Some(OneVector(_)) | Some(VectorVal(_)) | Some(ConstVector(_, _)) => y
      case Some(y: Vector) => y
      case a => throw new IllegalArgumentException(s"Invalid type: ${a.getClass}, expected Vector")
    }

    env.get(x.symbol) match {
      case None => yv.d(x)
      case Some(x: Var) => yv.d(x)
      case a => throw new IllegalArgumentException(s"Invalid type: ${a.getClass}, expected Var")
    }
  }
}

/** x + y */
case class AddVector(x: Vector, y: Vector) extends Vector {
  if (x.size != y.size) throw new IllegalArgumentException(s"Vector sizes mismatch: ${x.size} vs ${y.size}")

  override def toString: String = y match {
    case NegVector(y) => s"$x - $y"
    case _ => s"$x + $y"
  }

  override def equals(o: Any): Boolean = o match {
    case AddVector(a, b) => (x == a && y == b) || (x == b && y == a)
    case _ => false
  }

  override def size: IntScalar = x.size

  override def apply(env: Map[String, Tensor]): Vector = x(env) + y(env)

  override def d(dx: Var): Vector = {
    x.d(dx) + y.d(dx)
  }

  override def d(dx: VectorVar): Matrix = {
    x.d(dx) + y.d(dx)
  }

  override def simplify: Vector = (x, y) match {
    case (VectorVal(a), VectorVal(b)) => VectorVal(a.zip(b).map { case (x, y) => x + y })
    case (ZeroVector(_), b) => b
    case (a, ZeroVector(_)) => a
    case (a, NegVector(NegVector(b))) => a + b
    case (NegVector(a), NegVector(b)) => -(a + b)
    case (NegVector(a), b) if a == b => ZeroVector(a.size)
    case (a, NegVector(b)) if a == b => ZeroVector(a.size)
    case (ScalarVectorProduct(a, b), ScalarVectorProduct(c, d)) if b == d => (a + c) * b
    case (ScalarVectorProduct(a, b), ScalarVectorProduct(c, d)) if a == c => a * (b + d)
    case (ScalarVectorProduct(a, b), c) if b == c => (a + 1) * b
    case (a, ScalarVectorProduct(b, c)) if a == c => (b + 1) * a
    case (ScalarVectorProduct(a, b), NegVector(ScalarVectorProduct(c, d))) if b == d => (a - c) * b
    case (ScalarVectorProduct(a, b), NegVector(ScalarVectorProduct(c, d))) if a == c => a * (b - d)
    case (NegVector(ScalarVectorProduct(a, b)), c) if b == c => (1 - a) * b
    case (a, NegVector(ScalarVectorProduct(b, c))) if a == c => (1 - b) * a
    case (a, b) if a == b => 2 * a
    case _ => this
  }
}

/** -x */
case class NegVector(x: Vector) extends Vector {
  override def toString: String = x match {
    case AddVector(_, _) => s"-($x)"
    case _ => s"-$x"
  }

  override def size: IntScalar = x.size

  override def apply(env: Map[String, Tensor]): Vector = -x(env)

  override def d(dx: Var): Vector = {
    -x.d(dx)
  }

  override def d(dx: VectorVar): Matrix = {
    -x.d(dx)
  }

  override def simplify: Vector = x match {
    case a @ ZeroVector(_) => a
    case VectorVal(a) => VectorVal(a.map(-_))
    case NegVector(a) => a
    case _ => this
  }
}

/** a * x */
case class ScalarVectorProduct(a: Scalar, x: Vector) extends Vector {
  override def toString: String = {
    val xs = x match {
      case AddVector(_, _) => s"($x)"
      case _ => x.toString
    }

    val as = a match {
      case Add(_, _) | Sub(_, _) | Power(_, _) | InnerProduct(_, _) => s"($a)"
      case _ => a.toString
    }

    s"$as * $xs"
  }

  override def size: IntScalar = x.size

  override def apply(env: Map[String, Tensor]): Vector = a(env) * x(env)

  override def d(dx: Var): Vector = {
    a * x.d(dx) + a.d(dx) * x
  }

  override def d(dx: VectorVar): Matrix = {
    a * x.d(dx) + a.d(dx) *~ x
  }

  override def simplify: Vector = (a, x) match {
    case (Val(0), _) => ZeroVector(size)
    case (_, ZeroVector(_)) => ZeroVector(size)
    case (Val(1), b) => b
    case (Val(-1), b) => -b
    case (Val(a), VectorVal(b)) => VectorVal(b.map(_ * a))
    case (Neg(a), NegVector(b)) => a * b
    case (Neg(a), b) => -(a * b)
    case (a, NegVector(b)) => -(a * b)
    case (a, ScalarVectorProduct(b, c)) => (a * b) * c
    case _ => this
  }
}

/** Inner product (x * y) */
case class InnerProduct(x: Vector, y: Vector) extends Scalar {
  if (x.size != y.size) throw new IllegalArgumentException(s"Vector sizes mismatch: ${x.size} vs ${y.size}")

  override def toString: String = s"<$x, $y>"

  override def apply(env: Map[String, Tensor]): Scalar = x(env) * y(env)

  override def d(dx: Var): Scalar = {
    (x.d(dx) * y) + (x * y.d(dx))
  }

  override def d(dx: VectorVar): Vector = {
    (x.d(dx) * y) + (y.d(dx) * x)
  }

  override def simplify: Scalar = (x, y) match {
    case (ZeroVector(_), _) => Val(0)
    case (_, ZeroVector(_)) => Val(0)
    case (VectorVal(a), VectorVal(b)) => Val(a.zip(b).map { case (x, y) => x * y }.sum)
    case (ScalarVectorProduct(Val(a), b), c) => a * (b * c)
    case (a, ScalarVectorProduct(Val(b), c)) => b * (a * c)
    case (NegVector(a), NegVector(b)) => a * b
    case (a, NegVector(b)) => -(a * b)
    case (NegVector(a), b) => -(a * b)
    case _ => this
  }
}

/** Outer product (x ** y) */
case class OuterProduct(x: Vector, y: Vector) extends Matrix {
  if (x.size != y.size) throw new IllegalArgumentException(s"Vector sizes mismatch: ${x.size} vs ${y.size}")

  override def toString: String = {
    val xs = x match {
      case AddVector(_, _) | ScalarVectorProduct(_, _) => s"($x)"
      case _ => x.toString
    }

    val ys = y match {
      case AddVector(_, _) | ScalarVectorProduct(_, _) => s"($y)"
      case _ => y.toString
    }

    s"$xs \u2297 $ys"
  }

  override def size: (IntScalar, IntScalar) = (x.size, y.size)

  override def apply(env: Map[String, Tensor]): Matrix = x(env) *~ y(env)

  override def d(dx: Var): Matrix = {
    (x.d(dx) *~ y) + (x *~ y.d(dx))
  }

  override def simplify: Matrix = (x, y) match {
    case (ZeroVector(_), _) => ZeroMatrix(size)
    case (_, ZeroVector(_)) => ZeroMatrix(size)
    case (ScalarVectorProduct(Val(a), b), c) => a * (b *~ c)
    case (a, ScalarVectorProduct(Val(b), c)) => b * (a *~ c)
    case (NegVector(a), NegVector(b)) => a *~ b
    case (a, NegVector(b)) => -(a *~ b)
    case (NegVector(a), b) => -(a *~ b)
    case _ => this
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy