All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.simiacryptus.mindseye.network.util.AutoencoderNetwork Maven / Gradle / Ivy

/*
 * Copyright (c) 2019 by Andrew Charneski.
 *
 * The author licenses this file to you under the
 * Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance
 * with the License.  You may obtain a copy
 * of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.simiacryptus.mindseye.network.util;

import com.simiacryptus.mindseye.eval.ConstL12Normalizer;
import com.simiacryptus.mindseye.eval.L12Normalizer;
import com.simiacryptus.mindseye.eval.SampledArrayTrainable;
import com.simiacryptus.mindseye.eval.Trainable;
import com.simiacryptus.mindseye.lang.*;
import com.simiacryptus.mindseye.layers.java.*;
import com.simiacryptus.mindseye.network.PipelineNetwork;
import com.simiacryptus.mindseye.network.SimpleLossNetwork;
import com.simiacryptus.mindseye.opt.IterativeTrainer;
import com.simiacryptus.mindseye.opt.Step;
import com.simiacryptus.mindseye.opt.TrainingMonitor;
import com.simiacryptus.mindseye.opt.line.ArmijoWolfeSearch;
import com.simiacryptus.mindseye.opt.line.LineSearchStrategy;
import com.simiacryptus.mindseye.opt.orient.LBFGS;
import com.simiacryptus.mindseye.opt.orient.OrientationStrategy;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

public class AutoencoderNetwork {

  @Nonnull
  private final PipelineNetwork decoder;
  @Nonnull
  private final ReLuActivationLayer decoderActivation;
  @Nonnull
  private final BiasLayer decoderBias;
  @Nonnull
  private final VariableLayer decoderSynapsePlaceholder;
  @Nonnull
  private final DropoutNoiseLayer encodedNoise;
  @Nonnull
  private final PipelineNetwork encoder;
  @Nonnull
  private final ReLuActivationLayer encoderActivation;
  @Nonnull
  private final BiasLayer encoderBias;
  @Nonnull
  private final FullyConnectedLayer encoderSynapse;
  private final int[] innerSize;
  @Nonnull
  private final GaussianNoiseLayer inputNoise;
  @Nonnull
  private final AutoencoderNetwork.Builder networkParameters;
  private final int[] outerSize;
  private Layer decoderSynapse;

  protected AutoencoderNetwork(@Nonnull final AutoencoderNetwork.Builder networkParameters) {
    this.networkParameters = networkParameters;
    outerSize = networkParameters.getOuterSize();
    innerSize = networkParameters.getInnerSize();

    inputNoise = new GaussianNoiseLayer().setValue(networkParameters.getNoise());
    encoderSynapse = new FullyConnectedLayer(outerSize, innerSize);
    encoderSynapse.initSpacial(networkParameters.getInitRadius(), networkParameters.getInitStiffness(), networkParameters.getInitPeak());
    encoderBias = new BiasLayer(innerSize).setWeights(i -> 0.0);
    encoderActivation = (ReLuActivationLayer) new ReLuActivationLayer().freeze();
    encodedNoise = new DropoutNoiseLayer().setValue(networkParameters.getDropout());
    decoderSynapse = encoderSynapse.getTranspose();
    decoderSynapsePlaceholder = new VariableLayer(decoderSynapse);
    decoderBias = new BiasLayer(outerSize).setWeights(i -> 0.0);
    decoderActivation = (ReLuActivationLayer) new ReLuActivationLayer().freeze();

    encoder = new PipelineNetwork();
    encoder.add(inputNoise);
    encoder.add(encoderSynapse);
    encoder.add(encoderBias);
    encoder.add(encoderActivation);
    encoder.add(encodedNoise);

    decoder = new PipelineNetwork();
    decoder.add(decoderSynapsePlaceholder);
    decoder.add(decoderBias);
    decoder.add(decoderActivation);
  }

  public static AutoencoderNetwork.Builder newLayer(final int[] outerSize, final int[] innerSize) {
    return new AutoencoderNetwork.Builder(outerSize, innerSize);
  }

  public TensorList encode(@Nonnull final TensorList data) {
    Layer layer = encoder.getLayer();
    TensorList tensorList = layer
        .evalAndFree(ConstantResult.batchResultArray(data.stream().map(x -> new Tensor[]{x}).toArray(i -> new Tensor[i][])))
        .getDataAndFree();
    layer.freeRef();
    return tensorList;
  }

  @Nonnull
  public Layer getDecoder() {
    return decoder;
  }

  @Nonnull
  public Layer getDecoderActivation() {
    return decoderActivation;
  }

  @Nonnull
  public BiasLayer getDecoderBias() {
    return decoderBias;
  }

  public Layer getDecoderSynapse() {
    return decoderSynapse;
  }

  @Nonnull
  public DropoutNoiseLayer getEncodedNoise() {
    return encodedNoise;
  }

  @Nonnull
  public Layer getEncoder() {
    return encoder;
  }

  @Nonnull
  public Layer getEncoderActivation() {
    return encoderActivation;
  }

  @Nonnull
  public BiasLayer getEncoderBias() {
    return encoderBias;
  }

  @Nonnull
  public FullyConnectedLayer getEncoderSynapse() {
    return encoderSynapse;
  }

  public int[] getInnerSize() {
    return innerSize;
  }

  @Nonnull
  public GaussianNoiseLayer getInputNoise() {
    return inputNoise;
  }

  public int[] getOuterSize() {
    return outerSize;
  }

  public void runMode() {
    inputNoise.setValue(0.0);
    encodedNoise.setValue(0.0);
  }

  @Nonnull
  public AutoencoderNetwork.TrainingParameters train() {
    return new AutoencoderNetwork.TrainingParameters() {
      @Nonnull
      @Override
      public SimpleLossNetwork getTrainingNetwork() {
        @Nonnull final PipelineNetwork student = new PipelineNetwork();
        student.add(encoder);
        student.add(decoder);
        return new SimpleLossNetwork(student, new MeanSqLossLayer());
      }

      @Nonnull
      @Override
      protected TrainingMonitor wrap(@Nonnull final TrainingMonitor monitor) {
        return new TrainingMonitor() {
          @Override
          public void log(final String msg) {
            monitor.log(msg);
          }

          @Override
          public void onStepComplete(final Step currentPoint) {
            monitor.onStepComplete(currentPoint);
          }
        };
      }
    };
  }

  public void trainingMode() {
    inputNoise.setValue(networkParameters.getNoise());
    encodedNoise.setValue(networkParameters.getDropout());
  }

  public static class Builder {

    private final int[] innerSize;
    private final int[] outerSize;
    private double dropout = 0.0;
    private double initPeak = 0.001;
    private double initRadius = 0.5;
    private int initStiffness = 3;
    private double noise = 0.0;

    private Builder(final int[] outerSize, final int[] innerSize) {
      this.outerSize = outerSize;
      this.innerSize = innerSize;
    }

    @Nonnull
    public AutoencoderNetwork build() {
      return new AutoencoderNetwork(AutoencoderNetwork.Builder.this);
    }

    public double getDropout() {
      return dropout;
    }

    @Nonnull
    public AutoencoderNetwork.Builder setDropout(final double dropout) {
      this.dropout = dropout;
      return this;
    }

    public double getInitPeak() {
      return initPeak;
    }

    @Nonnull
    public AutoencoderNetwork.Builder setInitPeak(final double initPeak) {
      this.initPeak = initPeak;
      return this;
    }

    public double getInitRadius() {
      return initRadius;
    }

    @Nonnull
    public AutoencoderNetwork.Builder setInitRadius(final double initRadius) {
      this.initRadius = initRadius;
      return this;
    }

    public int getInitStiffness() {
      return initStiffness;
    }

    @Nonnull
    public AutoencoderNetwork.Builder setInitStiffness(final int initStiffness) {
      this.initStiffness = initStiffness;
      return this;
    }

    public int[] getInnerSize() {
      return innerSize;
    }

    public double getNoise() {
      return noise;
    }

    @Nonnull
    public AutoencoderNetwork.Builder setNoise(final double noise) {
      this.noise = noise;
      return this;
    }

    public int[] getOuterSize() {
      return outerSize;
    }
  }

  public static class RecursiveBuilder {

    private final List dimensions = new ArrayList<>();
    private final List layers = new ArrayList<>();
    private final List representations = new ArrayList<>();

    public RecursiveBuilder(@Nonnull final TensorList data) {
      representations.add(data);
      dimensions.add(data.get(0).getDimensions());
    }

    protected AutoencoderNetwork.Builder configure(final AutoencoderNetwork.Builder builder) {
      return builder;
    }

    protected AutoencoderNetwork.TrainingParameters configure(final AutoencoderNetwork.TrainingParameters trainingParameters) {
      return trainingParameters;
    }

    @Nonnull
    public Layer echo() {
      @Nonnull final PipelineNetwork network = new PipelineNetwork();
      network.add(getEncoder());
      network.add(getDecoder());
      return network;
    }

    @Nonnull
    public Layer getDecoder() {
      @Nonnull final PipelineNetwork network = new PipelineNetwork();
      for (int i = layers.size() - 1; i >= 0; i--) {
        network.add(layers.get(i).getDecoder());
      }
      return network;
    }

    @Nonnull
    public Layer getEncoder() {
      @Nonnull final PipelineNetwork network = new PipelineNetwork();
      for (int i = 0; i < layers.size(); i++) {
        network.add(layers.get(i).getEncoder());
      }
      return network;
    }

    @Nonnull
    public List getLayers() {
      return Collections.unmodifiableList(layers);
    }

    @Nonnull
    public AutoencoderNetwork growLayer(final int... dims) {
      return growLayer(layers.isEmpty() ? 100 : 0, 1, 10, dims);
    }

    @Nonnull
    public AutoencoderNetwork growLayer(final int pretrainingSize, final int pretrainingMinutes, final int pretrainIterations, final int[] dims) {
      trainingMode();
      @Nonnull final AutoencoderNetwork newLayer = configure(AutoencoderNetwork.newLayer(dimensions.get(dimensions.size() - 1), dims)).build();

      final TensorList data = representations.get(representations.size() - 1);
      dimensions.add(dims);
      layers.add(newLayer);

      if (pretrainingSize > 0 && pretrainIterations > 0 && pretrainingMinutes > 0) {
        @Nonnull final ArrayList list = new ArrayList<>(data.stream().collect(Collectors.toList()));
        Collections.shuffle(list);
        @Nonnull final Tensor[] pretrainingSet = list.subList(0, pretrainingSize).toArray(new Tensor[]{});
        configure(newLayer.train()).setMaxIterations(pretrainIterations).setTimeoutMinutes(pretrainingMinutes).run(TensorArray.create(pretrainingSet));
      }
      newLayer.decoderSynapse = ((FullyConnectedLayer) newLayer.decoderSynapse).getTranspose();
      newLayer.decoderSynapsePlaceholder.setInner(newLayer.decoderSynapse);
      configure(newLayer.train()).run(data);

      runMode();
      representations.add(newLayer.encode(data));
      return newLayer;
    }

    public void runMode() {
      layers.forEach(x -> x.runMode());
    }

    public void trainingMode() {
      layers.forEach(x -> x.trainingMode());
    }

    public void tune() {
      configure(new AutoencoderNetwork.TrainingParameters() {
        @Nonnull
        @Override
        public SimpleLossNetwork getTrainingNetwork() {
          @Nonnull final PipelineNetwork student = new PipelineNetwork();
          student.add(getEncoder());
          student.add(getDecoder());
          return new SimpleLossNetwork(student, new MeanSqLossLayer());
        }

        @Nonnull
        @Override
        protected TrainingMonitor wrap(@Nonnull final TrainingMonitor monitor) {
          return new TrainingMonitor() {
            @Override
            public void log(final String msg) {
              monitor.log(msg);
            }

            @Override
            public void onStepComplete(final Step currentPoint) {
              monitor.onStepComplete(currentPoint);
            }
          };
        }
      }).run(representations.get(0));
    }
  }

  public abstract static class TrainingParameters {
    private double endFitness = Double.NEGATIVE_INFINITY;
    private double l1normalization = 0.0;
    private double l2normalization = 0.0;
    private int maxIterations = Integer.MAX_VALUE;
    @Nullable
    private TrainingMonitor monitor = null;
    private OrientationStrategy orient = new LBFGS().setMinHistory(5).setMaxHistory(35);
    private int sampleSize = Integer.MAX_VALUE;
    private LineSearchStrategy step = new ArmijoWolfeSearch().setC2(0.9).setAlpha(1e-4);
    private int timeoutMinutes = 10;

    public double getEndFitness() {
      return endFitness;
    }

    @Nonnull
    public AutoencoderNetwork.TrainingParameters setEndFitness(final double endFitness) {
      this.endFitness = endFitness;
      return this;
    }

    public double getL1normalization() {
      return l1normalization;
    }

    @Nonnull
    public AutoencoderNetwork.TrainingParameters setL1normalization(final double l1normalization) {
      this.l1normalization = l1normalization;
      return this;
    }

    public double getL2normalization() {
      return l2normalization;
    }

    @Nonnull
    public AutoencoderNetwork.TrainingParameters setL2normalization(final double l2normalization) {
      this.l2normalization = l2normalization;
      return this;
    }

    public int getMaxIterations() {
      return maxIterations;
    }

    @Nonnull
    public AutoencoderNetwork.TrainingParameters setMaxIterations(final int maxIterations) {
      this.maxIterations = maxIterations;
      return this;
    }

    @Nullable
    public TrainingMonitor getMonitor() {
      return monitor;
    }

    @Nonnull
    public AutoencoderNetwork.TrainingParameters setMonitor(final TrainingMonitor monitor) {
      this.monitor = monitor;
      return this;
    }

    public OrientationStrategy getOrient() {
      return orient;
    }

    @Nonnull
    public AutoencoderNetwork.TrainingParameters setOrient(final OrientationStrategy orient) {
      this.orient = orient;
      return this;
    }

    public int getSampleSize() {
      return sampleSize;
    }

    @Nonnull
    public AutoencoderNetwork.TrainingParameters setSampleSize(final int sampleSize) {
      this.sampleSize = sampleSize;
      return this;
    }

    public LineSearchStrategy getStep() {
      return step;
    }

    @Nonnull
    public AutoencoderNetwork.TrainingParameters setStep(final LineSearchStrategy step) {
      this.step = step;
      return this;
    }

    public int getTimeoutMinutes() {
      return timeoutMinutes;
    }

    @Nonnull
    public AutoencoderNetwork.TrainingParameters setTimeoutMinutes(final int timeoutMinutes) {
      this.timeoutMinutes = timeoutMinutes;
      return this;
    }

    @Nonnull
    public abstract SimpleLossNetwork getTrainingNetwork();

    public void run(@Nonnull final TensorList data) {
      @Nonnull final SimpleLossNetwork trainingNetwork = getTrainingNetwork();
      @Nonnull final Trainable trainable = new SampledArrayTrainable(data.stream().map(x -> new Tensor[]{x, x}).toArray(i -> new Tensor[i][]), trainingNetwork, getSampleSize());
      @Nonnull final L12Normalizer normalized = new ConstL12Normalizer(trainable).setFactor_L1(getL1normalization()).setFactor_L2(getL2normalization());
      @Nonnull final IterativeTrainer trainer = new IterativeTrainer(normalized);
      trainer.setOrientation(getOrient());
      trainer.setLineSearchFactory((s) -> getStep());
      @Nullable final TrainingMonitor monitor = getMonitor();
      trainer.setMonitor(wrap(monitor));
      trainer.setTimeout(getTimeoutMinutes(), TimeUnit.MINUTES);
      trainer.setTerminateThreshold(getEndFitness());
      trainer.setMaxIterations(maxIterations);
      trainer.runAndFree();
    }

    @Nonnull
    protected abstract TrainingMonitor wrap(TrainingMonitor monitor);
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy