All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.zoo.util.BaseLabels Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.deeplearning4j.zoo.util;

import org.deeplearning4j.common.resources.DL4JResources;
import org.deeplearning4j.common.resources.ResourceType;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.resources.Downloader;

import java.io.*;
import java.net.URL;
import java.util.ArrayList;
import java.util.List;
import java.util.Scanner;

public abstract class BaseLabels implements Labels {

    protected ArrayList labels;

    /** Override {@link #getLabels()} when using this constructor. */
    protected BaseLabels() throws IOException {
        this.labels = getLabels();
    }

    /**
     * No need to override anything with this constructor.
     *
     * @param textResource name of a resource containing labels as a list in a text file.
     * @throws IOException 
     */
    protected BaseLabels(String textResource) throws IOException {
        this.labels = getLabels(textResource);
    }

    /**
     * Override to return labels when not calling {@link #BaseLabels(String)}.
     */
    protected ArrayList getLabels() throws IOException {
        return null;
    }

    /**
     * Returns labels based on the text file resource.
     */
    protected ArrayList getLabels(String textResource) throws IOException {
        ArrayList labels = new ArrayList<>();
        File resourceFile = getResourceFile();  //Download if required
        try (InputStream is = new BufferedInputStream(new FileInputStream(resourceFile)); Scanner s = new Scanner(is)) {
            while (s.hasNextLine()) {
                labels.add(s.nextLine());
            }
        }
        return labels;
    }

    @Override
    public String getLabel(int n) {
        Preconditions.checkArgument(n >= 0 && n < labels.size(), "Invalid index: %s. Must be in range" +
                "0 <= n < %s", n, labels.size());
        return labels.get(n);
    }

    @Override
    public List> decodePredictions(INDArray predictions, int n) {
        if(predictions.rank() == 1){
            //Reshape 1d edge case to [1, nClasses] 2d
            predictions = predictions.reshape(1, predictions.length());
        }
        Preconditions.checkState(predictions.size(1) == labels.size(), "Invalid input array:" +
                " expected array with size(1) equal to numLabels (%s), got array with shape %s", labels.size(), predictions.shape());

        long rows = predictions.size(0);
        long cols = predictions.size(1);
        if (predictions.isColumnVectorOrScalar()) {
            predictions = predictions.ravel();
            rows = (int) predictions.size(0);
            cols = (int) predictions.size(1);
        }
        List> descriptions = new ArrayList<>();
        for (int batch = 0; batch < rows; batch++) {
            INDArray result = predictions.getRow(batch, true);
            result = Nd4j.vstack(Nd4j.linspace(result.dataType(), 0, cols, 1).reshape(1,cols), result);
            result = Nd4j.sortColumns(result, 1, false);
            List current = new ArrayList<>();
            for (int i = 0; i < n; i++) {
                int label = result.getInt(0, i);
                double prob = result.getDouble(1, i);
                current.add(new ClassPrediction(label, getLabel(label), prob));
            }
            descriptions.add(current);
        }
        return descriptions;
    }

    /**
     * @return URL of the resource to download
     */
    protected abstract URL getURL();

    /**
     * @return Name of the resource (used for inferring local storage parent directory)
     */
    protected abstract String resourceName();

    /**
     * @return MD5 of the resource at getURL()
     */
    protected abstract String resourceMD5();

    /**
     * Download the resource at getURL() to the local resource directory, and return the local copy as a File
     *
     * @return File of the local resource
     */
    protected File getResourceFile(){

        URL url = getURL();
        String urlString = url.toString();
        String filename = urlString.substring(urlString.lastIndexOf('/')+1);
        File resourceDir = DL4JResources.getDirectory(ResourceType.RESOURCE, resourceName());
        File localFile = new File(resourceDir, filename);

        String expMD5 = resourceMD5();
        if(localFile.exists()){
            try{
                if(Downloader.checkMD5OfFile(expMD5, localFile)){
                    return localFile;
                }
            } catch (IOException e){
                //Ignore
            }
            //MD5 failed
            localFile.delete();
        }

        //Download
        try {
            Downloader.download(resourceName(), url, localFile, expMD5, 3);
        } catch (IOException e){
            throw new RuntimeException("Error downloading labels",e);
        }

        return localFile;
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy