All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.models.sequencevectors.transformers.impl.iterables.ParallelTransformerIterator Maven / Gradle / Ivy

/*******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package org.deeplearning4j.models.sequencevectors.transformers.impl.iterables;

import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.text.documentiterator.AsyncLabelAwareIterator;
import org.deeplearning4j.text.documentiterator.LabelAwareIterator;
import org.deeplearning4j.text.documentiterator.LabelledDocument;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * TransformerIterator implementation that's does transformation/tokenization/normalization/whatever in parallel threads.
 * Suitable for cases when tokenization takes too much time for single thread.
 *
 * TL/DR: we read data from sentence iterator, and apply tokenization in parallel threads.
 *
 * @author [email protected]
 */
@Slf4j
public class ParallelTransformerIterator extends BasicTransformerIterator {

    protected static final int capacity = 1024;
    protected BlockingQueue>> buffer = new LinkedBlockingQueue<>(capacity);
    //protected BlockingQueue stringBuffer;
    //protected TokenizerThread[] threads;
    protected AtomicBoolean underlyingHas = new AtomicBoolean(true);
    protected AtomicInteger processing = new AtomicInteger(0);

    private ExecutorService executorService;

    protected static final AtomicInteger count = new AtomicInteger(0);

    private static final int PREFETCH_SIZE = 100;

    public ParallelTransformerIterator(@NonNull LabelAwareIterator iterator, @NonNull SentenceTransformer transformer) {
        this(iterator, transformer, true);
    }

    private void prefetchIterator() {
        /*for (int i = 0; i < PREFETCH_SIZE; ++i) {
            //boolean before = underlyingHas;

                if (underlyingHas.get())
                    underlyingHas.set(iterator.hasNextDocument());
                else
                    underlyingHas.set(false);

            if (underlyingHas.get()) {
                CallableTransformer callableTransformer = new CallableTransformer(iterator.nextDocument(), sentenceTransformer);
                Future> futureSequence = executorService.submit(callableTransformer);
                try {
                    buffer.put(futureSequence);
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
            }
        }*/
    }

    public ParallelTransformerIterator(@NonNull LabelAwareIterator iterator, @NonNull SentenceTransformer transformer,
                                       boolean allowMultithreading) {
        super(new AsyncLabelAwareIterator(iterator, 512), transformer);
        //super(iterator, transformer);
        this.allowMultithreading = allowMultithreading;
        //this.stringBuffer = new LinkedBlockingQueue<>(512);

        //threads = new TokenizerThread[1];
        //threads = new TokenizerThread[allowMultithreading ? Math.max(Runtime.getRuntime().availableProcessors(), 2) : 1];
        executorService = Executors.newFixedThreadPool(allowMultithreading ? Math.max(Runtime.getRuntime().availableProcessors(), 2) : 1);

        prefetchIterator();
    }

    @Override
    public void reset() {
        this.executorService.shutdownNow();
        this.iterator.reset();
        underlyingHas.set(true);
        prefetchIterator();
        this.buffer.clear();
        this.executorService = Executors.newFixedThreadPool(allowMultithreading ? Math.max(Runtime.getRuntime().availableProcessors(), 2) : 1);
    }

    public void shutdown() {
        executorService.shutdown();
    }

    private static class CallableTransformer implements Callable> {

        private LabelledDocument document;
        private SentenceTransformer transformer;

        public CallableTransformer(LabelledDocument document, SentenceTransformer transformer) {
            this.transformer = transformer;
            this.document = document;
        }

        @Override
        public Sequence call() {
            Sequence sequence = new Sequence<>();

            if (document != null && document.getContent() != null) {
                sequence = transformer.transformToSequence(document.getContent());
                if (document.getLabels() != null) {
                    for (String label : document.getLabels()) {
                        if (label != null && !label.isEmpty())
                            sequence.addSequenceLabel(new VocabWord(1.0, label));
                    }
                }
            }
            return sequence;
        }
    }

    @Override
    public boolean hasNext() {
        //boolean before = underlyingHas;

        //if (underlyingHas.get()) {
            if (buffer.size() < capacity && iterator.hasNextDocument()) {
                CallableTransformer transformer = new CallableTransformer(iterator.nextDocument(), sentenceTransformer);
                Future> futureSequence = executorService.submit(transformer);
                try {
                    buffer.put(futureSequence);
                } catch (InterruptedException e) {
                    log.error("",e);
                }
            }
          /*  else
                underlyingHas.set(false);

        }
        else {
           underlyingHas.set(false);
        }*/

        return (/*underlyingHas.get() ||*/ !buffer.isEmpty() || /*!stringBuffer.isEmpty() ||*/ processing.get() > 0);
    }

    @Override
    public Sequence next() {
        try {
            /*if (underlyingHas)
                stringBuffer.put(iterator.nextDocument());*/
            processing.incrementAndGet();
            Future> future = buffer.take();
            Sequence  sequence = future.get();
            processing.decrementAndGet();
            return sequence;

        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy