All Downloads are FREE. Search and download functionalities are using the official Maven repository.

smile.cas.Scalar.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (c) 2010-2021 Haifeng Li. All rights reserved.
 *
 * Smile is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * Smile is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with Smile.  If not, see .
 */

package smile.cas

/** Scalar: rank-0 tensor. */
trait Scalar extends Tensor {
  override def rank: Option[Int] = Some(0)
  override def shape: Option[Array[IntScalar]] = Some(Array())

  /** Applies the expression. */
  def apply(env: Map[String, Tensor]): Scalar
  /** Applies the expression. */
  def apply(env: (String, Tensor)*): Scalar = apply(Map(env: _*))
  /** Simplify the expression. */
  def simplify: Scalar = this

  /** Returns the derivative. */
  def d(dx: Var): Scalar
  /** Returns the gradient vector. */
  def d(dx: Var*): Vector = Vars(dx.map(d): _*).simplify
  /** Returns the gradient vector. */
  def d(dx: VectorVar): Vector

  def + (y: Scalar): Scalar = Add(this, y).simplify
  def - (y: Scalar): Scalar = Sub(this, y).simplify
  def * (y: Scalar): Scalar = Mul(this, y).simplify
  def / (y: Scalar): Scalar = Div(this, y).simplify
  def ** (y: Scalar): Scalar = Power(this, y).simplify
  def unary_+ : Scalar = simplify
  def unary_- : Scalar = Neg(this).simplify

  def * (y: Vector): Vector = ScalarVectorProduct(this, y).simplify
  def * (y: Matrix): Matrix = ScalarMatrixProduct(this, y).simplify
}

/** Scalar: rank-0 tensor. */
trait IntScalar extends Tensor {
  override def rank: Option[Int] = Some(0)
  override def shape: Option[Array[IntScalar]] = Some(Array())

  /** Explicit conversion of int to float. */
  def toScalar: Scalar = Int2Scalar(this)
  /** Applies the expression. */
  def apply(env: Map[String, Tensor]): IntScalar
  /** Applies the expression. */
  def apply(env: (String, Tensor)*): IntScalar = apply(Map(env: _*))
  /** Simplify the expression. */
  def simplify: IntScalar = this

  def + (y: IntScalar): IntScalar = IntAdd(this, y).simplify
  def - (y: IntScalar): IntScalar = IntSub(this, y).simplify
  def * (y: IntScalar): IntScalar = IntMul(this, y).simplify
  def / (y: IntScalar): IntScalar = IntDiv(this, y).simplify
  def % (y: IntScalar): IntScalar = Mod(this, y).simplify
  def ** (y: IntScalar): IntScalar = IntPower(this, y).simplify
  def unary_+ : IntScalar = simplify
  def unary_- : IntScalar = IntNeg(this).simplify
}

/** Explicit conversion of int to float. */
case class Int2Scalar(x: IntScalar) extends Scalar {
  override def toString: String = x.toString

  override def apply(env: Map[String, Tensor]): Scalar = Int2Scalar(x.apply(env))
  override def simplify: Scalar = Int2Scalar(x.simplify)

  override def d(dx: Var): Scalar = Val(0)
  override def d(dx: VectorVar): Vector = ZeroVector(dx.size)
}

/** Scalar value. */
case class Val(x: Double) extends Scalar {
  override def toString: String = x.toString

  override def apply(env: Map[String, Tensor]): Val = this

  override def d(dx: Var): Scalar = Val(0)
  override def d(dx: VectorVar): Vector = ZeroVector(dx.size)
}

/** Integer scalar value. */
case class IntVal(x: Int) extends IntScalar {
  override def toString: String = x.toString

  override def toScalar: Scalar = Val(x)

  override def apply(env: Map[String, Tensor]): IntVal = this
}

/** Constant value. */
case class Const(symbol: String) extends Scalar {
  override def toString: String = symbol

  override def apply(env: Map[String, Tensor]): Scalar = env.get(symbol) match {
    case None => this
    case Some(x : Val) => x
    case x => throw new IllegalArgumentException(s"Invalid type: ${x.getClass}, expected Val")
  }

  override def d(dx: Var): Scalar = Val(0)
  override def d(dx: VectorVar): Vector = ZeroVector(dx.size)
}

/** Integer constant value. */
case class IntConst(symbol: String) extends IntScalar {
  override def toString: String = symbol

  override def apply(env: Map[String, Tensor]): IntScalar = env.get(symbol) match {
    case None => this
    case Some(x : IntVal) => x
    case x => throw new IllegalArgumentException(s"Invalid type: ${x.getClass}, expected IntVal")
  }
}

/** Scalar variable */
case class Var(symbol: String) extends Scalar {
  override def toString: String = symbol

  override def d(dx: Var): Scalar = {
    if (symbol.equals(dx.symbol)) Val(1) else Val(0)
  }

  override def d(dx: VectorVar): Vector = GradientVector(this, dx).simplify

  override def apply(env: Map[String, Tensor]): Scalar = env.get(symbol) match {
    case None => this
    case Some(x : Scalar) => x
    case x => throw new IllegalArgumentException(s"Invalid type: ${x.getClass}, expected Scalar")
  }
}

/** Integer scalar variable */
case class IntVar(symbol: String) extends IntScalar {
  override def toString: String = symbol

  override def apply(env: Map[String, Tensor]): IntScalar = env.get(symbol) match {
    case None => this
    case Some(x : IntScalar) => x
    case x => throw new IllegalArgumentException(s"Invalid type: ${x.getClass}, expected IntScalar")
  }
}

/** x + y */
case class Add(x: Scalar, y: Scalar) extends Scalar {
  override def toString: String = s"$x + $y"

  override def equals(o: Any): Boolean = o match {
    case Add(a, b) => (x == a && y == b) || (x == b && y == a)
    case _ => false
  }

  override def apply(env: Map[String, Tensor]): Scalar = x(env) + y(env)

  override def d(dx: Var): Scalar = {
    x.d(dx) + y.d(dx)
  }

  override def d(dx: VectorVar): Vector = {
    x.d(dx) + y.d(dx)
  }

  override def simplify: Scalar = (x, y) match {
    case (Val(a), Val(b)) => Val(a+b)
    case (Val(0), b) => b
    case (a, Val(0)) => a
    case (Neg(a), Neg(b)) => -(a + b)
    case (Neg(a), b) => b - a
    case (a, Neg(b)) => a - b
    case (Mul(a, b), Mul(c, d)) if a == c => a * (b + d)
    case (Mul(a, b), Mul(c, d)) if a == d => a * (b + c)
    case (Mul(a, b), Mul(c, d)) if b == c => b * (a + d)
    case (Mul(a, b), Mul(c, d)) if b == d => b * (a + c)
    case (Mul(a, b), c) if a == c => a * (b + 1)
    case (Mul(a, b), c) if b == c => b * (a + 1)
    case (a, Mul(b, c)) if a == b => a * (c + 1)
    case (a, Mul(b, c)) if a == c => a * (b + 1)
    case (Power(Sin(a), Val(2)), Power(Cos(b), Val(2))) if a == b => Val(1)
    case (Power(Cos(a), Val(2)), Power(Sin(b), Val(2))) if a == b => Val(1)
    case (Log(a), Log(b)) => Log(a * b)
    case (a, b) if a == b => 2 * a
    case _ => this
  }
}

/** x + y */
case class IntAdd(x: IntScalar, y: IntScalar) extends IntScalar {
  override def toString: String = {
    val xs = x match {
      case Mod(_, _) => s"($x)"
      case _ => x.toString
    }

    val ys = y match {
      case Mod(_, _) => s"($y)"
      case _ => y.toString
    }

    s"$xs + $ys"
  }

  override def apply(env: Map[String, Tensor]): IntScalar = x(env) + y(env)

  override def simplify: IntScalar = (x, y) match {
    case (IntVal(a), IntVal(b)) => IntVal(a+b)
    case (IntVal(0), b) => b
    case (a, IntVal(0)) => a
    case (IntNeg(a), IntNeg(b)) => -(a + b)
    case (IntNeg(a), b) => b - a
    case (a, IntNeg(b)) => a - b
    case (IntMul(a, b), IntMul(c, d)) if a == c => a * (b + d)
    case (IntMul(a, b), IntMul(c, d)) if a == d => a * (b + c)
    case (IntMul(a, b), IntMul(c, d)) if b == c => b * (a + d)
    case (IntMul(a, b), IntMul(c, d)) if b == d => b * (a + c)
    case (IntMul(a, b), c) if a == c => a * (b + 1)
    case (IntMul(a, b), c) if b == c => b * (a + 1)
    case (a, IntMul(b, c)) if a == b => a * (c + 1)
    case (a, IntMul(b, c)) if a == c => a * (b + 1)
    case (a, b) if a == b => 2 * a
    case _ => this
  }
}

/** x - y */
case class Sub(x: Scalar, y: Scalar) extends Scalar {
  override def toString: String = s"$x - $y"

  override def apply(env: Map[String, Tensor]): Scalar = x(env) - y(env)

  override def d(dx: Var): Scalar = {
    x.d(dx) - y.d(dx)
  }

  override def d(dx: VectorVar): Vector = {
    x.d(dx) - y.d(dx)
  }

  override def simplify: Scalar = (x, y) match {
    case (Val(a), Val(b)) => Val(a-b)
    case (Val(0), b) => -b
    case (a, Val(0)) => a
    case (Neg(a), Neg(b)) => b - a
    case (a, Neg(b)) => a + b
    case (Neg(a), b) => -(a + b)
    case (Mul(a, b), Mul(c, d)) if a == c => a * (b - d)
    case (Mul(a, b), Mul(c, d)) if a == d => a * (b - c)
    case (Mul(a, b), Mul(c, d)) if b == c => b * (a - d)
    case (Mul(a, b), Mul(c, d)) if b == d => b * (a - c)
    case (Mul(a, b), c) if a == c => a * (b - 1)
    case (Mul(a, b), c) if b == c => b * (a - 1)
    case (a, Mul(b, c)) if a == b => a * (c - 1)
    case (a, Mul(b, c)) if a == c => a * (b - 1)
    case (Log(a), Log(b)) => Log(a / b)
    case (a, b) if a == b => Val(0)
    case _ => this
  }
}

/** x - y */
case class IntSub(x: IntScalar, y: IntScalar) extends IntScalar {
  override def toString: String = {
    val xs = x match {
      case Mod(_, _) => s"($x)"
      case _ => x.toString
    }

    val ys = y match {
      case Mod(_, _) => s"($y)"
      case _ => y.toString
    }

    s"$xs - $ys"
  }

  override def apply(env: Map[String, Tensor]): IntScalar = x(env) - y(env)

  override def simplify: IntScalar = (x, y) match {
    case (IntVal(a), IntVal(b)) => IntVal(a-b)
    case (IntVal(0), b) => -b
    case (a, IntVal(0)) => a
    case (IntNeg(a), IntNeg(b)) => b - a
    case (a, IntNeg(b)) => a + b
    case (IntNeg(a), b) => -(a + b)
    case (IntMul(a, b), IntMul(c, d)) if a == c => a * (b - d)
    case (IntMul(a, b), IntMul(c, d)) if a == d => a * (b - c)
    case (IntMul(a, b), IntMul(c, d)) if b == c => b * (a - d)
    case (IntMul(a, b), IntMul(c, d)) if b == d => b * (a - c)
    case (IntMul(a, b), c) if a == c => a * (b - 1)
    case (IntMul(a, b), c) if b == c => b * (a - 1)
    case (a, IntMul(b, c)) if a == b => a * (c - 1)
    case (a, IntMul(b, c)) if a == c => a * (b - 1)
    case (a, b) if a == b => IntVal(0)
    case _ => this
  }
}

/** -x */
case class Neg(x: Scalar) extends Scalar {
  override def toString: String = x match {
    case Add(_, _) | Sub(_, _) | Power(_, _) => s"-($x)"
    case _ => s"-$x"
  }

  override def apply(env: Map[String, Tensor]): Scalar = -x(env)

  override def d(dx: Var): Scalar = {
    -x.d(dx)
  }

  override def d(dx: VectorVar): Vector = {
    -x.d(dx)
  }

  override def simplify: Scalar = x match {
    case a @ Val(0) => a
    case Val(a) => Val(-a)
    case Neg(a) => a
    case _ => this
  }
}

/** -x */
case class IntNeg(x: IntScalar) extends IntScalar {
  override def toString: String = x match {
    case IntAdd(_, _) | IntSub(_, _) | IntPower(_, _) | Mod(_, _) => s"-($x)"
    case _ => s"-$x"
  }

  override def apply(env: Map[String, Tensor]): IntScalar = -x(env)

  override def simplify: IntScalar = x match {
    case a @ IntVal(0) => a
    case IntVal(a) => IntVal(-a)
    case IntNeg(a) => a
    case _ => this
  }
}

/** x * y */
case class Mul(x: Scalar, y: Scalar) extends Scalar {
  override def toString: String = {
    val xs = x match {
      case Add(_, _) | Sub(_, _) | Power(_, _) => s"($x)"
      case _ => x.toString
    }

    val ys = y match {
      case Add(_, _) | Sub(_, _) | Power(_, _) => s"($y)"
      case _ => y.toString
    }

    s"$xs * $ys"
  }

  override def apply(env: Map[String, Tensor]): Scalar = x(env) * y(env)

  override def d(dx: Var): Scalar = {
    y * x.d(dx) + x * y.d(dx)
  }

  override def d(dx: VectorVar): Vector = {
    y * x.d(dx) + x * y.d(dx)
  }

  override def simplify: Scalar = (x, y) match {
    case (Val(a), Val(b)) => Val(a*b)
    case (Val(1), b) => b
    case (Val(-1), b) => -b
    case (a, Val(1)) => a
    case (a, Val(-1)) => -a
    case (a @ Val(0), _) => a
    case (_, b @ Val(0)) => b
    case (Neg(a), Neg(b)) => a * b
    case (Neg(a), b) => -(a * b)
    case (a, Neg(b)) => -(a * b)
    case (Div(a, b), Div(c, d)) => (a * c) / (b * d)
    case (a, Div(b, c)) => (a * b) / c
    case (Div(a, b), c) => (a * c) / b
    case (a, Power(b, c)) if a == b => a ** (c + 1)
    case (Power(b, c), a) if a == b => a ** (c + 1)
    case (Power(a, b), Power(c, d)) if a == c => a ** (b + d)
    case (Exp(a), Exp(b)) => exp(a + b)
    case (Tan(a), Cot(b)) if a == b => Val(1)
    case (Cot(a), Tan(b)) if a == b => Val(1)
    case (Tan(a), Cos(b)) if a == b => sin(a)
    case (Cos(a), Tan(b)) if a == b => sin(a)
    case (Cot(a), Sin(b)) if a == b => cos(a)
    case (Sin(a), Cot(b)) if a == b => cos(a)
    case (a, Mul(b, c)) if a * b != Mul(a, b) => (a * b) * c
    case (a, Mul(b, c)) if a * c != Mul(a, c) => (a * c) * b
    case (Mul(a, b), c) if a * c != Mul(a, c) => (a * c) * b
    case (Mul(a, b), c) if b * c != Mul(b, c) => a * (b * c)
    case (a, b @ Val(_)) => b * a
    case (a, b) if a == b => a ** 2
    case _ => this
  }
}

/** x * y */
case class IntMul(x: IntScalar, y: IntScalar) extends IntScalar {
  override def toString: String = {
    val xs = x match {
      case IntAdd(_, _) | IntSub(_, _) | IntPower(_, _) | Mod(_, _) => s"($x)"
      case _ => x.toString
    }

    val ys = y match {
      case IntAdd(_, _) | IntSub(_, _) | IntPower(_, _) | Mod(_, _) => s"($y)"
      case _ => y.toString
    }

    s"$xs * $ys"
  }

  override def apply(env: Map[String, Tensor]): IntScalar = x(env) * y(env)

  override def simplify: IntScalar = (x, y) match {
    case (IntVal(a), IntVal(b)) => IntVal(a*b)
    case (IntVal(1), b) => b
    case (IntVal(-1), b) => -b
    case (a, IntVal(1)) => a
    case (a, IntVal(-1)) => -a
    case (a @ IntVal(0), _) => a
    case (_, b @ IntVal(0)) => b
    case (IntNeg(a), IntNeg(b)) => a * b
    case (IntNeg(a), b) => -(a * b)
    case (a, IntNeg(b)) => -(a * b)
    case (IntDiv(a, b), IntDiv(c, d)) => (a * c) / (b * d)
    case (a, IntDiv(b, c)) => (a * b) / c
    case (IntDiv(a, b), c) => (a * c) / b
    case (a, IntPower(b, c)) if a == b => a ** (c + 1)
    case (IntPower(b, c), a) if a == b => a ** (c + 1)
    case (IntPower(a, b), IntPower(c, d)) if a == c => a ** (b + d)
    case (a, IntMul(b, c)) if a * b != IntMul(a, b) => (a * b) * c
    case (a, IntMul(b, c)) if a * c != IntMul(a, c) => (a * c) * b
    case (IntMul(a, b), c) if a * c != IntMul(a, c) => (a * c) * b
    case (IntMul(a, b), c) if b * c != IntMul(b, c) => a * (b * c)
    case (a, b @ IntVal(_)) => b * a
    case (a, b) if a == b => a ** 2
    case _ => this
  }
}

/** x / y */
case class Div(x: Scalar, y: Scalar) extends Scalar {
  override def toString: String = {
    val xs = x match {
      case Add(_, _) | Sub(_, _) | Mul(_, _) | Div(_, _) | Power(_, _) => s"($x)"
      case _ => x.toString
    }

    val ys = y match {
      case Add(_, _) | Sub(_, _) | Mul(_, _) | Div(_, _) | Power(_, _) => s"($y)"
      case _ => y.toString
    }

    s"$xs / $ys"
  }

  override def apply(env: Map[String, Tensor]): Scalar = x(env) / y(env)

  override def d(dx: Var): Scalar = {
    val dy = y.d(dx)
    dy match {
      case Val(0) => x.d(dx) / y
      case _ => (y * x.d(dx) - x * dy) / (y ** 2)
    }
  }

  override def d(dx: VectorVar): Vector = {
    (y * x.d(dx) - x * y.d(dx)) / (y ** 2)
  }

  override def simplify: Scalar = (x, y) match {
    case (_, Val(0)) => throw new ArithmeticException("/ by zero")
    case (a @ Val(0), _) => a
    case (a, Val(1)) => a
    case (a, Val(-1)) => -a
    case (Val(a), Val(b)) => Val(a/b)
    case (Neg(a), Neg(b)) => a / b
    case (Neg(a), b) => -(a / b)
    case (a, Neg(b)) => -(a / b)
    case (Div(a, b), Div(c, d)) => (a * d) / (b * c)
    case (a, Div(b, c)) => (a * c) / b
    case (Div(a, b), c) => a / (b * c)
    case (a, Mul(b, c)) if a / b != Div(a, b) => (a / b) / c
    case (a, Mul(b, c)) if a / c != Div(a, c) => (a / c) / b
    case (Mul(a, b), c) if a / c != Div(a, c) => (a / c) * b
    case (Mul(a, b), c) if b / c != Div(b, c) => a * (b / c)
    case (Exp(a), Exp(b)) => Exp(a - b)
    case (a, Power(b, c)) if a * Power(b, -c) != Mul(a, Power(b, -c)) => a * Power(b, -c)
    case (Sin(a), Cos(b)) if a == b => tan(a)
    case (Cos(a), Sin(b)) if a == b => cot(a)
    case (a, Tan(b)) => a * cot(b)
    case (a, Cot(b)) => a * tan(b)
    case (a, b) if a == b => Val(1)
    case _ => this
  }
}

/** x / y */
case class IntDiv(x: IntScalar, y: IntScalar) extends IntScalar {
  override def toString: String = {
    val xs = x match {
      case IntAdd(_, _) | IntSub(_, _) | IntMul(_, _) | IntDiv(_, _) | IntPower(_, _) | Mod(_, _) => s"($x)"
      case _ => x.toString
    }

    val ys = y match {
      case IntAdd(_, _) | IntSub(_, _) | IntMul(_, _) | IntDiv(_, _) | IntPower(_, _) | Mod(_, _) => s"($y)"
      case _ => y.toString
    }

    s"$xs / $ys"
  }

  override def apply(env: Map[String, Tensor]): IntScalar = x(env) / y(env)

  override def simplify: IntScalar = (x, y) match {
    case (_, IntVal(0)) => throw new ArithmeticException("/ by zero")
    case (a @ IntVal(0), _) => a
    case (a, IntVal(1)) => a
    case (a, IntVal(-1)) => -a
    case (IntVal(a), IntVal(b)) => IntVal(a/b)
    case (IntNeg(a), IntNeg(b)) => IntDiv(a, b)
    case (IntNeg(a), b) => IntNeg(IntDiv(a, b))
    case (a, IntNeg(b)) => IntNeg(IntDiv(a, b))
    case (IntDiv(a, b), IntDiv(c, d)) => (a * d) / (b * c)
    case (a, IntDiv(b, c)) => (a * c) / b
    case (IntDiv(a, b), c) => a / (b * c)
    case (a, IntMul(b, c)) if a / b != IntDiv(a, b) => (a / b) / c
    case (a, IntMul(b, c)) if a / c != IntDiv(a, c) => (a / c) / b
    case (IntMul(a, b), c) if a / c != IntDiv(a, c) => (a / c) * b
    case (IntMul(a, b), c) if b / c != IntDiv(b, c) => a * (b / c)
    case (a, b) if a == b => IntVal(1)
    case _ => this
  }
}

/** x ** y */
case class Power(x: Scalar, y: Scalar) extends Scalar {
  override def toString: String = {
    val xs = x match {
      case Add(_, _) | Sub(_, _) | Mul(_, _) | Div(_, _) | Power(_, _) => s"($x)"
      case _ => x.toString
    }

    val ys = y match {
      case Add(_, _) | Sub(_, _) | Mul(_, _) | Div(_, _) | Power(_, _) => s"($y)"
      case _ => y.toString
    }

    s"$xs ** $ys"
  }

  override def apply(env: Map[String, Tensor]): Scalar = x(env) ** y(env)

  override def d(dx: Var): Scalar = {
    (y.d(dx), x.d(dx)) match {
      case (Val(0), Val(0)) => Val(0)
      case (_, Val(0)) => this * log(x)
      case (Val(0), dx) => y * (x ** (y - 1)) * dx
      case (dy, dx) => this * (dy * log(x) + (y / x) * dx)
    }
  }

  override def d(dx: VectorVar): Vector = {
    this * (y.d(dx) * log(x) + (y / x) * x.d(dx))
  }

  override def simplify: Scalar = (x, y) match {
    case (Val(0), Val(0)) => throw new ArithmeticException("0 ** 0")
    case (a @ Val(0), _) => a
    case (a @ Val(1), _) => a
    case (_, Val(0)) => Val(1)
    case (a, Val(1)) => a
    case (a, Val(-1)) => 1 / a
    case (Val(a), Val(b)) => Val(Math.pow(a, b))
    case (Val(Math.E), b) => exp(b)
    case (Power(a, b), c) => a ** (b * c)
    case _ => this
  }
}

/** x ** y */
case class IntPower(x: IntScalar, y: IntScalar) extends IntScalar {
  override def toString: String = {
    val xs = x match {
      case IntAdd(_, _) | IntSub(_, _) | IntMul(_, _) | IntDiv(_, _) | IntPower(_, _) | Mod(_, _) => s"($x)"
      case _ => x.toString
    }

    val ys = y match {
      case IntAdd(_, _) | IntSub(_, _) | IntMul(_, _) | IntDiv(_, _) | IntPower(_, _) | Mod(_, _) => s"($y)"
      case _ => y.toString
    }

    s"$xs ** $ys"
  }

  override def apply(env: Map[String, Tensor]): IntScalar = x(env) ** y(env)

  override def simplify: IntScalar = (x, y) match {
    case (IntVal(0), IntVal(0)) => throw new ArithmeticException("0 ** 0")
    case (_, IntVal(b)) if b < 0 => throw new ArithmeticException("Negative exponent is not allowed for IntPower")
    case (a @ IntVal(0), _) => a
    case (a @ IntVal(1), _) => a
    case (_, IntVal(0)) => IntVal(1)
    case (a, IntVal(1)) => a
    case (IntVal(a), IntVal(b)) => IntVal(Math.pow(a, b).toInt)
    case (IntPower(a, b), c) => a ** (b * c)
    case _ => this
  }
}

/** x % y */
case class Mod(x: IntScalar, y: IntScalar) extends IntScalar {
  override def toString: String = {
    val xs = x match {
      case IntAdd(_, _) | IntSub(_, _) | IntMul(_, _) | IntDiv(_, _) | IntPower(_, _) | Mod(_, _) => s"($x)"
      case _ => x.toString
    }

    val ys = y match {
      case IntAdd(_, _) | IntSub(_, _) | IntMul(_, _) | IntDiv(_, _) | IntPower(_, _) | Mod(_, _) => s"($y)"
      case _ => y.toString
    }

    s"$xs % $ys"
  }

  override def apply(env: Map[String, Tensor]): IntScalar = x(env) % y(env)

  override def simplify: IntScalar = (x, y) match {
    case (_, IntVal(0)) => throw new ArithmeticException("% by zero")
    case (IntVal(0), _) => IntVal(0)
    case (_, IntVal(1)) => IntVal(0)
    case (IntVal(a), IntVal(b)) => IntVal(a % b)
    case (IntMul(a, b), c) if a % c == IntVal(0) || b % c == IntVal(0) => IntVal(0)
    case (a, b) if a == b => IntVal(0)
    case _ => this
  }
}

/** exp(x) */
case class Exp(x: Scalar) extends Scalar {
  override def toString: String = s"exp($x)"

  override def apply(env: Map[String, Tensor]): Scalar = exp(x(env))

  override def d(dx: Var): Scalar = {
    x.d(dx) * this
  }

  override def d(dx: VectorVar): Vector = {
    this * x.d(dx)
  }

  override def simplify: Scalar = x match {
    case Val(a) => Val(Math.exp(a))
    case Log(a) => a
    case _ => this
  }
}

/** log(x) */
case class Log(x: Scalar) extends Scalar {
  override def toString: String = s"log($x)"

  override def apply(env: Map[String, Tensor]): Scalar = log(x(env))

  override def d(dx: Var): Scalar = {
    (1 / x) * x.d(dx)
  }

  override def d(dx: VectorVar): Vector = {
    (1 / x) * x.d(dx)
  }

  override def simplify: Scalar = x match {
    case Val(a) => Val(Math.log(a))
    case Exp(a) => a
    case Power(a, b) => b * log(a)
    case _ => this
  }
}

/** sin(x) */
case class Sin(x: Scalar) extends Scalar {
  override def toString: String = s"sin($x)"

  override def apply(env: Map[String, Tensor]): Scalar = sin(x(env))

  override def d(dx: Var): Scalar = {
    x.d(dx) * cos(x)
  }

  override def d(dx: VectorVar): Vector = {
    cos(x) * x.d(dx)
  }

  override def simplify: Scalar = x match {
    case Val(a) => Val(Math.sin(a))
    case ArcSin(a) => a
    case _ => this
  }
}

/** cos(x) */
case class Cos(x: Scalar) extends Scalar {
  override def toString: String = s"cos($x)"

  override def apply(env: Map[String, Tensor]): Scalar = cos(x(env))

  override def d(dx: Var): Scalar = {
    x.d(dx) * (-sin(x))
  }

  override def d(dx: VectorVar): Vector = {
    -sin(x) * x.d(dx)
  }

  override def simplify: Scalar = x match {
    case Val(a) => Val(Math.cos(a))
    case ArcCos(a) => a
    case _ => this
  }
}

/** tan(x) */
case class Tan(x: Scalar) extends Scalar {
  override def toString: String = s"tan($x)"

  override def apply(env: Map[String, Tensor]): Scalar = tan(x(env))

  override def d(dx: Var): Scalar = {
    x.d(dx) / (cos(x) ** 2)
  }

  override def d(dx: VectorVar): Vector = {
    x.d(dx) / (cos(x) ** 2)
  }

  override def simplify: Scalar = x match {
    case Val(a) => Val(Math.tan(a))
    case ArcTan(a) => a
    case _ => this
  }
}

/** cot(x) */
case class Cot(x: Scalar) extends Scalar {
  override def toString: String = s"cot($x)"

  override def apply(env: Map[String, Tensor]): Scalar = cot(x(env))

  override def d(dx: Var): Scalar = {
    -x.d(dx) / (sin(x) ** 2)
  }

  override def d(dx: VectorVar): Vector = {
    -x.d(dx) / (sin(x) ** 2)
  }

  override def simplify: Scalar = x match {
    case Val(a) => Val(1.0 / Math.tan(a))
    case ArcCot(a) => a
    case _ => this
  }
}

/** asin(x) */
case class ArcSin(x: Scalar) extends Scalar {
  override def toString: String = s"asin($x)"

  override def apply(env: Map[String, Tensor]): Scalar = asin(x(env))

  override def d(dx: Var): Scalar = {
    x.d(dx) / sqrt(1.0 - x ** 2.0)
  }

  override def d(dx: VectorVar): Vector = {
    x.d(dx) / sqrt(1.0 - x ** 2.0)
  }

  override def simplify: Scalar = x match {
    case Val(a) => Val(Math.asin(a))
    case Sin(a) => a
    case _ => this
  }
}

/** acos(x) */
case class ArcCos(x: Scalar) extends Scalar {
  override def toString: String = s"acos($x)"

  override def apply(env: Map[String, Tensor]): Scalar = acos(x(env))

  override def d(dx: Var): Scalar = {
    -x.d(dx) / sqrt(1.0 - x ** 2.0)
  }

  override def d(dx: VectorVar): Vector = {
    -x.d(dx) / sqrt(1.0 - x ** 2.0)
  }

  override def simplify: Scalar = x match {
    case Val(a) => Val(Math.acos(a))
    case Cos(a) => a
    case _ => this
  }
}

/** atan(x) */
case class ArcTan(x: Scalar) extends Scalar {
  override def toString: String = s"atan($x)"

  override def apply(env: Map[String, Tensor]): Scalar = atan(x(env))

  override def d(dx: Var): Scalar = {
    x.d(dx) / (1.0 + x ** 2.0)
  }

  override def d(dx: VectorVar): Vector = {
    x.d(dx) / (1.0 + x ** 2.0)
  }

  override def simplify: Scalar = x match {
    case Val(a) => Val(Math.atan(a))
    case Tan(a) => a
    case _ => this
  }
}

/** acot(x) */
case class ArcCot(x: Scalar) extends Scalar {
  override def toString: String = s"acot($x)"

  override def apply(env: Map[String, Tensor]): Scalar = acot(x(env))

  override def d(dx: Var): Scalar = {
    -x.d(dx) / (1.0 + x ** 2.0)
  }

  override def d(dx: VectorVar): Vector = {
    -x.d(dx) / (1.0 + x ** 2.0)
  }

  override def simplify: Scalar = x match {
    case Val(a) => Val(Math.PI * 0.5 - Math.atan(a))
    case Cot(a) => a
    case _ => this
  }
}

/** abs(x) */
case class Abs(x: Scalar) extends Scalar {
  override def toString: String = s"abs($x)"

  override def apply(env: Map[String, Tensor]): Scalar = abs(x(env))

  override def d(dx: Var): Scalar = {
    x.d(dx) * (this / x)
  }

  override def d(dx: VectorVar): Vector = {
    x.d(dx) * (this / x)
  }

  override def simplify: Scalar = x.simplify match {
    case Val(a) => Val(Math.abs(a))
    case _ => this
  }
}

/** ceil(x) */
case class Ceil(x: Scalar) extends IntScalar {
  override def toString: String = s"ceil($x)"

  override def apply(env: Map[String, Tensor]): IntScalar = ceil(x(env))

  override def simplify: IntScalar = x.simplify match {
    case Val(a) => IntVal(Math.ceil(a).toInt)
    case _ => this
  }
}

/** floor(x) */
case class Floor(x: Scalar) extends IntScalar {
  override def toString: String = s"floor($x)"

  override def apply(env: Map[String, Tensor]): IntScalar = floor(x(env))

  override def simplify: IntScalar = x.simplify match {
    case Val(a) => IntVal(Math.floor(a).toInt)
    case _ => this
  }
}

/** round(x) */
case class Round(x: Scalar) extends IntScalar {
  override def toString: String = s"round($x)"

  override def apply(env: Map[String, Tensor]): IntScalar = round(x(env))

  override def simplify: IntScalar = x.simplify match {
    case Val(a) => IntVal(Math.round(a).toInt)
    case _ => this
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy