All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.models.word2vec.iterator.Word2VecDataFetcher Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package org.deeplearning4j.models.word2vec.iterator;

import org.apache.commons.io.FileUtils;
import org.apache.commons.io.LineIterator;
import org.deeplearning4j.models.word2vec.Word2Vec;
import org.deeplearning4j.text.movingwindow.Window;
import org.deeplearning4j.text.movingwindow.WindowConverter;
import org.deeplearning4j.text.movingwindow.Windows;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.fetcher.DataSetFetcher;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.FeatureUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.regex.Pattern;

/**
 *
 */
public class Word2VecDataFetcher implements DataSetFetcher {

    /**
     * 
     */
    private static final long serialVersionUID = 3245955804749769475L;
    private transient Iterator files;
    private Word2Vec vec;
    private static Pattern begin = Pattern.compile("<[A-Z]+>");
    private static Pattern end = Pattern.compile("");
    private List labels = new ArrayList<>();
    private int batch;
    private List cache = new ArrayList<>();
    private static final Logger log = LoggerFactory.getLogger(Word2VecDataFetcher.class);
    private int totalExamples;
    private String path;

    public Word2VecDataFetcher(String path, Word2Vec vec, List labels) {
        if (vec == null || labels == null || labels.isEmpty())
            throw new IllegalArgumentException(
                            "Unable to initialize due to missing argument or empty label applyTransformToDestination");
        this.vec = vec;
        this.labels = labels;
        this.path = path;
    }



    private DataSet fromCache() {
        INDArray outcomes = null;
        INDArray input = null;
        input = Nd4j.create(batch, vec.lookupTable().layerSize() * vec.getWindow());
        outcomes = Nd4j.create(batch, labels.size());
        for (int i = 0; i < batch; i++) {
            input.putRow(i, WindowConverter.asExampleMatrix(cache.get(i), vec));
            int idx = labels.indexOf(cache.get(i).getLabel());
            if (idx < 0)
                idx = 0;
            outcomes.putRow(i, FeatureUtil.toOutcomeVector(idx, labels.size()));
        }
        return new DataSet(input, outcomes);

    }

    @Override
    public DataSet next() {
        //pop from cache when possible, or when there's nothing left
        if (cache.size() >= batch || !files.hasNext())
            return fromCache();



        File f = files.next();
        try {
            LineIterator lines = FileUtils.lineIterator(f);
            INDArray outcomes = null;
            INDArray input = null;

            while (lines.hasNext()) {
                List windows = Windows.windows(lines.nextLine());
                if (windows.isEmpty() && lines.hasNext())
                    continue;

                if (windows.size() < batch) {
                    input = Nd4j.create(windows.size(), vec.lookupTable().layerSize() * vec.getWindow());
                    outcomes = Nd4j.create(batch, labels.size());
                    for (int i = 0; i < windows.size(); i++) {
                        input.putRow(i, WindowConverter.asExampleMatrix(cache.get(i), vec));
                        int idx = labels.indexOf(windows.get(i).getLabel());
                        if (idx < 0)
                            idx = 0;
                        INDArray outcomeRow = FeatureUtil.toOutcomeVector(idx, labels.size());
                        outcomes.putRow(i, outcomeRow);
                    }
                    return new DataSet(input, outcomes);


                } else {
                    input = Nd4j.create(batch, vec.lookupTable().layerSize() * vec.getWindow());
                    outcomes = Nd4j.create(batch, labels.size());
                    for (int i = 0; i < batch; i++) {
                        input.putRow(i, WindowConverter.asExampleMatrix(cache.get(i), vec));
                        int idx = labels.indexOf(windows.get(i).getLabel());
                        if (idx < 0)
                            idx = 0;
                        INDArray outcomeRow = FeatureUtil.toOutcomeVector(idx, labels.size());
                        outcomes.putRow(i, outcomeRow);
                    }
                    //add left over to cache; need to ensure that only batch rows are returned
                    /*
                     * Note that I'm aware of possible concerns for sentence sequencing.
                     * This is a hack right now in place of something
                     * that will be way more elegant in the future.
                     */
                    if (windows.size() > batch) {
                        List leftOvers = windows.subList(batch, windows.size());
                        cache.addAll(leftOvers);
                    }
                    return new DataSet(input, outcomes);
                }

            }
        } catch (IOException e) {
            throw new RuntimeException(e);
        }

        return null;
    }



    @Override
    public int totalExamples() {
        return totalExamples;
    }

    @Override
    public int inputColumns() {
        return vec.lookupTable().layerSize() * vec.getWindow();
    }

    @Override
    public int totalOutcomes() {
        return labels.size();

    }

    @Override
    public void reset() {
        files = FileUtils.iterateFiles(new File(path), null, true);
        cache.clear();

    }



    @Override
    public int cursor() {
        return 0;

    }



    @Override
    public boolean hasMore() {
        return files.hasNext() || !cache.isEmpty();
    }

    @Override
    public void fetch(int numExamples) {
        this.batch = numExamples;
    }

    public Iterator getFiles() {
        return files;
    }

    public Word2Vec getVec() {
        return vec;
    }

    public static Pattern getBegin() {
        return begin;
    }

    public static Pattern getEnd() {
        return end;
    }

    public List getLabels() {
        return labels;
    }

    public int getBatch() {
        return batch;
    }

    public List getCache() {
        return cache;
    }



}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy