All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.davidbracewell.apollo.ml.data.OffHeapDataset Maven / Gradle / Ivy

package com.davidbracewell.apollo.ml.data;

import com.davidbracewell.apollo.ml.Encoder;
import com.davidbracewell.apollo.ml.Example;
import com.davidbracewell.apollo.ml.LabelEncoder;
import com.davidbracewell.apollo.ml.preprocess.PreprocessorList;
import com.davidbracewell.conversion.Cast;
import com.davidbracewell.function.SerializableFunction;
import com.davidbracewell.function.Unchecked;
import com.davidbracewell.guava.common.base.Throwables;
import com.davidbracewell.io.Resources;
import com.davidbracewell.io.resource.Resource;
import com.davidbracewell.stream.MStream;
import com.davidbracewell.stream.StreamingContext;
import com.davidbracewell.string.StringUtils;
import com.davidbracewell.tuple.Tuples;
import lombok.NonNull;

import java.io.BufferedWriter;
import java.io.IOException;
import java.util.Random;
import java.util.concurrent.atomic.AtomicLong;

import static com.davidbracewell.function.Unchecked.function;

/**
 * 

Creates a dataset that streams examples off disk to save memory.

* * @param the example type parameter * @author David B. Bracewell */ public class OffHeapDataset extends Dataset { private static final long serialVersionUID = 1L; private final AtomicLong id = new AtomicLong(); private Resource outputResource = Resources.temporaryDirectory(); private Class clazz; private int size = 0; /** * Instantiates a new Off heap dataset. * * @param featureEncoder the feature encoder * @param labelEncoder the label encoder * @param preprocessors the preprocessors */ protected OffHeapDataset(Encoder featureEncoder, LabelEncoder labelEncoder, PreprocessorList preprocessors) { super(featureEncoder, labelEncoder, preprocessors); outputResource.deleteOnExit(); } @Override protected void addAll(@NonNull MStream instances) { //TODO: Rewrite using MultiFileWriter final long binSize; if (instances.isReusable()) { binSize = instances.count() / 5000; } else { binSize = -1; } if (binSize <= 1) { writeInstancesTo(instances, outputResource.getChild("part-" + id.incrementAndGet() + ".json").setIsCompressed(true)); } else { instances.mapToPair(i -> Tuples.$((long) Math.floor(Math.random() * binSize), i)) .groupByKey() .forEachLocal((key, list) -> { Resource r = outputResource.getChild("part-" + id.incrementAndGet() + ".json") .setIsCompressed(true); writeInstancesTo(StreamingContext.local().stream(list), r); } ); } } @Override public void close() { outputResource.delete(); } @Override protected Dataset create(@NonNull MStream instances, @NonNull Encoder featureEncoder, @NonNull LabelEncoder labelEncoder, PreprocessorList preprocessors) { Dataset dataset = new OffHeapDataset<>(featureEncoder, labelEncoder, preprocessors); dataset.addAll(instances); return dataset; } @Override public DatasetType getType() { return DatasetType.OffHeap; } @Override public Dataset mapSelf(@NonNull SerializableFunction function) { outputResource.getChildren().forEach(Unchecked.consumer(r -> { InMemoryDataset temp = new InMemoryDataset<>(getFeatureEncoder(), getLabelEncoder(), null); for (String line : r.readLines()) { temp.add(Example.fromJson(line, clazz)); } temp.mapSelf(function); writeInstancesTo(temp.stream(), r); })); return this; } @Override public Dataset shuffle(@NonNull Random random) { return create(stream().shuffle(), getFeatureEncoder().createNew(), getLabelEncoder().createNew(), null); } @Override public int size() { return size; } @Override public MStream stream() { return StreamingContext.local() .stream(outputResource.getChildren().parallelStream() .flatMap(function(r -> r.lines().javaStream())) .filter(StringUtils::isNotNullOrBlank) .map(function(line -> Cast.as(Example.fromJson(line, clazz))))); } @Override public Dataset copy() { OffHeapDataset copy = Cast.as(create(getStreamingContext().empty())); for (Resource child : outputResource.getChildren()) { try { copy.outputResource.getChild(child.baseName()) .write(child.readToString()); } catch (IOException e) { throw Throwables.propagate(e); } } copy.size = this.size; copy.id.set(this.id.longValue()); copy.clazz = this.clazz; return copy; } private void writeInstancesTo(MStream instances, Resource file) { try (BufferedWriter writer = new BufferedWriter(file.writer())) { instances.forEach(Unchecked.consumer(ii -> { clazz = Cast.as(ii.getClass()); if (ii.getFeatureSpace().count() > 0) { writer.write(ii.toJson()); writer.newLine(); size++; } })); } catch (IOException e) { throw Throwables.propagate(e); } } }// END OF OffHeapDataset




© 2015 - 2025 Weber Informatics LLC | Privacy Policy