All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ml.classifiers.LogisticRegressionClassifier Maven / Gradle / Ivy

package ml.classifiers;

import datasets.VectorDouble;
import optimization.ISupervisedOptimizer;
import datastructs.I2DDataSet;
import datastructs.IVector;
import maths.functions.IVectorRealFunction;

public class LogisticRegressionClassifier>,
                                HypothesisType extends IVectorRealFunction>> extends ClassifierBase {

    /**
     * Constructor.
     *
     * @param hypothesis A hypothesis type
     * @param optimizer An optimizer
     */
    public LogisticRegressionClassifier(HypothesisType hypothesis, ISupervisedOptimizer optimizer){
        super();
        this.hypothesis = hypothesis;
        this.optimizer = optimizer;
    }

    /**
     * Train the model using the provided dataset
     */
    @Override
    public  OutputType train(final DataSetType dataSet, final VectorDouble labels){
        return this.optimizer.optimize(dataSet, labels, this.hypothesis);
    }

    /**
     * Predict the class of the given data point
     */
    @Override
    public  Integer  predict(PointType point){
        VectorDouble vec = (VectorDouble) point;
        double hypothesisVal = this.hypothesis.evaluate(vec);

        if(hypothesisVal >= 0.5){
            return 1;
        }
        return 0;
    }

    /**
     * The hypothesis function assumed by the regressor
     */
    protected HypothesisType hypothesis;
    protected ISupervisedOptimizer optimizer;
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy