
com.simiacryptus.mindseye.opt.trainable.SparkTrainable Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mindseye Show documentation
Show all versions of mindseye Show documentation
A neural network library for Java 8
/*
* Copyright (c) 2017 by Andrew Charneski.
*
* The author licenses this file to you under the
* Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance
* with the License. You may obtain a copy
* of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package com.simiacryptus.mindseye.opt.trainable;
import com.simiacryptus.mindseye.network.graph.DAGNetwork;
import com.simiacryptus.mindseye.layers.DeltaSet;
import com.simiacryptus.mindseye.layers.NNLayer;
import com.simiacryptus.mindseye.layers.NNResult;
import com.simiacryptus.util.ml.Tensor;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.rdd.RDD;
import java.io.Serializable;
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
public class SparkTrainable implements Trainable {
private static class ReducableResult implements Serializable {
public final Map deltas;
public final double sum;
public final int count;
public ReducableResult(Map deltas, double sum, int count) {
this.deltas = deltas;
this.sum = sum;
this.count = count;
}
public void accumulate(DeltaSet source) {
Map idIndex = source.map.entrySet().stream().collect(Collectors.toMap(
e -> e.getKey().id.toString(), e -> e.getKey()
));
deltas.forEach((k,v)->source.get(idIndex.get(k), (double[])null).accumulate(v));
}
public SparkTrainable.ReducableResult add(SparkTrainable.ReducableResult right) {
HashMap map = new HashMap<>();
Set keys = Stream.concat(deltas.keySet().stream(), right.deltas.keySet().stream()).collect(Collectors.toSet());
for(String key : keys) {
double[] l = deltas.get(key);
double[] r = right.deltas.get(key);
if(null != r) {
if(null != l) {
assert(l.length==r.length);
double[] x = new double[l.length];
for(int i=0;i, SparkTrainable.ReducableResult> {
final DAGNetwork network;
private PartitionTask(DAGNetwork network) {
this.network = network;
}
@Override
public Iterator call(Iterator partition) throws Exception {
Tensor[][] tensors = SparkTrainable.getStream(partition).toArray(i -> new Tensor[i][]);
NNResult eval = network.eval(NNResult.batchResultArray(tensors));
DeltaSet deltaSet = new DeltaSet();
eval.accumulate(deltaSet);
double[] doubles = Arrays.stream(eval.data).mapToDouble(x -> x.getData()[0]).toArray();
return Arrays.asList(SparkTrainable.getResult(deltaSet, doubles)).iterator();
}
}
private static SparkTrainable.ReducableResult getResult(DeltaSet delta, double[] values) {
Map deltas = delta.map.entrySet().stream().collect(Collectors.toMap(
e -> e.getKey().id.toString(), e -> e.getValue().delta
));
return new SparkTrainable.ReducableResult(deltas, Arrays.stream(values).sum(), values.length);
}
private DeltaSet getDelta(SparkTrainable.ReducableResult reduce) {
DeltaSet deltaSet = new DeltaSet();
Tensor[] prototype = dataRDD.take(1).get(0);
NNResult result = network.eval(NNResult.batchResultArray(new Tensor[][]{prototype}));
result.accumulate(deltaSet, 0);
reduce.accumulate(deltaSet);
return deltaSet;
}
private final JavaRDD dataRDD;
private final DAGNetwork network;
public SparkTrainable(RDD trainingData, DAGNetwork network) {
this.dataRDD = trainingData.toJavaRDD();
this.network = network;
resetSampling();
}
@Override
public Trainable.PointSample measure() {
SparkTrainable.ReducableResult result = dataRDD.mapPartitions(new PartitionTask(network))
.reduce(SparkTrainable.ReducableResult::add);
DeltaSet deltaSet = getDelta(result);
DeltaSet stateSet = new DeltaSet();
deltaSet.map.forEach((layer, layerDelta) -> {
stateSet.get(layer, layerDelta.target).accumulate(layerDelta.target);
});
return new Trainable.PointSample(deltaSet, stateSet, result.meanValue());
}
@Override
public void resetToFull() {
}
private static Stream getStream(Iterator partition) {
int characteristics = Spliterator.ORDERED;
boolean parallel = false;
Spliterator spliterator = Spliterators.spliteratorUnknownSize(partition, characteristics);
return StreamSupport.stream(spliterator, parallel);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy