All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.datacleaner.components.machinelearning.impl.SmileClassifier Maven / Gradle / Ivy

There is a newer version: 5.8.1
Show newest version
/**
 * DataCleaner (community edition)
 * Copyright (C) 2014 Free Software Foundation, Inc.
 *
 * This copyrighted material is made available to anyone wishing to use, modify,
 * copy, or redistribute it subject to the terms and conditions of the GNU
 * Lesser General Public License, as published by the Free Software Foundation.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
 * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public License
 * for more details.
 *
 * You should have received a copy of the GNU Lesser General Public License
 * along with this distribution; if not, write to:
 * Free Software Foundation, Inc.
 * 51 Franklin Street, Fifth Floor
 * Boston, MA  02110-1301  USA
 */
package org.datacleaner.components.machinelearning.impl;

import org.apache.metamodel.util.SerializableRef;
import org.datacleaner.components.machinelearning.api.MLClassification;
import org.datacleaner.components.machinelearning.api.MLClassificationMetadata;

import smile.classification.Classifier;
import smile.classification.SoftClassifier;

public class SmileClassifier extends AbstractClassifier {

    private static final long serialVersionUID = 1L;
    
    private final SerializableRef> smileClassifierRef;

    public SmileClassifier(final Classifier smileClassifier,
            MLClassificationMetadata classificationMetadata) {
        super(classificationMetadata);
        this.smileClassifierRef = new SerializableRef<>(smileClassifier);
    }

    @Override
    public MLClassification classify(double[] featureValues) {
        final Classifier classifier = smileClassifierRef.get();
        if (classifier instanceof SoftClassifier) {
            final SoftClassifier softClassifier = (SoftClassifier) classifier;

            final double[] posteriori = new double[getMetadata().getClassCount()];
            softClassifier.predict(featureValues, posteriori);
            return new MLConfidenceClassification(posteriori);
        }

        final int prediction = classifier.predict(featureValues);
        return new MLSimpleClassification(prediction);
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy