All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.simiacryptus.mindseye.test.integration.AutoencodingProblem Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (c) 2019 by Andrew Charneski.
 *
 * The author licenses this file to you under the
 * Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance
 * with the License.  You may obtain a copy
 * of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.simiacryptus.mindseye.test.integration;

import com.simiacryptus.lang.UncheckedSupplier;
import com.simiacryptus.mindseye.eval.ArrayTrainable;
import com.simiacryptus.mindseye.eval.SampledArrayTrainable;
import com.simiacryptus.mindseye.lang.Layer;
import com.simiacryptus.mindseye.lang.Result;
import com.simiacryptus.mindseye.lang.Tensor;
import com.simiacryptus.mindseye.lang.TensorList;
import com.simiacryptus.mindseye.layers.StochasticComponent;
import com.simiacryptus.mindseye.network.DAGNetwork;
import com.simiacryptus.mindseye.network.PipelineNetwork;
import com.simiacryptus.mindseye.opt.Step;
import com.simiacryptus.mindseye.opt.TrainingMonitor;
import com.simiacryptus.mindseye.opt.ValidatingTrainer;
import com.simiacryptus.mindseye.test.GraphVizNetworkInspector;
import com.simiacryptus.mindseye.test.StepRecord;
import com.simiacryptus.mindseye.test.TestUtil;
import com.simiacryptus.notebook.NotebookOutput;
import com.simiacryptus.notebook.TableOutput;
import com.simiacryptus.ref.lang.RefUtil;
import com.simiacryptus.util.Util;
import com.simiacryptus.util.test.LabeledObject;
import guru.nidi.graphviz.engine.Format;
import guru.nidi.graphviz.engine.Graphviz;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.awt.image.BufferedImage;
import java.io.IOException;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;

/**
 * The type Autoencoding problem.
 */
@SuppressWarnings("FieldCanBeLocal")
public abstract class AutoencodingProblem implements Problem {

  private static int modelNo = 0;

  private final int batchSize = 10000;
  private final ImageProblemData data;
  private final double dropout;
  private final int features;
  private final FwdNetworkFactory fwdFactory;
  @Nonnull
  private final List history = new ArrayList<>();
  private final OptimizationStrategy optimizer;
  private final RevNetworkFactory revFactory;
  private int timeoutMinutes = 1;

  /**
   * Instantiates a new Autoencoding problem.
   *
   * @param fwdFactory the fwd factory
   * @param optimizer  the optimizer
   * @param revFactory the rev factory
   * @param data       the data
   * @param features   the features
   * @param dropout    the dropout
   */
  public AutoencodingProblem(final FwdNetworkFactory fwdFactory, final OptimizationStrategy optimizer,
                             final RevNetworkFactory revFactory, final ImageProblemData data, final int features, final double dropout) {
    this.fwdFactory = fwdFactory;
    this.optimizer = optimizer;
    this.revFactory = revFactory;
    this.data = data;
    this.features = features;
    this.dropout = dropout;
  }

  @Nonnull
  @Override
  public List getHistory() {
    return history;
  }

  /**
   * Gets timeout minutes.
   *
   * @return the timeout minutes
   */
  public int getTimeoutMinutes() {
    return timeoutMinutes;
  }

  /**
   * Sets timeout minutes.
   *
   * @param timeoutMinutes the timeout minutes
   * @return the timeout minutes
   */
  @Nonnull
  public AutoencodingProblem setTimeoutMinutes(final int timeoutMinutes) {
    this.timeoutMinutes = timeoutMinutes;
    return this;
  }

  /**
   * Get training data tensor [ ] [ ].
   *
   * @return the tensor [ ] [ ]
   */
  @Nonnull
  public Tensor[][] getTrainingData() {
    try {
      return data.trainingData().map(labeledObject -> {
        Tensor[] tensors = {labeledObject.data};
        labeledObject.freeRef();
        return tensors;
      }).toArray(Tensor[][]::new);
    } catch (@Nonnull final IOException e) {
      throw Util.throwException(e);
    }
  }

  /**
   * Parse int.
   *
   * @param label the label
   * @return the int
   */
  public int parse(@Nonnull final String label) {
    return Integer.parseInt(label.replaceAll("[^\\d]", ""));
  }

  @Nonnull
  @Override
  public AutoencodingProblem run(@Nonnull final NotebookOutput log) {

    @Nonnull final DAGNetwork fwdNetwork = fwdFactory.imageToVector(log, features);
    @Nonnull final DAGNetwork revNetwork = revFactory.vectorToImage(log, features);

    @Nonnull final PipelineNetwork echoNetwork = new PipelineNetwork(1);
    RefUtil.freeRef(echoNetwork.add(fwdNetwork.addRef()));
    RefUtil.freeRef(echoNetwork.add(revNetwork.addRef()));

    @Nonnull final PipelineNetwork supervisedNetwork = new PipelineNetwork(1);
    RefUtil.freeRef(supervisedNetwork.add(fwdNetwork.addRef()));
    @Nonnull final StochasticComponent dropoutNoiseLayer = dropout(dropout);
    RefUtil.freeRef(supervisedNetwork.add(dropoutNoiseLayer));
    RefUtil.freeRef(supervisedNetwork.add(revNetwork.addRef()));
    RefUtil.freeRef(supervisedNetwork.add(lossLayer(), supervisedNetwork.getHead(), supervisedNetwork.getInput(0)));

    log.h3("Network Diagrams");
    log.eval(RefUtil.wrapInterface((UncheckedSupplier) () -> {
      return Graphviz.fromGraph(GraphVizNetworkInspector.toGraphviz(fwdNetwork.addRef())).height(400)
          .width(600).render(Format.PNG).toImage();
    }, fwdNetwork.addRef()));
    log.eval(RefUtil.wrapInterface((UncheckedSupplier) () -> {
      return Graphviz.fromGraph(GraphVizNetworkInspector.toGraphviz(revNetwork.addRef())).height(400)
          .width(600).render(Format.PNG).toImage();
    }, revNetwork.addRef()));
    log.eval(RefUtil.wrapInterface((UncheckedSupplier) () -> {
      return Graphviz.fromGraph(GraphVizNetworkInspector.toGraphviz(supervisedNetwork.addRef()))
          .height(400).width(600).render(Format.PNG).toImage();
    }, supervisedNetwork.addRef()));

    @Nonnull final TrainingMonitor monitor = new TrainingMonitor() {
      @Nonnull
      final TrainingMonitor inner = TestUtil.getMonitor(history);

      @Override
      public void log(final String msg) {
        inner.log(msg);
      }

      @Override
      public void onStepComplete(@Nullable final Step currentPoint) {
        inner.onStepComplete(currentPoint == null ? null : currentPoint.addRef());
        if (null != currentPoint)
          currentPoint.freeRef();
      }
    };

    final Tensor[][] trainingData = getTrainingData();

    //MonitoredObject monitoringRoot = new MonitoredObject();
    //TestUtil.addMonitoring(supervisedNetwork, monitoringRoot);

    log.h3("Training");
    TestUtil.instrumentPerformance(supervisedNetwork.addRef());
    @Nonnull final ValidatingTrainer trainer = optimizer.train(log,
        new SampledArrayTrainable(RefUtil.addRef(trainingData),
            supervisedNetwork.addRef(), trainingData.length / 2, batchSize),
        new ArrayTrainable(RefUtil.addRef(trainingData), supervisedNetwork.addRef(),
            batchSize),
        monitor);
    RefUtil.freeRef(trainingData);
    log.run(RefUtil.wrapInterface(() -> {
      trainer.setTimeout(timeoutMinutes, TimeUnit.MINUTES);
      ValidatingTrainer temp_21_0003 = trainer.addRef();
      temp_21_0003.setMaxIterations(10000);
      ValidatingTrainer temp_21_0004 = temp_21_0003.addRef();
      temp_21_0004.run();
      temp_21_0004.freeRef();
      temp_21_0003.freeRef();
    }, trainer));
    if (!history.isEmpty()) {
      log.eval(() -> {
        return TestUtil.plot(history);
      });
      log.eval(() -> {
        return TestUtil.plotTime(history);
      });
    }
    TestUtil.extractPerformance(log, supervisedNetwork);

    {
      @Nonnull final String modelName = "encoder_model" + AutoencodingProblem.modelNo++ + ".json";
      log.p("Saved model as " + log.file(fwdNetwork.getJson().toString(), modelName, modelName));
    }

    fwdNetwork.freeRef();
    @Nonnull final String modelName = "decoder_model" + AutoencodingProblem.modelNo++ + ".json";
    log.p("Saved model as " + log.file(revNetwork.getJson().toString(), modelName, modelName));

    //    log.h3("Metrics");
    //    log.run(() -> {
    //      return TestUtil.toFormattedJson(monitoringRoot.getMetrics());
    //    });

    log.h3("Validation");

    log.p("Here are some re-encoded examples:");
    log.eval(RefUtil.wrapInterface((UncheckedSupplier) () -> {
      @Nonnull final TableOutput table = new TableOutput();
      data.validationData().map(RefUtil.wrapInterface(
          (Function, ? extends LinkedHashMap>) labeledObject -> {
            Result temp_21_0006 = echoNetwork.eval(labeledObject.data.addRef());
            assert temp_21_0006 != null;
            TensorList data = temp_21_0006.getData();
            Tensor tensor = data.get(0);
            LinkedHashMap row = toRow(log, labeledObject, tensor.getData());
            tensor.freeRef();
            data.freeRef();
            temp_21_0006.freeRef();
            return row;
          }, echoNetwork.addRef())).filter(Objects::nonNull).limit(10)
          .forEach(table::putRow);
      return table;
    }, echoNetwork));

    log.p("Some rendered unit vectors:");
    for (int featureNumber = 0; featureNumber < features; featureNumber++) {
      Tensor temp_21_0001 = new Tensor(features);
      temp_21_0001.set(featureNumber, 1);
      @Nonnull final Tensor input = temp_21_0001.addRef();
      temp_21_0001.freeRef();
      Result temp_21_0007 = revNetwork.eval(input);
      assert temp_21_0007 != null;
      TensorList temp_21_0008 = temp_21_0007.getData();
      @Nullable final Tensor tensor = temp_21_0008.get(0);
      temp_21_0008.freeRef();
      temp_21_0007.freeRef();
      log.out(log.png(tensor.toImage(), ""));
      tensor.freeRef();
    }
    revNetwork.freeRef();
    return this;
  }

  /**
   * To row linked hash map.
   *
   * @param log              the log
   * @param labeledObject    the labeled object
   * @param predictionSignal the prediction signal
   * @return the linked hash map
   */
  @Nonnull
  public LinkedHashMap toRow(@Nonnull final NotebookOutput log,
                                                   @Nonnull final LabeledObject labeledObject, final double[] predictionSignal) {
    @Nonnull final LinkedHashMap row = new LinkedHashMap<>();
    row.put("Image", log.png(labeledObject.data.toImage(), labeledObject.label));
    Tensor temp_21_0002 = new Tensor(predictionSignal, labeledObject.data.getDimensions());
    row.put("Echo", log.png(temp_21_0002.toImage(), labeledObject.label));
    labeledObject.freeRef();
    temp_21_0002.freeRef();
    return row;
  }

  /**
   * Loss layer layer.
   *
   * @return the layer
   */
  @Nonnull
  protected abstract Layer lossLayer();

  /**
   * Dropout stochastic component.
   *
   * @param dropout the dropout
   * @return the stochastic component
   */
  @Nonnull
  protected abstract StochasticComponent dropout(double dropout);
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy