All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.api.loader.FileBatch Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.nd4j.api.loader;

import lombok.AllArgsConstructor;
import lombok.Data;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.StringUtils;

import java.io.*;
import java.nio.charset.StandardCharsets;
import java.util.*;
import java.util.zip.ZipEntry;
import java.util.zip.ZipInputStream;
import java.util.zip.ZipOutputStream;

/**
 * FileBatch: stores the raw contents of multiple files in byte arrays (one per file) along with their original paths.
 * FileBatch can be stored to disk (or any output stream) in zip format.
 * Typical use cases include creating batches of data in their raw format for distributed training. This can reduce the
 * number of disk reads required (fewer files) and network transfers when reading from remote storage, due to the zip
 * format used for compression.
 *
 * @author Alex Black
 */
@AllArgsConstructor
@Data
public class FileBatch implements Serializable {
    /**
     * Name of the file in the zip file that contains the original paths/filenames
     */
    public static final String ORIGINAL_PATHS_FILENAME = "originalUris.txt";

    private final List fileBytes;
    private final List originalUris;

    /**
     * Read a FileBatch from the specified file. This method assumes the FileBatch was previously saved to
     * zip format using {@link #writeAsZip(File)} or {@link #writeAsZip(OutputStream)}
     *
     * @param file File to read from
     * @return The loaded FileBatch
     * @throws IOException If an error occurs during reading
     */
    public static FileBatch readFromZip(File file) throws IOException {
        try (FileInputStream fis = new FileInputStream(file)) {
            return readFromZip(fis);
        }
    }

    /**
     * Read a FileBatch from the specified input stream. This method assumes the FileBatch was previously saved to
     * zip format using {@link #writeAsZip(File)} or {@link #writeAsZip(OutputStream)}
     *
     * @param is Input stream to read from
     * @return The loaded FileBatch
     * @throws IOException If an error occurs during reading
     */
    public static FileBatch readFromZip(InputStream is) throws IOException {
        String originalUris = null;
        Map bytesMap = new HashMap<>();
        try (ZipInputStream zis = new ZipInputStream(new BufferedInputStream(is))) {
            ZipEntry ze;
            while ((ze = zis.getNextEntry()) != null) {
                String name = ze.getName();
                byte[] bytes = IOUtils.toByteArray(zis);
                if (name.equals(ORIGINAL_PATHS_FILENAME)) {
                    originalUris = new String(bytes, 0, bytes.length, StandardCharsets.UTF_8);
                } else {
                    int idxSplit = name.indexOf("_");
                    int idxSplit2 = name.indexOf(".");
                    int fileIdx = Integer.parseInt(name.substring(idxSplit + 1, idxSplit2));
                    bytesMap.put(fileIdx, bytes);
                }
            }
        }

        List list = new ArrayList<>(bytesMap.size());
        for (int i = 0; i < bytesMap.size(); i++) {
            list.add(bytesMap.get(i));
        }

        List origPaths = Arrays.asList(originalUris.split("\n"));
        return new FileBatch(list, origPaths);
    }

    /**
     * Create a FileBatch from the specified files
     *
     * @param files Files to create the FileBatch from
     * @return The created FileBatch
     * @throws IOException If an error occurs during reading of the file content
     */
    public static FileBatch forFiles(File... files) throws IOException {
        return forFiles(Arrays.asList(files));
    }

    /**
     * Create a FileBatch from the specified files
     *
     * @param files Files to create the FileBatch from
     * @return The created FileBatch
     * @throws IOException If an error occurs during reading of the file content
     */
    public static FileBatch forFiles(List files) throws IOException {
        List origPaths = new ArrayList<>(files.size());
        List bytes = new ArrayList<>(files.size());
        for (File f : files) {
            bytes.add(FileUtils.readFileToByteArray(f));
            origPaths.add(f.toURI().toString());
        }
        return new FileBatch(bytes, origPaths);
    }

    /**
     * Write the FileBatch to the specified File, in zip file format
     *
     * @param f File to write to
     * @throws IOException If an error occurs during writing
     */
    public void writeAsZip(File f) throws IOException {
        writeAsZip(new FileOutputStream(f));
    }

    /**
     * @param os Write the FileBatch to the specified output stream, in zip file format
     * @throws IOException If an error occurs during writing
     */
    public void writeAsZip(OutputStream os) throws IOException {
        try (ZipOutputStream zos = new ZipOutputStream(new BufferedOutputStream(os))) {

            //Write original paths as a text file:
            ZipEntry ze = new ZipEntry(ORIGINAL_PATHS_FILENAME);
            String originalUrisJoined = StringUtils.join(originalUris, "\n"); //Java String.join is Java 8
            zos.putNextEntry(ze);
            zos.write(originalUrisJoined.getBytes(StandardCharsets.UTF_8));

            for (int i = 0; i < fileBytes.size(); i++) {
                String ext = FilenameUtils.getExtension(originalUris.get(i));
                if (ext == null || ext.isEmpty())
                    ext = "bin";
                String name = "file_" + i + "." + ext;
                ze = new ZipEntry(name);
                zos.putNextEntry(ze);
                zos.write(fileBytes.get(i));
            }
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy