All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.util.TestDataSetConsumer Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.deeplearning4j.util;

import lombok.NonNull;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.concurrent.atomic.AtomicLong;

/**
 * Class that consumes DataSets with specified delays, suitable for testing
 *
 * @author [email protected]
 */
public class TestDataSetConsumer {
    private DataSetIterator iterator;
    private long delay;
    private AtomicLong count = new AtomicLong(0);
    protected static final Logger logger = LoggerFactory.getLogger(TestDataSetConsumer.class);

    public TestDataSetConsumer(long delay) {
        this.delay = delay;
    }

    public TestDataSetConsumer(@NonNull DataSetIterator iterator, long delay) {
        this.iterator = iterator;
        this.delay = delay;
    }


    /**
     * This method cycles through iterator, whie iterator.hasNext() returns true. After each cycle execution time is simulated either using Thread.sleep() or empty cycle
     *
     * @param consumeWithSleep
     * @return
     */
    public long consumeWhileHasNext(boolean consumeWithSleep) {
        count.set(0);
        if (iterator == null)
            throw new RuntimeException("Can't use consumeWhileHasNext() if iterator isn't set");

        while (iterator.hasNext()) {
            DataSet ds = iterator.next();
            this.consumeOnce(ds, consumeWithSleep);
        }

        return count.get();
    }

    /**
     * This method consumes single DataSet, and spends delay time simulating execution of this dataset
     *
     * @param dataSet
     * @param consumeWithSleep
     * @return
     */
    public long consumeOnce(@NonNull DataSet dataSet, boolean consumeWithSleep) {
        long timeMs = System.currentTimeMillis() + delay;
        while (System.currentTimeMillis() < timeMs) {
            if (consumeWithSleep)
                try {
                    Thread.sleep(delay);
                } catch (Exception e) {
                    throw new RuntimeException(e);
                }
        }

        count.incrementAndGet();

        if (count.get() % 100 == 0)
            logger.info("Passed {} datasets...", count.get());

        return count.get();
    }

    public long getCount() {
        return count.get();
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy