All Downloads are FREE. Search and download functionalities are using the official Maven repository.

tri.covid19.coda.forecast.ForecastPanelModel.kt Maven / Gradle / Ivy

/*-
 * #%L
 * coda-app
 * --
 * Copyright (C) 2020 - 2022 Elisha Peterson
 * --
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *      http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * #L%
 */

package tri.covid19.coda.forecast

import javafx.beans.property.SimpleObjectProperty
import javafx.beans.property.SimpleStringProperty
import javafx.scene.control.Alert
import org.apache.commons.math3.exception.NoBracketingException
import org.apache.commons.math3.exception.TooManyEvaluationsException
import tornadofx.*
import tri.area.USA
import tri.area.usa.UsaAreaLookup
import tri.area.usa.Usa
import tri.covid19.coda.data.CovidForecasts
import tri.covid19.coda.data.CovidTimeSeriesSources
import tri.covid19.coda.history.METRIC_OPTIONS
import tri.covid19.coda.history.changeDoublingDataSeries
import tri.covid19.coda.history.hubbertSeries
import tri.covid19.coda.utils.ChartDataSeries
import tri.covid19.data.*
import tri.util.math.Sigmoid
import tri.timeseries.forecast.Forecast
import tri.timeseries.TimeSeries
import tri.util.DateRange
import tri.util.minus
import tri.util.userFormat
import java.time.LocalDate
import java.util.*
import kotlin.reflect.KMutableProperty1

/** Config for logistic projection. */
class ForecastPanelModel(var listener: () -> Unit = {}) {

    //region UI BOUND PROPERTIES

    // metric selection
    internal var areaId by property(USA.id)
    internal var selectedMetric by property(METRIC_OPTIONS[0])
    internal var perCapita by property(false)
    internal var smooth by property(true)
    internal var showLogisticPrediction by property(true)

    // user forecast
    internal val curveFitter = ForecastCurveFitter()
    internal val forecastInfoList = observableListOf()

    internal var showForecast by property(true)
    private var vActive by property(false)

    internal val _manualEquation = SimpleStringProperty(curveFitter.equation)
    internal val _manualPeak = SimpleStringProperty("")
    internal val _manualLogCumStdErr = SimpleStringProperty("")
    internal val _manualDeltaStdErr = SimpleStringProperty("")

    // other forecasts
    internal val otherForecasts = observableListOf(IHME, YYG).apply { onChange { listener() } }

    var showConfidence by property(true)
    var firstForecastDay: Number by property(90)
    var lastForecastDay: Number by property(120)
    private val forecastDateRange: DateRange
        get() = DateRange(curveFitter.numberToDate(firstForecastDay), curveFitter.numberToDate(lastForecastDay))

//    // historical forecasts
//    private var movingAverage by property(4)
//    private var projectionDays by property(10)

    //endregion

    //region JAVAFX UI PROPERTIES

    private fun  property(prop: KMutableProperty1<*, T>) = getProperty(prop).apply { addListener { _ ->
        updateData()
        listener()
    } }

    private fun  forecastProperty(prop: KMutableProperty1<*, T>) = curveFitter.getProperty(prop).apply {
        addListener { _ ->
            updateData()
            updateEquation()
            vActive = curveFitter.curve == Sigmoid.GEN_LOGISTIC
            listener()
        }
    }

    internal val _region = property(ForecastPanelModel::areaId)
    internal val _selectedMetric = property(ForecastPanelModel::selectedMetric)
    internal val _perCapita = property(ForecastPanelModel::perCapita)
    internal val _smooth = property(ForecastPanelModel::smooth)
    internal val _showLogisticPrediction = property(ForecastPanelModel::showLogisticPrediction)

    internal val _showForecast = property(ForecastPanelModel::showForecast)
    internal val _vActive = property(ForecastPanelModel::vActive)
//
//    internal val _movingAverage = property(ForecastPanelModel::movingAverage)
//    internal val _projectionDays = property(ForecastPanelModel::projectionDays)

    internal val _fitLabel = forecastProperty(ForecastCurveFitter::fitLabel)
    internal val _curve = forecastProperty(ForecastCurveFitter::curve)
    internal val _l = forecastProperty(ForecastCurveFitter::l)
    internal val _k = forecastProperty(ForecastCurveFitter::k)
    internal val _x0 = forecastProperty(ForecastCurveFitter::x0)
    internal val _v = forecastProperty(ForecastCurveFitter::v)

    internal val _firstFitDay = forecastProperty(ForecastCurveFitter::firstFitDay).apply { onChange { autofit() }}
    internal val _lastFitDay = forecastProperty(ForecastCurveFitter::lastFitDay).apply { onChange { autofit() }}
    internal val _fitCumulative = forecastProperty(ForecastCurveFitter::fitCumulative).apply { onChange { autofit() }}

    internal val _firstEvalDay = curveFitter.getProperty(ForecastCurveFitter::firstEvalDay).apply { onChange { calcErrors() }}
    internal val _lastEvalDay = curveFitter.getProperty(ForecastCurveFitter::lastEvalDay).apply { onChange { calcErrors() }}

    internal val _showConfidence = property(ForecastPanelModel::showConfidence)
    internal val _firstForecastDay = property(ForecastPanelModel::firstForecastDay)
    internal val _lastForecastDay = property(ForecastPanelModel::lastForecastDay)

    //endregion

    //region DATA FOR PROJECTION PLOT

    /** List of areas available for panel. */
    val areas: SortedSet by lazy {
        val dataAreas = LocalCovidDataQuery.allDataAreas().map { it.id }
        val forecastAreas = CovidForecasts.allForecasts.map { it.areaId }.toSet()
        (dataAreas + forecastAreas).toSortedSet()
    }

    /** Domain for raw data. */
    var domain: DateRange? = null

    /** The primary time series for the selected metric. */
    val mainSeries = SimpleObjectProperty()
    /** User's projection. */
    var userForecast: TimeSeries? = null

    /** Past forecasts. */
    var pastForecasts = PastForecasts()
    /** Other forecasts. */
    var externalForecasts = ExternalForecasts()

    private fun updateData() {
        val areaMetrics = CovidTimeSeriesSources.dailyReports(UsaAreaLookup.area(areaId, assumeUsState = true), selectedMetric)
        mainSeries.value = areaMetrics.firstOrNull { it.metric == selectedMetric }?.restrictNumberOfStartingZerosTo(0)
        domain = mainSeries.value?.domain?.shift(0, 30)

        val shift = if (smooth) -3.5 else 0.0
        userForecast = when {
            !showForecast -> null
            domain == null -> null
            else -> TimeSeries("User Forecast", areaId, "$selectedMetric (curve)", "",false, 0.0, domain!!.start,
                    domain!!.map { d -> curveFitter(d, shift) })
        }

        pastForecasts.metrics = areaMetrics.filter { showLogisticPrediction && ("predicted" in it.metric || "peak" in it.metric) }
        externalForecasts.forecasts = CovidForecasts.allForecasts
                .filter { it.model in otherForecasts }
                .filter { it.areaId == areaId && it.metric == selectedMetric }
                .filter { it.forecastDate in forecastDateRange }
    }

    //endregion

    //region SERIES BUILDERS

    internal fun cumulativeDataSeries() = dataseries {
        series(mainSeries.value?.maybeSmoothed())
        series(userForecast)
        series(pastForecasts.cumulative)
        series(externalForecasts.cumulative)
    }

    internal fun dailyDataSeries() = dataseries {
        series(mainSeries.value?.deltas()?.maybeSmoothed())
        series(userForecast?.deltas())
        series(pastForecasts.deltas)
        series(externalForecasts.cumulative.map { it.deltas() })
    }

    internal fun hubbertDataSeries() = dataseries {
        series(mainSeries.value?.hubbertSeries(7))
        series(userForecast?.hubbertSeries(1))
        series(externalForecasts.cumulative.map { it.hubbertSeries(1) })
    }

    internal fun changeDoublingDataSeries() = dataseries {
        series(mainSeries.value?.changeDoublingDataSeries(7))
        series(userForecast?.changeDoublingDataSeries(1))
        series(externalForecasts.cumulative.map { it.changeDoublingDataSeries(1) })
    }

    internal fun residualDataSeries() = dataseries {
        val daily = mainSeries.value?.deltas()?.maybeSmoothed()
        series(userForecast?.deltas()?.residuals(daily))
        series(externalForecasts.deltas.mapNotNull { it.residuals(daily) })
    }

    private fun TimeSeries.residuals(empirical: TimeSeries?): TimeSeries? {
        empirical ?: return null
        val commonDomain = domain.intersect(empirical.domain)
        commonDomain ?: return null
        return copy(start = commonDomain.start, values = commonDomain.map { empirical[it] - get(it) })
    }

    private fun dataseries(op: MutableList.() -> Unit) = mutableListOf().apply { op() }
    private fun MutableList.series(s: TimeSeries?) { series(listOfNotNull(s)) }
    private fun MutableList.series(s: List) {
        domain?.let { domain -> s.forEach { this += tri.covid19.coda.utils.series(it.metric, domain, it) } }
    }
    private fun MutableList.series(xy: Pair?, idFirst: Boolean = true) { series(listOfNotNull(xy), idFirst) }
    private fun MutableList.series(xyxy: List>, idFirst: Boolean = true) {
        domain?.let { domain -> xyxy.forEach { this += tri.covid19.coda.utils.series(if (idFirst) it.first.metric else it.second.metric, domain, it.first, it.second) } }
    }

    private fun TimeSeries.maybeSmoothed() = if (smooth) movingAverage(7) else this

    //endregion

    //region FORECAST CURVE

    /** Updates equation label whenever it changes. */
    private fun updateEquation() {
        _manualEquation.value = curveFitter.equation
        _manualPeak.value = try {
            val (x, y) = curveFitter.equationPeak()
            "${y.userFormat()} on day ${curveFitter.numberToDate(x) ?: "?"}"
        } catch (x: NoBracketingException) {
            ""
        }

        val se1 = curveFitter.cumulativeRmse(empirical = mainSeries.value)
        val se2 = curveFitter.deltaRmse(empirical = mainSeries.value)

        _manualLogCumStdErr.value = "RMSE = ${se1?.userFormat() ?: "?"} (totals)"
        _manualDeltaStdErr.value = "RMSE = ${se2?.userFormat() ?: "?"} (per day)"
    }

    //endregion

    //region ACTIONS

    /** Load the next US state in alphabetical order. */
    fun goToNextUsState() {
        val states = Usa.stateNames.toSortedSet()
        areaId = when {
            states.contains(areaId) -> states.rollAfter(areaId)
            else -> states.first()
        }
        autofit()
    }

    /** Load the next US state in alphabetical order. */
    fun goToPreviousUsState() {
        val states = Usa.stateNames.toSortedSet()
        areaId = when {
            states.contains(areaId) -> states.rollBefore(areaId)
            else -> states.last()
        }
        autofit()
    }

    private fun  SortedSet.rollAfter(x: X) = tailSet(x).elementAtOrNull(1) ?: first()
    private fun  SortedSet.rollBefore(x: X) = headSet(x).reversed().elementAtOrNull(0) ?: last()

    /** Runs autofit using current config. */
    fun autofit() {
        try {
            curveFitter.autofit(mainSeries.value)
        } catch (x: TooManyEvaluationsException) {
            alert(Alert.AlertType.ERROR, "Too many evaluations during curve fit.")
        }
    }

    /** Recalculates errors. */
    fun calcErrors() {
        updateEquation()
    }

    /** Loads selected forecast. */
    fun load(f: ForecastStats) {
        areaId = f.region.id
        selectedMetric = f.metric
        curveFitter.curve = f.sigmoidCurve
        curveFitter.l = f.sigmoidParameters?.load as Number
        curveFitter.k = f.sigmoidParameters?.k as Number
        curveFitter.x0 = f.sigmoidParameters?.x0 as Number
        curveFitter.v = f.sigmoidParameters?.v as Number

        f.fitDateRange?.run {
            curveFitter.firstFitDay = start.minus(LocalDate.now())
            curveFitter.lastFitDay = endInclusive.minus(LocalDate.now())
        }
    }

    /** Save current config as new forecast. */
    fun save() {
        val empirical = mainSeries.value
        if (empirical != null) {
            forecastInfoList.add(curveFitter.userForecastInfo(empirical))
        }
    }

    /** Save all other forecasts. */
    fun saveExternalForecastsToTable() {
        val empirical = mainSeries.value
        if (empirical != null) {
            externalForecasts.forecasts.filter { it.model in otherForecasts }
                    .forEach { forecastInfoList.add(curveFitter.forecastStats(it, empirical)) }
        }
    }

    //endregion

    //region DATA MANAGEMENT

    /** Provides access to past forecasts. */
    class PastForecasts(var metrics: List = listOf())
    /** Provides access to external forecasts. */
    class ExternalForecasts(var forecasts: List = listOf())

    val PastForecasts.cumulative
        get() = metrics.filter { "predicted" in it.metric && "peak" !in it.metric }
    val PastForecasts.deltas
        get() = metrics.filter { "predicted peak" in it.metric }
    val PastForecasts.peakDays
        get() = metrics.filter { "days" in it.metric }

    val ExternalForecasts.filtered
        get() = forecasts.filter { it.model in otherForecasts }.flatMap {
            it.data.filter { showConfidence || ("lower" !in it.metric && "upper" !in it.metric) }
        }
    val ExternalForecasts.cumulative
        get() = filtered.toMutableList()
    val ExternalForecasts.deltas
        get() = filtered.map { it.deltas() }

    //endregion


}





© 2015 - 2025 Weber Informatics LLC | Privacy Policy