All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.chen0040.mlp.ann.classifiers.MLPWithLabelOutput Maven / Gradle / Ivy

There is a newer version: 1.0.6
Show newest version
package com.github.chen0040.mlp.ann.classifiers;


import com.github.chen0040.data.frame.DataRow;
import com.github.chen0040.mlp.ann.MLP;
import com.github.chen0040.mlp.ann.MLPNet;

import java.util.List;
import java.util.function.Supplier;


/**
 * Created by xschen on 5/9/15.
 */
public class MLPWithLabelOutput extends MLP {
    public Supplier> classLabelsModel;

    @Override
    public boolean isValidTrainingSample(DataRow tuple){
        return !tuple.getCategoricalTargetColumnNames().isEmpty();
    }

    @Override
    public double[] getTarget(DataRow tuple) {
        List labels = classLabelsModel.get();
        double[] target = new double[labels.size()];
        for (int i = 0; i < labels.size(); ++i) {
            target[i] = labels.get(i).equals(tuple.categoricalTarget()) ? 1 : 0;
        }
        return target;
    }

    @Override
    public Object clone() throws CloneNotSupportedException {
        MLPWithLabelOutput clone = (MLPWithLabelOutput)super.clone();
        clone.copy(this);
        return clone;
    }

    @Override
    public void copy(MLPNet rhs) throws CloneNotSupportedException {
        super.copy(rhs);

        MLPWithLabelOutput rhs2 = (MLPWithLabelOutput)rhs;
        classLabelsModel = rhs2.classLabelsModel;
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy