All Downloads are FREE. Search and download functionalities are using the official Maven repository.

smile.deep.DatasetImpl Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (c) 2010-2024 Haifeng Li. All rights reserved.
 *
 * Smile is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * Smile is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with Smile.  If not, see .
 */
package smile.deep;

import java.util.Arrays;
import java.util.Iterator;
import smile.deep.tensor.Index;
import smile.deep.tensor.Tensor;
import smile.math.MathEx;

/**
 * A dataset of arrays.
 *
 * @author Haifeng Li
 */
class DatasetImpl implements Dataset {
    private final Tensor data;
    private final Tensor target;
    private final int size;
    private final int batch;

    /**
     * Constructor.
     * @param data the data.
     * @param target the target.
     * @param batch the mini-batch size.
     */
    public DatasetImpl(float[][] data, int[] target, int batch) {
        int n = data.length;
        int p = data[0].length;

        float[] x = new float[n * p];
        long[] y = new long[n];
        for (int i = 0; i < n; i++) {
            y[i] = target[i];
            float[] xi = data[i];
            System.arraycopy(xi, 0, x, i * p, p);
        }

        this.data = Tensor.of(x, n, p);
        this.target = Tensor.of(y, n);
        this.size = n;
        this.batch = batch;
    }

    /**
     * Constructor.
     * @param data the data.
     * @param target the target.
     * @param batch the mini-batch size.
     */
    public DatasetImpl(double[][] data, int[] target, int batch) {
        int n = data.length;
        int p = data[0].length;

        float[] x = new float[n * p];
        long[] y = new long[n];
        for (int i = 0; i < n; i++) {
            y[i] = target[i];
            double[] xi = data[i];
            for (int j = 0; j < p; j++) {
                x[i * p + j] = (float) xi[j];
            }
        }

        this.data = Tensor.of(x, n, p);
        this.target = Tensor.of(y, n);
        this.size = n;
        this.batch = batch;
    }

    /**
     * Constructor.
     * @param data the data.
     * @param target the target.
     * @param batch the mini-batch size.
     */
    public DatasetImpl(float[][] data, float[] target, int batch) {
        int n = data.length;
        int p = data[0].length;

        float[] x = new float[n * p];
        for (int i = 0; i < n; i++) {
            float[] xi = data[i];
            System.arraycopy(xi, 0, x, i * p, p);
        }

        this.data = Tensor.of(x, n, p);
        this.target = Tensor.of(target, n);
        this.size = n;
        this.batch = batch;
    }

    /**
     * Constructor.
     * @param data the data.
     * @param target the target.
     * @param batch the mini-batch size.
     */
    public DatasetImpl(double[][] data, double[] target, int batch) {
        int n = data.length;
        int p = data[0].length;

        float[] x = new float[n * p];
        for (int i = 0; i < n; i++) {
            double[] xi = data[i];
            for (int j = 0; j < p; j++) {
                x[i * p + j] = (float) xi[j];
            }
        }

        this.data = Tensor.of(x, n, p);
        this.target = Tensor.of(target, n);
        this.size = n;
        this.batch = batch;
    }

    @Override
    public void close() {
        data.close();
        target.close();
    }

    @Override
    public long size() {
    return size;
}

    @Override
    public Iterator iterator() {
        return new Iterator<>() {
            final int[] permutation = MathEx.permutate(size);
            int[] index = new int[batch];
            int i = 0;

            @Override
            public boolean hasNext() {
                return i < size;
            }

            @Override
            public SampleBatch next() {
                int j = 0;
                for (; j < batch && i < size; j++, i++) {
                    index[j] = permutation[i];
                }

                if (j < batch) {
                    index = Arrays.copyOf(index, j);
                }

                var idx = Index.of(index);
                return new SampleBatch(data.get(idx), target.get(idx));
            }
        };
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy