All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.simiacryptus.mindseye.applications.StyleTransfer Maven / Gradle / Ivy

There is a newer version: 2.1.0
Show newest version
/*
 * Copyright (c) 2019 by Andrew Charneski.
 *
 * The author licenses this file to you under the
 * Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance
 * with the License.  You may obtain a copy
 * of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.simiacryptus.mindseye.applications;

import com.simiacryptus.lang.Tuple2;
import com.simiacryptus.mindseye.eval.ArrayTrainable;
import com.simiacryptus.mindseye.eval.Trainable;
import com.simiacryptus.mindseye.lang.Layer;
import com.simiacryptus.mindseye.lang.ReferenceCountingBase;
import com.simiacryptus.mindseye.lang.Tensor;
import com.simiacryptus.mindseye.lang.cudnn.Precision;
import com.simiacryptus.mindseye.layers.cudnn.*;
import com.simiacryptus.mindseye.models.CVPipe;
import com.simiacryptus.mindseye.models.CVPipe_VGG16;
import com.simiacryptus.mindseye.models.CVPipe_VGG19;
import com.simiacryptus.mindseye.models.LayerEnum;
import com.simiacryptus.mindseye.network.DAGNetwork;
import com.simiacryptus.mindseye.network.DAGNode;
import com.simiacryptus.mindseye.network.InnerNode;
import com.simiacryptus.mindseye.network.PipelineNetwork;
import com.simiacryptus.mindseye.opt.IterativeTrainer;
import com.simiacryptus.mindseye.opt.line.BisectionSearch;
import com.simiacryptus.mindseye.opt.orient.TrustRegionStrategy;
import com.simiacryptus.mindseye.opt.region.RangeConstraint;
import com.simiacryptus.mindseye.opt.region.TrustRegion;
import com.simiacryptus.mindseye.test.StepRecord;
import com.simiacryptus.mindseye.test.TestUtil;
import com.simiacryptus.notebook.FileHTTPD;
import com.simiacryptus.notebook.MarkdownNotebookOutput;
import com.simiacryptus.notebook.NotebookOutput;
import com.simiacryptus.notebook.NullNotebookOutput;
import com.simiacryptus.util.JsonUtil;
import com.simiacryptus.util.Util;
import com.simiacryptus.util.data.ScalarStatistics;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.annotation.Nonnull;
import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.Closeable;
import java.io.IOException;
import java.util.*;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;

/**
 * This notebook implements the Style Transfer protocol outlined in A Neural Algorithm of Artistic Style
 *
 * @param  the type parameter
 * @param  the type parameter
 */
public abstract class StyleTransfer, U extends CVPipe> {

  private static final Logger logger = LoggerFactory.getLogger(StyleTransfer.class);
  /**
   * The Parallel loss functions.
   */
  public boolean parallelLossFunctions = true;
  private boolean tiled = false;

  /**
   * Transfer tensor.
   *
   * @param canvasImage     the canvas png
   * @param styleParameters the style parameters
   * @param trainingMinutes the training minutes
   * @param measureStyle    the measureStyle style
   * @return the tensor
   */
  public Tensor transfer(final Tensor canvasImage, final StyleSetup styleParameters, final int trainingMinutes, final NeuralSetup measureStyle) {
    return transfer(new NullNotebookOutput(), canvasImage, styleParameters, trainingMinutes, measureStyle, 50, true);
  }

  /**
   * Transfer tensor.
   *
   * @param log             the log
   * @param canvasData      the canvas data
   * @param styleParameters the style parameters
   * @param trainingMinutes the training minutes
   * @param measureStyle    the measureStyle style
   * @param maxIterations   the max iterations
   * @param verbose         the verbose
   * @return the tensor
   */
  public Tensor transfer(
      @Nonnull final NotebookOutput log,
      final Tensor canvasData,
      final StyleSetup styleParameters,
      final int trainingMinutes,
      final NeuralSetup measureStyle,
      final int maxIterations,
      final boolean verbose
  ) {
    try {
      transfer(log, styleParameters, trainingMinutes, measureStyle, maxIterations, verbose, canvasData);
      log.p("Result:");
      log.p(log.png(canvasData.toImage(), "Output Canvas"));
      return canvasData;
    } catch (Throwable e) {
      return canvasData;
    }
  }

  /**
   * Transfer.
   *
   * @param log             the log
   * @param styleParameters the style parameters
   * @param trainingMinutes the training minutes
   * @param measureStyle    the measureStyle style
   * @param maxIterations   the max iterations
   * @param verbose         the verbose
   * @param canvas          the canvas
   */
  public void transfer(
      @Nonnull final NotebookOutput log,
      final StyleSetup styleParameters,
      final int trainingMinutes,
      final NeuralSetup measureStyle,
      final int maxIterations,
      final boolean verbose,
      final Tensor canvas
  ) {
//      log.p("Input Content:");
//      log.p(log.png(styleParameters.contentImage, "Content Image"));
//      log.p("Style Content:");
//      styleParameters.styleImages.forEach((file, styleImage) -> {
//        log.p(log.png(styleImage, file));
//      });
//      log.p("Input Canvas:");
//      log.p(log.png(canvasImage, "Input Canvas"));
    System.gc();
    TestUtil.monitorImage(canvas, false, false);
    String imageName = String.format("etc/image_%s.jpg", Long.toHexString(MarkdownNotebookOutput.random.nextLong()));
    log.p(String.format("", imageName, imageName));
    Closeable jpeg = log.getHttpd().addGET(imageName, "image/jpeg", r -> {
      try {
        ImageIO.write(canvas.toImage(), "jpeg", r);
      } catch (IOException e) {
        throw new RuntimeException(e);
      }
    });
    if (verbose) {
      log.p("Input Parameters:");
      log.eval(() -> {
        return ArtistryUtil.toJson(styleParameters);
      });
    }
    NotebookOutput trainingLog = verbose ? log : new NullNotebookOutput();
    Trainable trainable = trainingLog.eval(() -> {
      PipelineNetwork network = fitnessNetwork(measureStyle);
      network.setFrozen(true);
      ArtistryUtil.setPrecision(network, styleParameters.precision);
      TestUtil.instrumentPerformance(network);
      final FileHTTPD server = log.getHttpd();
      if (null != server) ArtistryUtil.addLayersHandler(network, server);
      if (tiled) network = ArtistryUtil.tileCycle(network, 3);
      Trainable trainable1 = getTrainable(canvas, network);
      network.freeRef();
      return trainable1;
    });
    try {
      @Nonnull ArrayList history = new ArrayList<>();
      String training_name = String.format("etc/training_%s.png", Long.toHexString(MarkdownNotebookOutput.random.nextLong()));
      log.p(String.format("", training_name, training_name));
      Closeable png = log.getHttpd().addGET(training_name, "image/png", r -> {
        try {
          ImageIO.write(Util.toImage(TestUtil.plot(history)), "png", r);
        } catch (IOException e) {
          throw new RuntimeException(e);
        }
      });
      trainingLog.eval(() -> {
        new IterativeTrainer(trainable)
            .setMonitor(TestUtil.getMonitor(history))
            .setOrientation(new TrustRegionStrategy() {
              @Override
              public TrustRegion getRegionPolicy(final Layer layer) {
                return new RangeConstraint().setMin(1e-2).setMax(256);
              }
            })
            .setMaxIterations(maxIterations)
            .setIterationsPerSample(100)
            .setLineSearchFactory(name -> new BisectionSearch().setSpanTol(1e-1).setCurrentRate(1e6))
            .setTimeout(trainingMinutes, TimeUnit.MINUTES)
            .setTerminateThreshold(Double.NEGATIVE_INFINITY)
            .runAndFree();
        return TestUtil.plot(history);
      });
      try {
        jpeg.close();
        ImageIO.write(canvas.toImage(), "jpeg", log.file(imageName));
      } catch (IOException e) {
        throw new RuntimeException(e);
      }
      try {
        png.close();
        BufferedImage image = Util.toImage(TestUtil.plot(history));
        if (null != image) ImageIO.write(image, "png", log.file(training_name));
      } catch (IOException e) {
        logger.warn("Error writing result images", e);
      }
      log.p("Result:");
      log.p(log.png(canvas.toImage(), "Output Canvas"));
    } finally {
      trainable.freeRef();
    }
  }

  /**
   * Gets trainable.
   *
   * @param canvas  the canvas
   * @param network the network
   * @return the trainable
   */
  @Nonnull
  public Trainable getTrainable(final Tensor canvas, final PipelineNetwork network) {
    return new ArrayTrainable(network, 1).setVerbose(true).setMask(true).setData(Arrays.asList(new Tensor[][]{{canvas}}));
  }

  /**
   * Gets style components.
   *
   * @param node          the node
   * @param network       the network
   * @param styleParams   the style params
   * @param mean          the mean
   * @param covariance    the covariance
   * @param centeringMode the centering mode
   * @return the style components
   */
  @Nonnull
  public ArrayList> getStyleComponents(
      final DAGNode node,
      final PipelineNetwork network,
      final LayerStyleParams styleParams,
      final Tensor mean,
      final Tensor covariance,
      final CenteringMode centeringMode
  ) {
    ArrayList> styleComponents = new ArrayList<>();
    if (null != styleParams && (styleParams.cov != 0 || styleParams.mean != 0)) {
      double meanRms = mean.rms();
      double meanScale = 0 == meanRms ? 1 : (1.0 / meanRms);
      InnerNode negTarget = network.wrap(new ValueLayer(mean.scale(-1)), new DAGNode[]{});
      InnerNode negAvg = network.wrap(new BandAvgReducerLayer().setAlpha(-1), node);
      if (styleParams.enhance != 0 || styleParams.cov != 0) {
        DAGNode recentered;
        switch (centeringMode) {
          case Origin:
            recentered = node;
            break;
          case Dynamic:
            recentered = network.wrap(new GateBiasLayer(), node, negAvg);
            break;
          case Static:
            recentered = network.wrap(new GateBiasLayer(), node, negTarget);
            break;
          default:
            throw new RuntimeException();
        }
        int[] covDim = covariance.getDimensions();
        double covRms = covariance.rms();
        if (styleParams.enhance != 0) {
          styleComponents.add(new Tuple2<>(-(0 == covRms ? styleParams.enhance : (styleParams.enhance / covRms)), network.wrap(
              new AvgReducerLayer(),
              network.wrap(new SquareActivationLayer(), recentered)
          )));
        }
        if (styleParams.cov != 0) {
          assert 0 < covDim[2] : Arrays.toString(covDim);
          int inputBands = mean.getDimensions()[2];
          assert 0 < inputBands : Arrays.toString(mean.getDimensions());
          int outputBands = covDim[2] / inputBands;
          assert 0 < outputBands : Arrays.toString(covDim) + " / " + inputBands;
          double covScale = 0 == covRms ? 1 : (1.0 / covRms);
          styleComponents.add(new Tuple2<>(styleParams.cov, network.wrap(
              new MeanSqLossLayer().setAlpha(covScale),
              network.wrap(new ValueLayer(covariance), new DAGNode[]{}),
              network.wrap(new GramianLayer(), recentered)
          )
          ));
        }
      }
      if (styleParams.mean != 0) {
        styleComponents.add(new Tuple2<>(
            styleParams.mean,
            network.wrap(new MeanSqLossLayer().setAlpha(meanScale), negAvg, negTarget)
        ));
      }
    }
    return styleComponents;
  }

  /**
   * Measure style neural setup.
   *
   * @param style the style
   * @return the neural setup
   */
  public NeuralSetup measureStyle(final StyleSetup style) {
    NeuralSetup self = new NeuralSetup(style);
    List keyList = style.styleImages.keySet().stream().collect(Collectors.toList());
    List styleInputs = keyList.stream().map(x -> style.styleImages.get(x)).map(img -> Tensor.fromRGB(img)).collect(Collectors.toList());
    IntStream.range(0, keyList.size()).forEach(i -> {
      self.styleTargets.put(keyList.get(i), new StyleTarget<>());
    });
    self.contentTarget = new ContentTarget<>();
    for (final T layerType : getLayerTypes()) {
      System.gc();
      Layer network = layerType.network();
      try {
        ArtistryUtil.setPrecision((DAGNetwork) network, style.precision);
        //network = new ImgTileSubnetLayer(network, 400,400,400,400);
        Tensor content = network.eval(style.contentImage).getDataAndFree().getAndFree(0);
        self.contentTarget.content.put(layerType, content);
        logger.info(String.format("%s : target content = %s", layerType.name(), content.prettyPrint()));
        logger.info(String.format(
            "%s : content statistics = %s",
            layerType.name(),
            JsonUtil.toJson(new ScalarStatistics().add(content.getData()).getMetrics())
        ));
        for (int i = 0; i < styleInputs.size(); i++) {
          Tensor styleInput = styleInputs.get(i);
          CharSequence key = keyList.get(i);
          StyleTarget styleTarget = self.styleTargets.get(key);
          if (0 == self.style.styles.entrySet().stream().filter(e1 -> e1.getKey().contains(key)).map(x -> x.getValue().params.get(
              layerType)).filter(x -> null != x).filter(x -> x.mean != 0 || x.cov != 0).count())
            continue;
          System.gc();
          Layer wrapAvg = ArtistryUtil.wrapTiledAvg(network.copy(), 400);
          Tensor mean = wrapAvg.eval(styleInput).getDataAndFree().getAndFree(0);
          wrapAvg.freeRef();
          styleTarget.mean.put(layerType, mean);
          logger.info(String.format("%s : style mean = %s", layerType.name(), mean.prettyPrint()));
          logger.info(String.format(
              "%s : mean statistics = %s",
              layerType.name(),
              JsonUtil.toJson(new ScalarStatistics().add(mean.getData()).getMetrics())
          ));
          if (0 == self.style.styles.entrySet().stream().filter(e1 -> e1.getKey().contains(key)).map(x -> x.getValue().params.get(
              layerType)).filter(x -> null != x).filter(x -> x.cov != 0).count())
            continue;
          System.gc();
          Layer gram = ArtistryUtil.wrapTiledAvg(ArtistryUtil.gram(network.copy()), 400);
          Tensor cov0 = gram.eval(styleInput).getDataAndFree().getAndFree(0);
          gram.freeRef();
          gram = ArtistryUtil.wrapTiledAvg(ArtistryUtil.gram(network.copy(), mean), 400);
          Tensor cov1 = gram.eval(styleInput).getDataAndFree().getAndFree(0);
          gram.freeRef();
          styleTarget.cov0.put(layerType, cov0);
          styleTarget.cov1.put(layerType, cov1);
          int featureBands = mean.getDimensions()[2];
          int covarianceElements = cov1.getDimensions()[2];
          int selectedBands = covarianceElements / featureBands;
          logger.info(String.format("%s : target cov0 = %s", layerType.name(), cov0.reshapeCast(featureBands, selectedBands, 1).prettyPrint()));
          logger.info(String.format(
              "%s : cov0 statistics = %s",
              layerType.name(),
              JsonUtil.toJson(new ScalarStatistics().add(cov0.getData()).getMetrics())
          ));
          logger.info(String.format("%s : target cov1 = %s", layerType.name(), cov1.reshapeCast(featureBands, selectedBands, 1).prettyPrint()));
          logger.info(String.format(
              "%s : cov1 statistics = %s",
              layerType.name(),
              JsonUtil.toJson(new ScalarStatistics().add(cov1.getData()).getMetrics())
          ));
        }
      } finally {
        network.freeRef();
      }
    }
    style.contentImage.freeRef();
    return self;
  }

  /**
   * Gets fitness components.
   *
   * @param setup   the setup
   * @param nodeMap the node buildMap
   * @return the fitness components
   */
  @Nonnull
  public List> getFitnessComponents(NeuralSetup setup, final Map nodeMap) {
    List> functions = new ArrayList<>();
    functions.addAll(getContentComponents(setup, nodeMap));
    functions.addAll(getStyleComponents(setup, nodeMap));
    return functions;
  }

  /**
   * Gets style components.
   *
   * @param setup   the setup
   * @param nodeMap the node buildMap
   * @return the style components
   */
  @Nonnull
  public ArrayList> getStyleComponents(NeuralSetup setup, final Map nodeMap) {
    ArrayList> styleComponents = new ArrayList<>();
    for (final List keys : setup.style.styles.keySet()) {
      StyleTarget styleTarget = keys.stream().map(x -> {
        StyleTarget obj = setup.styleTargets.get(x);
        obj.addRef();
        return obj;
      }).reduce((a, b) -> {
        StyleTarget r = a.add(b);
        a.freeRef();
        b.freeRef();
        return r;
      }).map(x -> {
        StyleTarget r = x.scale(1.0 / keys.size());
        x.freeRef();
        return r;
      }).get();
      for (final T layerType : getLayerTypes()) {
        StyleCoefficients styleCoefficients = setup.style.styles.get(keys);
        assert null != styleCoefficients;
        assert null != styleTarget;
        final DAGNode node = nodeMap.get(layerType);
        final PipelineNetwork network = (PipelineNetwork) node.getNetwork();
        LayerStyleParams styleParams = styleCoefficients.params.get(layerType);
        Tensor mean = styleTarget.mean.get(layerType);
        Tensor covariance;
        switch (styleCoefficients.centeringMode) {
          case Origin:
            covariance = styleTarget.cov0.get(layerType);
            break;
          case Dynamic:
          case Static:
            covariance = styleTarget.cov1.get(layerType);
            break;
          default:
            throw new RuntimeException();
        }
        styleComponents.addAll(getStyleComponents(node, network, styleParams, mean, covariance, styleCoefficients.centeringMode));
      }
      styleTarget.freeRef();

    }
    return styleComponents;
  }

  /**
   * Fitness function pipeline network.
   *
   * @param setup the setup
   * @return the pipeline network
   */
  @Nonnull
  public PipelineNetwork fitnessNetwork(NeuralSetup setup) {
    PipelineNetwork pipelineNetwork = getInstance().getNetwork();
    Map nodes = new HashMap<>();
    Map ids = getInstance().getNodes();
    ids.forEach((l, id) -> nodes.put(l, pipelineNetwork.getChildNode(id)));
    PipelineNetwork network = buildNetwork(setup, nodes, pipelineNetwork);
    //network = withClamp(network);
    ArtistryUtil.setPrecision(network, setup.style.precision);
    return network;
  }

  /**
   * Get key types t [ ].
   *
   * @return the t [ ]
   */
  @Nonnull
  public abstract T[] getLayerTypes();

  /**
   * Gets content components.
   *
   * @param setup   the setup
   * @param nodeMap the node buildMap
   * @return the content components
   */
  @Nonnull
  public ArrayList> getContentComponents(NeuralSetup setup, final Map nodeMap) {
    ArrayList> contentComponents = new ArrayList<>();
    for (final T layerType : getLayerTypes()) {
      final DAGNode node = nodeMap.get(layerType);
      final double coeff_content = !setup.style.content.params.containsKey(layerType) ? 0 : setup.style.content.params.get(layerType);
      final PipelineNetwork network1 = (PipelineNetwork) node.getNetwork();
      if (coeff_content != 0) {
        Tensor content = setup.contentTarget.content.get(layerType);
        contentComponents.add(new Tuple2<>(coeff_content, network1.wrap(new MeanSqLossLayer().setAlpha(1.0 / content.rms()),
            node, network1.wrap(new ValueLayer(content), new DAGNode[]{})
        )));
      }
    }
    return contentComponents;
  }

  /**
   * Gets instance.
   *
   * @return the instance
   */
  public abstract U getInstance();

  /**
   * Measure style pipeline network.
   *
   * @param setup   the setup
   * @param nodeMap the node buildMap
   * @param network the network
   * @return the pipeline network
   */
  public PipelineNetwork buildNetwork(NeuralSetup setup, final Map nodeMap, final PipelineNetwork network) {
    List> functions = getFitnessComponents(setup, nodeMap);
    ArtistryUtil.reduce(network, functions, parallelLossFunctions);
    return network;
  }

  /**
   * Is tiled boolean.
   *
   * @return the boolean
   */
  public boolean isTiled() {
    return tiled;
  }

  /**
   * Sets tiled.
   *
   * @param tiled the tiled
   * @return the tiled
   */
  public StyleTransfer setTiled(boolean tiled) {
    this.tiled = tiled;
    return this;
  }

  /**
   * The enum Centering mode.
   */
  public enum CenteringMode {
    /**
     * Dynamic centering mode.
     */
    Dynamic,
    /**
     * Static centering mode.
     */
    Static,
    /**
     * Origin centering mode.
     */
    Origin
  }

  /**
   * The type Vgg 16.
   */
  public static class VGG16 extends StyleTransfer {

    public CVPipe_VGG16 getInstance() {
      return CVPipe_VGG16.INSTANCE;
    }

    @Nonnull
    public CVPipe_VGG16.Layer[] getLayerTypes() {
      return CVPipe_VGG16.Layer.values();
    }

  }

  /**
   * The type Vgg 19.
   */
  public static class VGG19 extends StyleTransfer {

    public CVPipe_VGG19 getInstance() {
      return CVPipe_VGG19.INSTANCE;
    }

    @Nonnull
    public CVPipe_VGG19.Layer[] getLayerTypes() {
      return CVPipe_VGG19.Layer.values();
    }

  }

  /**
   * The type Content coefficients.
   *
   * @param  the type parameter
   */
  public static class ContentCoefficients> {
    /**
     * The Params.
     */
    public final Map params = new HashMap<>();

    /**
     * Set content coefficients.
     *
     * @param l the l
     * @param v the v
     * @return the content coefficients
     */
    public ContentCoefficients set(final T l, final double v) {
      params.put(l, v);
      return this;
    }

  }

  /**
   * The type Layer style params.
   */
  public static class LayerStyleParams {
    /**
     * The Coeff style mean 0.
     */
    public final double mean;
    /**
     * The Coeff style bandCovariance 0.
     */
    public final double cov;
    private final double enhance;

    /**
     * Instantiates a new Layer style params.
     *
     * @param mean    the mean
     * @param cov     the bandCovariance
     * @param enhance the enhance
     */
    public LayerStyleParams(final double mean, final double cov, final double enhance) {
      this.mean = mean;
      this.cov = cov;
      this.enhance = enhance;
    }
  }

  /**
   * The type Style setup.
   *
   * @param  the type parameter
   */
  public static class StyleSetup> {
    /**
     * The Precision.
     */
    public final Precision precision;
    /**
     * The Style png.
     */
    public final transient Map styleImages;
    /**
     * The Styles.
     */
    public final Map, StyleCoefficients> styles;
    /**
     * The Content.
     */
    public final ContentCoefficients content;
    /**
     * The Content png.
     */
    public transient Tensor contentImage;


    /**
     * Instantiates a new Style setup.
     *
     * @param precision           the precision
     * @param contentImage        the content png
     * @param contentCoefficients the content coefficients
     * @param styleImages         the style png
     * @param styles              the styles
     */
    public StyleSetup(
        final Precision precision,
        final Tensor contentImage,
        ContentCoefficients contentCoefficients,
        final Map styleImages,
        final Map, StyleCoefficients> styles
    ) {
      this.precision = precision;
      this.contentImage = contentImage;
      this.styleImages = styleImages;
      this.styles = styles;
      this.content = contentCoefficients;
    }

  }

  /**
   * The type Style coefficients.
   *
   * @param  the type parameter
   */
  public static class StyleCoefficients> {
    /**
     * The Dynamic center.
     */
    public final CenteringMode centeringMode;
    /**
     * The Params.
     */
    public final Map params = new HashMap<>();


    /**
     * Instantiates a new Style coefficients.
     *
     * @param centeringMode the dynamic center
     */
    public StyleCoefficients(final CenteringMode centeringMode) {
      this.centeringMode = centeringMode;
    }

    /**
     * Set style coefficients.
     *
     * @param layerType        the key type
     * @param coeff_style_mean the coeff style mean
     * @param coeff_style_cov  the coeff style bandCovariance
     * @return the style coefficients
     */
    public StyleCoefficients set(final T layerType, final double coeff_style_mean, final double coeff_style_cov) {
      return set(
          layerType,
          coeff_style_mean,
          coeff_style_cov,
          0.0
      );
    }

    /**
     * Set style coefficients.
     *
     * @param layerType        the key type
     * @param coeff_style_mean the coeff style mean
     * @param coeff_style_cov  the coeff style bandCovariance
     * @param dream            the dream
     * @return the style coefficients
     */
    public StyleCoefficients set(final T layerType, final double coeff_style_mean, final double coeff_style_cov, final double dream) {
      params.put(layerType, new LayerStyleParams(coeff_style_mean, coeff_style_cov, dream));
      return this;
    }

  }

  /**
   * The type Content target.
   *
   * @param  the type parameter
   */
  public static class ContentTarget> {
    /**
     * The Content.
     */
    public Map content = new HashMap<>();
  }

  /**
   * The type Style target.
   *
   * @param  the type parameter
   */
  public class StyleTarget> extends ReferenceCountingBase {
    /**
     * The Cov.
     */
    public Map cov0 = new HashMap<>();
    /**
     * The Cov.
     */
    public Map cov1 = new HashMap<>();
    /**
     * The Mean.
     */
    public Map mean = new HashMap<>();

    @Override
    protected void _free() {
      super._free();
      if (null != cov0) cov0.values().forEach(ReferenceCountingBase::freeRef);
      if (null != cov1) cov1.values().forEach(ReferenceCountingBase::freeRef);
      if (null != mean) mean.values().forEach(ReferenceCountingBase::freeRef);
    }

    /**
     * Add style target.
     *
     * @param right the right
     * @return the style target
     */
    public StyleTarget add(StyleTarget right) {
      StyleTarget newStyle = new StyleTarget<>();
      Stream.concat(mean.keySet().stream(), right.mean.keySet().stream()).distinct().forEach(layer -> {
        Tensor l = mean.get(layer);
        Tensor r = right.mean.get(layer);
        if (l != null && l != r) {
          Tensor add = l.add(r);
          newStyle.mean.put(layer, add);
        } else if (l != null) {
          l.addRef();
          newStyle.mean.put(layer, l);
        } else if (r != null) {
          r.addRef();
          newStyle.mean.put(layer, r);
        }
      });
      Stream.concat(cov0.keySet().stream(), right.cov0.keySet().stream()).distinct().forEach(layer -> {
        Tensor l = cov0.get(layer);
        Tensor r = right.cov0.get(layer);
        if (l != null && l != r) {
          Tensor add = l.add(r);
          newStyle.cov0.put(layer, add);
        } else if (l != null) {
          l.addRef();
          newStyle.cov0.put(layer, l);
        } else if (r != null) {
          r.addRef();
          newStyle.cov0.put(layer, r);
        }
      });
      Stream.concat(cov1.keySet().stream(), right.cov1.keySet().stream()).distinct().forEach(layer -> {
        Tensor l = cov1.get(layer);
        Tensor r = right.cov1.get(layer);
        if (l != null && l != r) {
          Tensor add = l.add(r);
          newStyle.cov1.put(layer, add);
        } else if (l != null) {
          l.addRef();
          newStyle.cov1.put(layer, l);
        } else if (r != null) {
          r.addRef();
          newStyle.cov1.put(layer, r);
        }
      });
      return newStyle;
    }

    /**
     * Scale style target.
     *
     * @param value the value
     * @return the style target
     */
    public StyleTarget scale(double value) {
      StyleTarget newStyle = new StyleTarget<>();
      mean.keySet().stream().distinct().forEach(layer -> {
        newStyle.mean.put(layer, mean.get(layer).scale(value));
      });
      cov0.keySet().stream().distinct().forEach(layer -> {
        newStyle.cov0.put(layer, cov0.get(layer).scale(value));
      });
      cov1.keySet().stream().distinct().forEach(layer -> {
        newStyle.cov1.put(layer, cov1.get(layer).scale(value));
      });
      return newStyle;
    }

  }

  public class NeuralSetup {

    /**
     * The Style parameters.
     */
    public final StyleSetup style;
    /**
     * The Content target.
     */
    public ContentTarget contentTarget = new ContentTarget<>();
    /**
     * The Style targets.
     */
    public Map> styleTargets = new HashMap<>();


    /**
     * Instantiates a new Neural setup.
     *
     * @param style the style
     */
    public NeuralSetup(final StyleSetup style) {
      this.style = style;
    }
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy