All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.simiacryptus.mindseye.test.integration.ClassifyProblem Maven / Gradle / Ivy

There is a newer version: 2.1.0
Show newest version
/*
 * Copyright (c) 2019 by Andrew Charneski.
 *
 * The author licenses this file to you under the
 * Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance
 * with the License.  You may obtain a copy
 * of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.simiacryptus.mindseye.test.integration;

import com.google.common.collect.Lists;
import com.simiacryptus.mindseye.eval.ArrayTrainable;
import com.simiacryptus.mindseye.eval.SampledArrayTrainable;
import com.simiacryptus.mindseye.lang.*;
import com.simiacryptus.mindseye.network.DAGNetwork;
import com.simiacryptus.mindseye.network.SimpleLossNetwork;
import com.simiacryptus.mindseye.opt.TrainingMonitor;
import com.simiacryptus.mindseye.opt.ValidatingTrainer;
import com.simiacryptus.mindseye.test.StepRecord;
import com.simiacryptus.mindseye.test.TestUtil;
import com.simiacryptus.notebook.NotebookOutput;
import com.simiacryptus.notebook.TableOutput;
import com.simiacryptus.util.Util;
import com.simiacryptus.util.test.LabeledObject;
import guru.nidi.graphviz.engine.Format;
import guru.nidi.graphviz.engine.Graphviz;
import guru.nidi.graphviz.model.Graph;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.util.*;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;

/**
 * The type Mnist apply base.
 */
public abstract class ClassifyProblem implements Problem {

  private static final Logger logger = LoggerFactory.getLogger(ClassifyProblem.class);

  private static int modelNo = 0;
  private final int categories;
  private final ImageProblemData data;
  private final FwdNetworkFactory fwdFactory;
  private final List history = new ArrayList<>();
  private final OptimizationStrategy optimizer;
  private final List labels;
  private int batchSize = 10000;
  private int timeoutMinutes = 1;

  /**
   * Instantiates a new Classify problem.
   *
   * @param fwdFactory the fwd factory
   * @param optimizer  the optimizer
   * @param data       the data
   * @param categories the categories
   */
  public ClassifyProblem(final FwdNetworkFactory fwdFactory, final OptimizationStrategy optimizer, final ImageProblemData data, final int categories) {
    this.fwdFactory = fwdFactory;
    this.optimizer = optimizer;
    this.data = data;
    this.categories = categories;
    try {
      this.labels = Stream.concat(this.data.trainingData(), this.data.validationData()).map(x -> x.label).distinct().sorted().collect(Collectors.toList());
    } catch (IOException e) {
      throw new RuntimeException(e);
    }
  }


  @Nonnull
  @Override
  public List getHistory() {
    return history;
  }

  /**
   * Gets timeout minutes.
   *
   * @return the timeout minutes
   */
  public int getTimeoutMinutes() {
    return timeoutMinutes;
  }

  /**
   * Sets timeout minutes.
   *
   * @param timeoutMinutes the timeout minutes
   * @return the timeout minutes
   */
  @Nonnull
  public ClassifyProblem setTimeoutMinutes(final int timeoutMinutes) {
    this.timeoutMinutes = timeoutMinutes;
    return this;
  }

  /**
   * Get training data tensor [ ] [ ].
   *
   * @param log the log
   * @return the tensor [ ] [ ]
   */
  public Tensor[][] getTrainingData(final NotebookOutput log) {
    try {
      return data.trainingData().map(labeledObject -> {
        @Nonnull final Tensor categoryTensor = new Tensor(categories);
        final int category = parse(labeledObject.label);
        categoryTensor.set(category, 1);
        return new Tensor[]{labeledObject.data, categoryTensor};
      }).toArray(i -> new Tensor[i][]);
    } catch (@Nonnull final IOException e) {
      throw new RuntimeException(e);
    }
  }

  /**
   * Parse int.
   *
   * @param label the label
   * @return the int
   */
  public int parse(final CharSequence label) {
    return this.labels.indexOf(label);
  }

  /**
   * Predict int [ ].
   *
   * @param network       the network
   * @param labeledObject the labeled object
   * @return the int [ ]
   */
  public int[] predict(@Nonnull final Layer network, @Nonnull final LabeledObject labeledObject) {
    @Nullable final double[] predictionSignal = network.eval(labeledObject.data).getData().get(0).getData();
    return IntStream.range(0, categories).mapToObj(x -> x).sorted(Comparator.comparing(i -> -predictionSignal[i])).mapToInt(x -> x).toArray();
  }

  @Nonnull
  @Override
  public ClassifyProblem run(@Nonnull final NotebookOutput log) {
    @Nonnull final TrainingMonitor monitor = TestUtil.getMonitor(history);
    final Tensor[][] trainingData = getTrainingData(log);

    @Nonnull final DAGNetwork network = fwdFactory.imageToVector(log, categories);
    log.h3("Network Diagram");
    log.eval(() -> {
      return Graphviz.fromGraph((Graph) TestUtil.toGraph(network))
          .height(400).width(600).render(Format.PNG).toImage();
    });

    log.h3("Training");
    @Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, lossLayer());
    TestUtil.instrumentPerformance(supervisedNetwork);
    int initialSampleSize = Math.max(trainingData.length / 5, Math.min(10, trainingData.length / 2));
    @Nonnull final ValidatingTrainer trainer = optimizer.train(log,
        new SampledArrayTrainable(trainingData, supervisedNetwork, initialSampleSize, getBatchSize()),
        new ArrayTrainable(trainingData, supervisedNetwork, getBatchSize()), monitor);
    log.run(() -> {
      trainer.setTimeout(timeoutMinutes, TimeUnit.MINUTES).setMaxIterations(10000).run();
    });
    if (!history.isEmpty()) {
      log.eval(() -> {
        return TestUtil.plot(history);
      });
      log.eval(() -> {
        return TestUtil.plotTime(history);
      });
    }

    @Nonnull String training_name = log.getName() + "_" + ClassifyProblem.modelNo++ + "_plot.png";
    try {
      BufferedImage image = Util.toImage(TestUtil.plot(history));
      if (null != image) ImageIO.write(image, "png", log.file(training_name));
    } catch (IOException e) {
      logger.warn("Error writing result images", e);
    }
    log.appendFrontMatterProperty("result_plot", new File(log.getResourceDir(), training_name).toString(), ";");

    TestUtil.extractPerformance(log, supervisedNetwork);
    @Nonnull final String modelName = "classification_model_" + ClassifyProblem.modelNo++ + ".json";
    log.appendFrontMatterProperty("result_model", modelName, ";");
    log.p("Saved model as " + log.file(network.getJson().toString(), modelName, modelName));

    log.h3("Validation");
    log.p("If we apply our model against the entire validation dataset, we get this accuracy:");
    log.eval(() -> {
      return data.validationData().mapToDouble(labeledObject ->
          predict(network, labeledObject)[0] == parse(labeledObject.label) ? 1 : 0)
          .average().getAsDouble() * 100;
    });

    log.p("Let's examine some incorrectly predicted results in more detail:");
    log.eval(() -> {
      try {
        @Nonnull final TableOutput table = new TableOutput();
        Lists.partition(data.validationData().collect(Collectors.toList()), 100).stream().flatMap(batch -> {
          @Nonnull TensorList batchIn = TensorArray.create(batch.stream().map(x -> x.data).toArray(i -> new Tensor[i]));
          TensorList batchOut = network.eval(new ConstantResult(batchIn)).getData();
          return IntStream.range(0, batchOut.length())
              .mapToObj(i -> toRow(log, batch.get(i), batchOut.get(i).getData()));
        }).filter(x -> null != x).limit(10).forEach(table::putRow);
        return table;
      } catch (@Nonnull final IOException e) {
        throw new RuntimeException(e);
      }
    });
    return this;
  }

  protected abstract Layer lossLayer();

  /**
   * To row linked hash buildMap.
   *
   * @param log              the log
   * @param labeledObject    the labeled object
   * @param predictionSignal the prediction signal
   * @return the linked hash buildMap
   */
  @Nullable
  public LinkedHashMap toRow(@Nonnull final NotebookOutput log, @Nonnull final LabeledObject labeledObject, final double[] predictionSignal) {
    final int actualCategory = parse(labeledObject.label);
    final int[] predictionList = IntStream.range(0, categories).mapToObj(x -> x).sorted(Comparator.comparing(i -> -predictionSignal[i])).mapToInt(x -> x).toArray();
    if (predictionList[0] == actualCategory) return null; // We will only examine mispredicted rows
    @Nonnull final LinkedHashMap row = new LinkedHashMap<>();
    row.put("Image", log.png(labeledObject.data.toImage(), labeledObject.label));
    row.put("Prediction", Arrays.stream(predictionList).limit(3)
        .mapToObj(i -> String.format("%d (%.1f%%)", i, 100.0 * predictionSignal[i]))
        .reduce((a, b) -> a + ", " + b).get());
    return row;
  }

  /**
   * Gets batch size.
   *
   * @return the batch size
   */
  public int getBatchSize() {
    return batchSize;
  }

  /**
   * Sets batch size.
   *
   * @param batchSize the batch size
   * @return the batch size
   */
  @Nonnull
  public ClassifyProblem setBatchSize(int batchSize) {
    this.batchSize = batchSize;
    return this;
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy