All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.zoo.util.BaseLabels Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.deeplearning4j.zoo.util;

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Scanner;

/**
 * Base functionality for helper classes that return label descriptions.
 *
 * @author saudet
 */
public abstract class BaseLabels implements Labels {

    protected ArrayList labels = null;

    /** Override {@link #getLabels()} when using this constructor. */
    protected BaseLabels() throws IOException {
        this.labels = getLabels();
    }

    /**
     * No need to override anything with this constructor.
     *
     * @param textResource name of a resource containing labels as a list in a text file.
     * @throws IOException 
     */
    protected BaseLabels(String textResource) throws IOException {
        this.labels = getLabels(textResource);
    }

    /**
     * Override to return labels when not calling {@link #BaseLabels(String)}.
     */
    protected ArrayList getLabels() throws IOException {
        return null;
    }

    /**
     * Returns labels based on the text file resource.
     */
    protected ArrayList getLabels(String textResource) throws IOException {
        ArrayList labels = new ArrayList<>();
        try (Scanner s = new Scanner(this.getClass().getResourceAsStream(textResource))) {
            while (s.hasNextLine()) {
                labels.add(s.nextLine());
            }
        }
        return labels;
    }

    @Override
    public String getLabel(int n) {
        return labels.get(n);
    }

    @Override
    public List> decodePredictions(INDArray predictions, int n) {
        int rows = predictions.size(0);
        int cols = predictions.size(1);
        if (predictions.isColumnVectorOrScalar()) {
            predictions = predictions.ravel();
            rows = predictions.size(0);
            cols = predictions.size(1);
        }
        List> descriptions = new ArrayList<>();
        for (int batch = 0; batch < rows; batch++) {
            INDArray result = predictions.getRow(batch);
            result = Nd4j.vstack(Nd4j.linspace(0, cols, cols), result);
            result = Nd4j.sortColumns(result, 1, false);
            List current = new ArrayList<>();
            for (int i = 0; i < n; i++) {
                int label = result.getInt(0, i);
                double prob = result.getDouble(1, i);
                current.add(new ClassPrediction(label, getLabel(label), prob));
            }
            descriptions.add(current);
        }
        return descriptions;
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy