com.simiacryptus.mindseye.eval.LocalSparkTrainable Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mindseye-core Show documentation
Show all versions of mindseye-core Show documentation
Core Neural Networks Framework
///*
// * Copyright (c) 2018 by Andrew Charneski.
// *
// * The author licenses this file to you under the
// * Apache License, Version 2.0 (the "License");
// * you may not use this file except in compliance
// * with the License. You may obtain a copy
// * of the License at
// *
// * http://www.apache.org/licenses/LICENSE-2.0
// *
// * Unless required by applicable law or agreed to in writing,
// * software distributed under the License is distributed on an
// * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// * KIND, either express or implied. See the License for the
// * specific language governing permissions and limitations
// * under the License.
// */
//
//package com.simiacryptus.mindseye.eval;
//
//import com.simiacryptus.mindseye.lang.DeltaSet;
//import com.simiacryptus.mindseye.lang.Layer;
//import com.simiacryptus.mindseye.lang.PointSample;
//import com.simiacryptus.mindseye.lang.StateSet;
//import com.simiacryptus.mindseye.lang.Tensor;
//import com.simiacryptus.mindseye.opt.TrainingMonitor;
//import org.apache.spark.api.java.JavaRDD;
//import org.apache.spark.rdd.RDD;
//import org.slf4j.Logger;
//import org.slf4j.LoggerFactory;
//
//import javax.annotation.Nonnull;
//import java.util.Arrays;
//import java.util.Iterator;
//import java.util.List;
//import java.util.stream.Collectors;
//import java.util.stream.Stream;
//
///**
// * A debugging class which replaces SparkTrainable apply an implementation that uses direct method calls instead of RMI.
// * This can be useful for debugging in some situations.
// */
//public class LocalSparkTrainable extends SparkTrainable {
// private static final Logger log = LoggerFactory.getLogger(LocalSparkTrainable.class);
//
// /**
// * Instantiates a new Local spark trainable.
// *
// * @param trainingData the training data
// * @param network the network
// */
// public LocalSparkTrainable(final RDD trainingData, final Layer network) {
// super(trainingData, network);
// }
//
// /**
// * Instantiates a new Local spark trainable.
// *
// * @param trainingData the training data
// * @param network the network
// * @param sampleSize the sample size
// */
// public LocalSparkTrainable(final RDD trainingData, final Layer network, final int sampleSize) {
// super(trainingData, network, sampleSize);
// }
//
//
// @Nonnull
// @Override
// public PointSample measure(final TrainingMonitor monitor) {
// final long time1 = com.simiacryptus.ref.wrappers.RefSystem.nanoTime();
// final JavaRDD javaRDD = sampledRDD.toJavaRDD();
// assert !javaRDD.isEmpty();
// final List mapPartitions = javaRDD.partitions().stream().map(partition -> {
// try {
// final List[] array = javaRDD.collectPartitions(new int[]{partition.index()});
// assert 0 < array.length;
// if (0 == Arrays.stream(array).mapToInt((@Nonnull final List x) -> x.size()).sum()) {
// return null;
// }
// assert 0 < Arrays.stream(array).mapToInt(x -> x.stream().mapToInt(y -> y.length).sum()).sum();
// final Stream stream = Arrays.stream(array).flatMap(i -> i.stream());
// @Nonnull final Iterator iterator = stream.iterator();
// return new PartitionTask(network).call(iterator).next();
// } catch (@Nonnull final RuntimeException e) {
// throw e;
// } catch (@Nonnull final Exception e) {
// throw new RuntimeException(e);
// }
// }).filter(x -> null != x).collect(Collectors.toList());
// final long time2 = com.simiacryptus.ref.wrappers.RefSystem.nanoTime();
// @Nonnull final SparkTrainable.ReducableResult result = mapPartitions.stream().reduce(SparkTrainable.ReducableResult::add).get();
// if (isVerbose()) {
// log.info(String.format("Measure timing: %.3f / %.3f for %s items", (time2 - time1) * 1e-9, (com.simiacryptus.ref.wrappers.RefSystem.nanoTime() - time2) * 1e-9, sampledRDD.count()));
// }
// @Nonnull final DeltaSet xxx = getDelta(result);
// return new PointSample(xxx, new StateSet(xxx), result.sum, 0.0, result.count).normalize();
// }
//
//}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy