All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.common.loader.FileBatch Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*******************************************************************************
 * Copyright (c) 2015-2019 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package org.nd4j.common.loader;

import lombok.AllArgsConstructor;
import lombok.Data;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.StringUtils;

import java.io.*;
import java.nio.charset.StandardCharsets;
import java.util.*;
import java.util.zip.ZipEntry;
import java.util.zip.ZipInputStream;
import java.util.zip.ZipOutputStream;

/**
 * FileBatch: stores the raw contents of multiple files in byte arrays (one per file) along with their original paths.
 * FileBatch can be stored to disk (or any output stream) in zip format.
 * Typical use cases include creating batches of data in their raw format for distributed training. This can reduce the
 * number of disk reads required (fewer files) and network transfers when reading from remote storage, due to the zip
 * format used for compression.
 *
 * @author Alex Black
 */
@AllArgsConstructor
@Data
public class FileBatch implements Serializable {
    /**
     * Name of the file in the zip file that contains the original paths/filenames
     */
    public static final String ORIGINAL_PATHS_FILENAME = "originalUris.txt";

    private final List fileBytes;
    private final List originalUris;

    /**
     * Read a FileBatch from the specified file. This method assumes the FileBatch was previously saved to
     * zip format using {@link #writeAsZip(File)} or {@link #writeAsZip(OutputStream)}
     *
     * @param file File to read from
     * @return The loaded FileBatch
     * @throws IOException If an error occurs during reading
     */
    public static FileBatch readFromZip(File file) throws IOException {
        try (FileInputStream fis = new FileInputStream(file)) {
            return readFromZip(fis);
        }
    }

    /**
     * Read a FileBatch from the specified input stream. This method assumes the FileBatch was previously saved to
     * zip format using {@link #writeAsZip(File)} or {@link #writeAsZip(OutputStream)}
     *
     * @param is Input stream to read from
     * @return The loaded FileBatch
     * @throws IOException If an error occurs during reading
     */
    public static FileBatch readFromZip(InputStream is) throws IOException {
        String originalUris = null;
        Map bytesMap = new HashMap<>();
        try (ZipInputStream zis = new ZipInputStream(new BufferedInputStream(is))) {
            ZipEntry ze;
            while ((ze = zis.getNextEntry()) != null) {
                String name = ze.getName();
                byte[] bytes = IOUtils.toByteArray(zis);
                if (name.equals(ORIGINAL_PATHS_FILENAME)) {
                    originalUris = new String(bytes, 0, bytes.length, StandardCharsets.UTF_8);
                } else {
                    int idxSplit = name.indexOf("_");
                    int idxSplit2 = name.indexOf(".");
                    int fileIdx = Integer.parseInt(name.substring(idxSplit + 1, idxSplit2));
                    bytesMap.put(fileIdx, bytes);
                }
            }
        }

        List list = new ArrayList<>(bytesMap.size());
        for (int i = 0; i < bytesMap.size(); i++) {
            list.add(bytesMap.get(i));
        }

        List origPaths = Arrays.asList(originalUris.split("\n"));
        return new FileBatch(list, origPaths);
    }

    /**
     * Create a FileBatch from the specified files
     *
     * @param files Files to create the FileBatch from
     * @return The created FileBatch
     * @throws IOException If an error occurs during reading of the file content
     */
    public static FileBatch forFiles(File... files) throws IOException {
        return forFiles(Arrays.asList(files));
    }

    /**
     * Create a FileBatch from the specified files
     *
     * @param files Files to create the FileBatch from
     * @return The created FileBatch
     * @throws IOException If an error occurs during reading of the file content
     */
    public static FileBatch forFiles(List files) throws IOException {
        List origPaths = new ArrayList<>(files.size());
        List bytes = new ArrayList<>(files.size());
        for (File f : files) {
            bytes.add(FileUtils.readFileToByteArray(f));
            origPaths.add(f.toURI().toString());
        }
        return new FileBatch(bytes, origPaths);
    }

    /**
     * Write the FileBatch to the specified File, in zip file format
     *
     * @param f File to write to
     * @throws IOException If an error occurs during writing
     */
    public void writeAsZip(File f) throws IOException {
        writeAsZip(new FileOutputStream(f));
    }

    /**
     * @param os Write the FileBatch to the specified output stream, in zip file format
     * @throws IOException If an error occurs during writing
     */
    public void writeAsZip(OutputStream os) throws IOException {
        try (ZipOutputStream zos = new ZipOutputStream(new BufferedOutputStream(os))) {

            //Write original paths as a text file:
            ZipEntry ze = new ZipEntry(ORIGINAL_PATHS_FILENAME);
            String originalUrisJoined = StringUtils.join(originalUris, "\n"); //Java String.join is Java 8
            zos.putNextEntry(ze);
            zos.write(originalUrisJoined.getBytes(StandardCharsets.UTF_8));

            for (int i = 0; i < fileBytes.size(); i++) {
                String ext = FilenameUtils.getExtension(originalUris.get(i));
                if (ext == null || ext.isEmpty())
                    ext = "bin";
                String name = "file_" + i + "." + ext;
                ze = new ZipEntry(name);
                zos.putNextEntry(ze);
                zos.write(fileBytes.get(i));
            }
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy