org.datavec.api.split.FileSplit Maven / Gradle / Ivy
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.api.split;
import org.apache.commons.io.FilenameUtils;
import org.apache.commons.io.filefilter.IOFileFilter;
import org.apache.commons.io.filefilter.RegexFileFilter;
import org.apache.commons.io.filefilter.SuffixFileFilter;
import org.datavec.api.util.files.URIUtil;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.collection.CompactHeapStringList;
import org.nd4j.linalg.util.MathUtils;
import java.io.*;
import java.net.URI;
import java.util.*;
/**
* File input split. Splits up a root directory in to files.
*
* @author Adam Gibson
*/
public class FileSplit extends BaseInputSplit {
protected File rootDir;
// Use for Collections, pass in list of file type strings
protected String[] allowFormat = null;
protected boolean recursive = true;
protected Random random;
protected boolean randomize = false;
protected FileSplit(File rootDir, String[] allowFormat, boolean recursive, Random random, boolean runMain) {
this.allowFormat = allowFormat;
this.recursive = recursive;
this.rootDir = rootDir;
if (random != null) {
this.random = random;
this.randomize = true;
}
if (runMain)
this.initialize();
}
public FileSplit(File rootDir) {
this(rootDir, null, true, null, true);
}
public FileSplit(File rootDir, Random rng) {
this(rootDir, null, true, rng, true);
}
public FileSplit(File rootDir, String[] allowFormat) {
this(rootDir, allowFormat, true, null, true);
}
public FileSplit(File rootDir, String[] allowFormat, Random rng) {
this(rootDir, allowFormat, true, rng, true);
}
public FileSplit(File rootDir, String[] allowFormat, boolean recursive) {
this(rootDir, allowFormat, recursive, null, true);
}
protected void initialize() {
// Collection subFiles;
if (rootDir == null)
throw new IllegalArgumentException("File path must not be null");
else if(rootDir.isAbsolute() && !rootDir.exists()) {
try {
if(!rootDir.createNewFile()) {
throw new IllegalArgumentException("Unable to create file " + rootDir.getAbsolutePath());
}
//ensure uri strings has the root file if it's not a directory
else {
uriStrings = new ArrayList<>();
uriStrings.add(rootDir.toURI().toString());
}
} catch (IOException e) {
throw new IllegalStateException(e);
}
}
else if (!rootDir.getAbsoluteFile().exists())
// When implementing wild card characters in the rootDir, remove this if exists,
// verify expanded paths exist and check for the edge case when expansion cannot be
// translated to existed locations
throw new IllegalArgumentException("No such file or directory: " + rootDir.getAbsolutePath());
else if (rootDir.isDirectory()) {
List list = listFiles(rootDir, allowFormat, recursive);
uriStrings = new CompactHeapStringList();
if (randomize) {
iterationOrder = new int[list.size()];
for (int i = 0; i < iterationOrder.length; i++) {
iterationOrder[i] = i;
}
MathUtils.shuffleArray(iterationOrder, random);
}
for (File f : list) {
uriStrings.add(URIUtil.fileToURI(f).toString());
++length;
}
} else {
// Lists one file
String toString = URIUtil.fileToURI(rootDir).toString(); //URI.getPath(), getRawPath() etc don't have file:/ prefix necessary for conversion back to URI
uriStrings = new ArrayList<>(1);
uriStrings.add(toString);
length += rootDir.length();
}
}
@Override
public String addNewLocation() {
if(rootDir.isDirectory())
return addNewLocation(new File(rootDir, UUID.randomUUID().toString()).toURI().toString());
else {
//add a file in the same directory as the file with the same extension as the original file
return addNewLocation(new File(rootDir.getParent(), UUID.randomUUID().toString() + "." + FilenameUtils.getExtension(rootDir.getAbsolutePath())).toURI().toString());
}
}
@Override
public String addNewLocation(String location) {
File f = new File(URI.create(location));
try {
f.createNewFile();
} catch (IOException e) {
throw new IllegalStateException(e);
}
uriStrings.add(location);
++length;
return location;
}
@Override
public void updateSplitLocations(boolean reset) {
if (reset) {
initialize();
}
}
@Override
public boolean needsBootstrapForWrite() {
return locations() == null ||
locations().length < 1
|| locations().length == 1 && !locations()[0].isAbsolute();
}
@Override
public void bootStrapForWrite() {
if(locations().length == 1 && !locations()[0].isAbsolute()) {
File parentDir = new File(locations()[0]);
File writeFile = new File(parentDir,"write-file");
try {
writeFile.createNewFile();
//since locations are dynamically generated, allow
uriStrings.add(writeFile.toURI().toString());
} catch (IOException e) {
throw new IllegalStateException(e);
}
}
}
@Override
public OutputStream openOutputStreamFor(String location) throws Exception {
FileOutputStream ret = location.startsWith("file:") ? new FileOutputStream(new File(URI.create(location))):
new FileOutputStream(new File(location));
return ret;
}
@Override
public InputStream openInputStreamFor(String location) throws Exception {
FileInputStream ret = location.startsWith("file:") ? new FileInputStream(new File(URI.create(location))):
new FileInputStream(new File(location));
return ret;
}
@Override
public long length() {
return length;
}
@Override
public void reset() {
if (randomize) {
//Shuffle the iteration order
MathUtils.shuffleArray(iterationOrder, random);
}
}
@Override
public boolean resetSupported() {
return true;
}
public File getRootDir() {
return rootDir;
}
private List listFiles(File dir, String[] allowedFormats, boolean recursive) {
Preconditions.checkState(dir.isDirectory(), "Argument is not a directory: %s", dir);
IOFileFilter filter;
if (allowedFormats == null) {
filter = new RegexFileFilter(".*");
} else {
filter = new SuffixFileFilter(allowedFormats);
}
LinkedList queue = new LinkedList<>();
queue.add(dir);
List out = new ArrayList<>();
while(!queue.isEmpty()){
File[] listFiles = queue.remove().listFiles();
if(listFiles != null){
for(File f : listFiles){
boolean isDir = f.isDirectory();
if(isDir && recursive){
queue.add(f);
} else if(!isDir && filter.accept(f)){
out.add(f);
}
}
}
}
return out;
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy