All Downloads are FREE. Search and download functionalities are using the official Maven repository.

lucuma.core.util.TimestampUnion.scala Maven / Gradle / Ivy

There is a newer version: 0.108.0
Show newest version
// Copyright (c) 2016-2023 Association of Universities for Research in Astronomy, Inc. (AURA)
// For license information see LICENSE or https://opensource.org/licenses/BSD-3-Clause

package lucuma.core.util

import cats.Eq
import cats.Order.catsKernelOrderingForOrder
import cats.kernel.BoundedSemilattice
import cats.syntax.eq.*
import cats.syntax.foldable.*
import cats.syntax.option.*
import cats.syntax.traverse.*

import scala.collection.immutable.SortedSet

sealed class TimestampUnion private (val intervals: SortedSet[TimestampInterval]) {

  import TimestampUnion.Empty

  /**
   * Adds the specified interval, which will be merged into the collection
   * of intervals.
   */
  def add(n: TimestampInterval): TimestampUnion = {
    val (u, nʹ) = intervals.foldLeft((Empty, n)) { case ((res, nʹ), ti) =>
      if (nʹ.abuts(ti) || nʹ.intersects(ti)) (res, nʹ.span(ti)) else (res + ti, nʹ)
    }
    new TimestampUnion(u.intervals + nʹ)
  }

  /**
   * Alias for add.
   */
  def +(n: TimestampInterval): TimestampUnion =
    add(n)

  def ++(ns: IterableOnce[TimestampInterval]): TimestampUnion =
    ns.iterator.foldLeft(this)(_.add(_))

  /**
   * Removes the specified interval, which will be clipped out of the
   * collection of intervals.
   */
  def remove(del: TimestampInterval): TimestampUnion =
    intervals.foldLeft(Empty) { (res, ti) => res ++ ti.minus(del) }

  /**
   * Alias for remove.
   */
  def -(del: TimestampInterval): TimestampUnion =
    remove(del)

  def --(ns: IterableOnce[TimestampInterval]): TimestampUnion =
    ns.iterator.foldLeft(this)(_.remove(_))

  /**
   * Intersection of two TimestampUnions, creating a new union with intervals
   * whose times are covered in both.
   */
  def intersect(other: TimestampUnion): TimestampUnion = {
    def collectPoints(init: Vector[Timestamp], intervals: SortedSet[TimestampInterval]): Vector[Timestamp] =
      intervals.foldLeft(init) { (v, ti) => v.appended(ti.start).appended(ti.end) }

    val points = collectPoints(collectPoints(Vector.empty, intervals), other.intervals).sorted

    if (points.isEmpty)
      Empty
    else
      points.zip(points.tail).foldLeft(Empty) { case (res, (a, b)) =>
        if (contains(a) && other.contains(a)) res.add(TimestampInterval.between(a, b)) else res
      }
  }

  /**
   * Alias for intersect.
   */
  def ∩(other: TimestampUnion): TimestampUnion =
    intersect(other)

  def ++(other: TimestampUnion): TimestampUnion =
    ++(other.intervals)

  def --(other: TimestampUnion): TimestampUnion =
    --(other.intervals)

  def contains(timestamp: Timestamp): Boolean =
    intervals.exists(_.contains(timestamp))

  /**
   * Sums the time spanned by all the intervals in this union, assuming the
   * result fits in a `TimeSpan`.
   */
  def sum: Option[TimeSpan] =
    intervals.toList.traverse(_.timeSpan).flatMap { ts =>
      ts.foldLeft(TimeSpan.Zero.some) { case (s, t) => s.flatMap(_.add(t)) }
    }

  /**
   * Sums the time spanned by all the intervals in this union, capping the
   * result at `TimeSpan.Max`.  This is sufficient for most uses and more
   * convenient than `sum`.
   */
  def boundedSum: TimeSpan =
    intervals.foldMap(_.boundedTimeSpan)

  def isEmpty: Boolean =
    intervals.isEmpty

  override def equals(that: Any): Boolean =
    that match {
      case u: TimestampUnion => intervals === u.intervals
      case _                 => false
    }

  override def hashCode: Int =
    intervals.hashCode() * 31

  override def toString: String =
    s"{${intervals.mkString(", ")}}"

}

object TimestampUnion {

  /**
   * The empty TimestampUnion.
   */
  val Empty: TimestampUnion =
    new TimestampUnion(SortedSet.empty)

  def apply(intervals: TimestampInterval*): TimestampUnion =
    fromIntervals(intervals)

  def fromIntervals(intervals: IterableOnce[TimestampInterval]): TimestampUnion =
    Empty ++ intervals

  given Eq[TimestampUnion] =
    Eq.by(_.intervals)

  given BoundedSemilattice[TimestampUnion] =
    BoundedSemilattice.instance(Empty, _ ++ _)

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy