All Downloads are FREE. Search and download functionalities are using the official Maven repository.

it.unibo.alchemist.boundary.extractors.MeanSquaredError.kt Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (C) 2010-2023, Danilo Pianini and contributors
 * listed, for each module, in the respective subproject's build.gradle.kts file.
 *
 * This file is part of Alchemist, and is distributed under the terms of the
 * GNU General Public License, with a linking exception,
 * as described in the file LICENSE in the Alchemist distribution's top directory.
 */

package it.unibo.alchemist.boundary.extractors

import it.unibo.alchemist.model.Actionable
import it.unibo.alchemist.model.Environment
import it.unibo.alchemist.model.Incarnation
import it.unibo.alchemist.model.Molecule
import it.unibo.alchemist.model.Time
import it.unibo.alchemist.util.StatUtil
import org.apache.commons.math3.stat.descriptive.UnivariateStatistic

/**
 * Exports the Mean Squared Error for the concentration of some molecule, given
 * another molecule that carries the correct result. The correct value is
 * extracted from every node, then the provided {@link UnivariateStatistic} is
 * applied to get a single, global correct value. Then, the actual value is
 * extracted from every node, its value is compared (subtracted) to the computed
 * correct value, it gets squared, and then logged.
 *
 * @param  concentration type
 */
class MeanSquaredError
    @JvmOverloads
    constructor(
        incarnation: Incarnation,
        localCorrectValueMolecule: String,
        localCorrectValueProperty: String = "",
        statistics: String,
        localValueMolecule: String,
        localValueProperty: String = "",
        precision: Int? = null,
    ) : AbstractDoubleExporter(precision) {
        constructor(
            incarnation: Incarnation,
            localCorrectValueMolecule: String,
            statistics: String,
            localValueMolecule: String,
            precision: Int,
        ) : this(
            incarnation = incarnation,
            localCorrectValueMolecule = localCorrectValueMolecule,
            statistics = statistics,
            localValueMolecule = localValueMolecule,
            localValueProperty = "",
            precision = precision,
        )

        private val statistic: UnivariateStatistic =
            StatUtil
                .makeUnivariateStatistic(statistics)
                .orElseThrow { IllegalArgumentException("Could not create univariate statistic $statistics") }
        private val mReference: Molecule = incarnation.createMolecule(localCorrectValueMolecule)
        private val pReference: String = localCorrectValueProperty
        private val pActual: String = localValueProperty
        private val mActual: Molecule = incarnation.createMolecule(localValueMolecule)
        private val name: String =
            with(StringBuilder("MSE(")) {
                append(statistics)
                append('(')
                if (pReference.isNotEmpty()) {
                    append(pReference).append('@')
                }
                append(localCorrectValueMolecule).append("),")
                if (pActual.isNotEmpty()) {
                    append(pActual).append('@')
                }
                append(localValueMolecule).append(')')
                toString()
            }
        override val columnNames = listOf(name)

        override fun  extractData(
            environment: Environment,
            reaction: Actionable?,
            time: Time,
            step: Long,
        ): Map {
            val incarnation: Incarnation = environment.incarnation
            val value: Double =
                statistic
                    .evaluate(
                        environment.nodes.map { incarnation.getProperty(it, mReference, pReference) }.toDoubleArray(),
                    )
            val mse: Double =
                environment.nodes
                    .parallelStream()
                    .mapToDouble { incarnation.getProperty(it, mActual, pActual) - value }
                    .map { it * it }
                    .average()
                    .orElse(Double.NaN)
            return mapOf(name to mse)
        }
    }




© 2015 - 2025 Weber Informatics LLC | Privacy Policy