All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.mwg.mlx.algorithm.AbstractLinearRegressionNode Maven / Gradle / Ivy

package org.mwg.mlx.algorithm;

import org.mwg.Graph;
import org.mwg.ml.RegressionNode;
import org.mwg.plugin.Enforcer;
import org.mwg.plugin.NodeState;

public abstract class AbstractLinearRegressionNode extends AbstractAnySlidingWindowManagingRegressionNode implements RegressionNode {

    /**
     * Regression coefficients
     */
    public static final String COEFFICIENTS_KEY = "regressionCoefficients";
    /**
     * Regression coefficients - default
     */
    public static final double[] COEFFICIENTS_DEF = new double[0];
    /**
     * Regression intercept
     */
    public static final String INTERCEPT_KEY = "regressionIntercept";
    /**
     * Regression intercept - default
     */
    public static final double INTERCEPT_DEF = 0.0;

    /**
     * L2 regularization coefficient
     */
    public static final String L2_COEF_KEY = "L2Coefficient";
    /**
     * L2 regularization coefficient - default
     */
    public static final double L2_COEF_DEF = 0.0;

    public AbstractLinearRegressionNode(long p_world, long p_time, long p_id, Graph p_graph, long[] currentResolution) {
        super(p_world, p_time, p_id, p_graph, currentResolution);
    }

    private static final Enforcer alrEnforcer = new Enforcer()
            .asNonNegativeDouble(L2_COEF_KEY);

    @Override
    public void setProperty(String propertyName, byte propertyType, Object propertyValue) {
        if (COEFFICIENTS_KEY.equals(propertyName) || INTERCEPT_KEY.equals(propertyName)) {
            //Nothing. Those cannot be set.
        }else{
            alrEnforcer.check(propertyName, propertyType, propertyValue);
            super.setProperty(propertyName, propertyType, propertyValue);
        }
    }

    @Override
    protected void setBootstrapModeHook(NodeState state) {
        //What should we do when bootstrap mode is approaching?
        //TODO Nothing?
    }

    @Override
    public double predictValue(NodeState state, double curValue[]){
        return predictValueInternal(curValue, state.getFromKeyWithDefault(COEFFICIENTS_KEY, COEFFICIENTS_DEF), state.getFromKeyWithDefault(INTERCEPT_KEY, INTERCEPT_DEF));
    }

    protected double predictValueInternal(double curValue[], double coefs[], double intercept){
        double response = 0;
        for (int i=0;i




© 2015 - 2025 Weber Informatics LLC | Privacy Policy