All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.simiacryptus.mindseye.applications.SegmentedStyleTransfer Maven / Gradle / Ivy

There is a newer version: 2.1.0
Show newest version
/*
 * Copyright (c) 2019 by Andrew Charneski.
 *
 * The author licenses this file to you under the
 * Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance
 * with the License.  You may obtain a copy
 * of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.simiacryptus.mindseye.applications;

import com.simiacryptus.lang.Tuple2;
import com.simiacryptus.mindseye.eval.ArrayTrainable;
import com.simiacryptus.mindseye.eval.Trainable;
import com.simiacryptus.mindseye.lang.Layer;
import com.simiacryptus.mindseye.lang.ReferenceCountingBase;
import com.simiacryptus.mindseye.lang.Tensor;
import com.simiacryptus.mindseye.lang.cudnn.Precision;
import com.simiacryptus.mindseye.layers.cudnn.*;
import com.simiacryptus.mindseye.layers.java.ImgTileSelectLayer;
import com.simiacryptus.mindseye.models.CVPipe;
import com.simiacryptus.mindseye.models.CVPipe_VGG16;
import com.simiacryptus.mindseye.models.CVPipe_VGG19;
import com.simiacryptus.mindseye.models.LayerEnum;
import com.simiacryptus.mindseye.network.DAGNetwork;
import com.simiacryptus.mindseye.network.DAGNode;
import com.simiacryptus.mindseye.network.InnerNode;
import com.simiacryptus.mindseye.network.PipelineNetwork;
import com.simiacryptus.mindseye.opt.IterativeTrainer;
import com.simiacryptus.mindseye.opt.line.BisectionSearch;
import com.simiacryptus.mindseye.opt.orient.TrustRegionStrategy;
import com.simiacryptus.mindseye.opt.region.RangeConstraint;
import com.simiacryptus.mindseye.opt.region.TrustRegion;
import com.simiacryptus.mindseye.test.StepRecord;
import com.simiacryptus.mindseye.test.TestUtil;
import com.simiacryptus.notebook.FileHTTPD;
import com.simiacryptus.notebook.MarkdownNotebookOutput;
import com.simiacryptus.notebook.NotebookOutput;
import com.simiacryptus.notebook.NullNotebookOutput;
import com.simiacryptus.util.JsonUtil;
import com.simiacryptus.util.Util;
import com.simiacryptus.util.data.ScalarStatistics;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.annotation.Nonnull;
import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.Closeable;
import java.io.IOException;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;

/**
 * This notebook implements the Style Transfer protocol outlined in A Neural Algorithm of Artistic Style
 *
 * @param  the type parameter
 * @param  the type parameter
 */
public abstract class SegmentedStyleTransfer, U extends CVPipe> {

  private static final Logger logger = LoggerFactory.getLogger(SegmentedStyleTransfer.class);
  private final Map> maskCache = new ConcurrentHashMap<>();
  /**
   * The Parallel loss functions.
   */
  public boolean parallelLossFunctions = true;
  private boolean tiled = false;
  private int content_masks = 3;
  private int content_colorClusters = 3;
  private int content_textureClusters = 3;
  private int style_masks = 3;
  private int stlye_colorClusters = 3;
  private int style_textureClusters = 3;

  /**
   * Alpha list list.
   *
   * @param styleInput the style input
   * @param tensors    the tensors
   * @return the list
   */
  public static List alphaList(final Tensor styleInput, final Set tensors) {
    return tensors.stream().map(x -> alpha(styleInput, x)).collect(Collectors.toList());
  }

  /**
   * Alpha tensor.
   *
   * @param content the content
   * @param mask    the mask
   * @return the tensor
   */
  public static Tensor alpha(final Tensor content, final Tensor mask) {
    int xbands = mask.getDimensions()[2] - 1;
    return content.mapCoords(c -> {
      int[] coords = c.getCoords();
      return content.get(c) * mask.get(coords[0], coords[1], Math.min(coords[2], xbands));
    });
  }

  public static double alphaMaskSimilarity(final Tensor contentMask, final Tensor styleMask) {
    Tensor l = contentMask.sumChannels();
    Tensor r = styleMask.sumChannels();
    int[] dimensions = r.getDimensions();
    Tensor resize = Tensor.fromRGB(TestUtil.resize(l.toImage(), dimensions[0], dimensions[1])).sumChannels();
    Tensor a = resize.unit();
    Tensor b = r.unit();
    double dot = a.dot(b);
    a.freeRef();
    b.freeRef();
    r.freeRef();
    l.freeRef();
    return dot;
  }

  /**
   * Alpha map map.
   *
   * @param styleInput the style input
   * @param tensors    the tensors
   * @return the map
   */
  public static Map alphaMap(final Tensor styleInput, final Set tensors) {
    assert null != styleInput;
    assert null != tensors;
    assert tensors.stream().allMatch(x -> x != null);
    return tensors.stream().distinct().collect(Collectors.toMap(x -> x, x -> alpha(styleInput, x)));
  }

  /**
   * Gets trainable.
   *
   * @param canvas  the canvas
   * @param network the network
   * @return the trainable
   */
  @Nonnull
  public Trainable getTrainable(final Tensor canvas, final PipelineNetwork network) {
    return new ArrayTrainable(network, 1).setVerbose(true).setMask(true).setData(Arrays.asList(new Tensor[][]{{canvas}}));
  }

  /**
   * Gets style components.
   *
   * @param node          the node
   * @param network       the network
   * @param styleParams   the style params
   * @param mean          the mean
   * @param covariance    the covariance
   * @param centeringMode the centering mode
   * @return the style components
   */
  @Nonnull
  public ArrayList> getStyleComponents(
      final DAGNode node,
      final PipelineNetwork network,
      final LayerStyleParams styleParams,
      final Tensor mean,
      final Tensor covariance,
      final CenteringMode centeringMode
  ) {
    ArrayList> styleComponents = new ArrayList<>();
    if (null != styleParams && null != mean && (styleParams.cov != 0 || styleParams.mean != 0)) {
      double meanRms = mean.rms();
      double meanScale = 0 == meanRms ? 1 : (1.0 / meanRms);
      InnerNode negTarget = network.wrap(new ValueLayer(mean.scale(-1)), new DAGNode[]{});
      node.addRef();
      InnerNode negAvg = network.wrap(new BandAvgReducerLayer().setAlpha(-1), node);
      if (styleParams.enhance != 0 || styleParams.cov != 0) {
        DAGNode recentered;
        switch (centeringMode) {
          case Origin:
            node.addRef();
            recentered = node;
            break;
          case Dynamic:
            negAvg.addRef();
            node.addRef();
            recentered = network.wrap(new GateBiasLayer(), node, negAvg);
            break;
          case Static:
            node.addRef();
            negTarget.addRef();
            recentered = network.wrap(new GateBiasLayer(), node, negTarget);
            break;
          default:
            throw new RuntimeException();
        }
        int[] covDim = covariance.getDimensions();
        double covRms = covariance.rms();
        if (styleParams.enhance != 0) {
          recentered.addRef();
          styleComponents.add(new Tuple2<>(-(0 == covRms ? styleParams.enhance : (styleParams.enhance / covRms)), network.wrap(
              new AvgReducerLayer(),
              network.wrap(new SquareActivationLayer(), recentered)
          )));
        }
        if (styleParams.cov != 0) {
          assert 0 < covDim[2] : Arrays.toString(covDim);
          int inputBands = mean.getDimensions()[2];
          assert 0 < inputBands : Arrays.toString(mean.getDimensions());
          int outputBands = covDim[2] / inputBands;
          assert 0 < outputBands : Arrays.toString(covDim) + " / " + inputBands;
          double covScale = 0 == covRms ? 1 : (1.0 / covRms);
          recentered.addRef();
          styleComponents.add(new Tuple2<>(styleParams.cov, network.wrap(
              new MeanSqLossLayer().setAlpha(covScale),
              network.wrap(new ValueLayer(covariance), new DAGNode[]{}),
              network.wrap(new GramianLayer(), recentered)
          )
          ));
        }
        recentered.freeRef();
      }
      if (styleParams.mean != 0) {
        styleComponents.add(new Tuple2<>(
            styleParams.mean,
            network.wrap(new MeanSqLossLayer().setAlpha(meanScale), negAvg, negTarget)
        ));
      }
    }
    return styleComponents;
  }

  /**
   * Transfer.
   *
   * @param log             the log
   * @param styleParameters the style parameters
   * @param trainingMinutes the training minutes
   * @param measureStyle    the measureStyle style
   * @param maxIterations   the max iterations
   * @param verbose         the verbose
   * @param canvas          the canvas
   */
  public Tensor transfer(
      @Nonnull final NotebookOutput log,
      final StyleSetup styleParameters,
      final int trainingMinutes,
      final NeuralSetup measureStyle,
      final int maxIterations,
      final boolean verbose,
      final Tensor canvas
  ) {
//      log.p("Input Content:");
//      log.p(log.png(styleParameters.contentImage, "Content Image"));
//      log.p("Style Content:");
//      styleParameters.styleImages.forEach((file, styleImage) -> {
//        log.p(log.png(styleImage, file));
//      });
//      log.p("Input Canvas:");
//      log.p(log.png(canvasImage, "Input Canvas"));
    System.gc();
    NotebookOutput trainingLog = verbose ? log : new NullNotebookOutput();
    if (1 < getContent_masks()) log.h2("Content Partitioning");
    Set masks = getMasks(
        log,
        measureStyle.contentSource,
        new MaskJob(getContent_masks(), getContent_colorClusters(), getContent_textureClusters(), "content")
    );
    System.gc();
    if (1 < getContent_masks()) log.h2("Content Painting");
    if (verbose) {
      log.p("Input Parameters:");
      log.eval(() -> {
        return ArtistryUtil.toJson(styleParameters);
      });
    }
    final FileHTTPD server = log.getHttpd();
    TestUtil.monitorImage(canvas, false, false);
    String imageName = String.format("etc/image_%s.jpg", Long.toHexString(MarkdownNotebookOutput.random.nextLong()));
    log.p(String.format("", imageName, imageName));
    Closeable jpeg = server.addGET(imageName, "image/jpeg", r -> {
      try {
        ImageIO.write(canvas.toImage(), "jpeg", r);
      } catch (IOException e) {
        throw new RuntimeException(e);
      }
    });
    Trainable trainable = trainingLog.eval(() -> {
      PipelineNetwork network = fitnessNetwork(measureStyle, masks);
      network.setFrozen(true);
      if (null != server) ArtistryUtil.addLayersHandler(network, server);
      if (tiled) network = ArtistryUtil.tileCycle(network, 3);
      ArtistryUtil.setPrecision(network, styleParameters.precision);
      TestUtil.instrumentPerformance(network);
      Trainable trainable1 = getTrainable(canvas, network);
      network.freeRef();
      return trainable1;
    });
    masks.forEach(ReferenceCountingBase::freeRef);
    try {
      @Nonnull ArrayList history = new ArrayList<>();
      String training_name = String.format("etc/training_plot_%s.png", Long.toHexString(MarkdownNotebookOutput.random.nextLong()));
      log.p(String.format("", training_name, training_name));
      Closeable closeable = log.getHttpd().addGET(training_name, "image/png", r -> {
        try {
          BufferedImage image1 = Util.toImage(TestUtil.plot(history));
          if (null != image1) ImageIO.write(image1, "png", r);
        } catch (IOException e) {
          logger.warn("Error writing result images", e);
        }
      });
      trainingLog.run(() -> {
        new IterativeTrainer(trainable)
            .setMonitor(TestUtil.getMonitor(history))
            .setOrientation(new TrustRegionStrategy() {
              @Override
              public TrustRegion getRegionPolicy(final Layer layer) {
                return new RangeConstraint().setMin(1e-2).setMax(256);
              }
            })
            .setMaxIterations(maxIterations)
            .setIterationsPerSample(100)
            .setLineSearchFactory(name -> new BisectionSearch().setSpanTol(1e-1).setCurrentRate(1e6))
            .setTimeout(trainingMinutes, TimeUnit.MINUTES)
            .setTerminateThreshold(Double.NEGATIVE_INFINITY)
            .runAndFree();
      });
      try {
        closeable.close();
        BufferedImage image = Util.toImage(TestUtil.plot(history));
        if (null != image) ImageIO.write(image, "png", log.file(training_name));
      } catch (IOException e) {
        logger.warn("Error writing result images", e);
      }
      try {
        ImageIO.write(canvas.toImage(), "jpeg", log.file(imageName));
      } catch (IOException e) {
        logger.warn("Error writing result images", e);
      }
      return canvas;
    } finally {
      trainable.freeRef();
    }
  }

  /**
   * Measure style neural setup.
   *
   * @param log   the log
   * @param style the style
   * @return the neural setup
   */
  public NeuralSetup measureStyle(final NotebookOutput log, final StyleSetup style) {
    NeuralSetup self = new NeuralSetup(style);
    measureStyles(log, style, self);
    measureContent(log, style, self);
    return self;
  }

  public List getStyleKeys(final StyleSetup style) {
    return style.styleImages.keySet().stream().collect(Collectors.toList());
  }

  public void measureContent(
      final NotebookOutput log,
      final StyleSetup style,
      final NeuralSetup self
  ) {
    self.contentTarget = new ContentTarget<>();
    self.contentSource = style.contentImage;
    for (final T layerType : getLayerTypes()) {
      System.gc();
      Layer network = layerType.network();
      try {
        ArtistryUtil.setPrecision((DAGNetwork) network, style.precision);
        //network = new ImgTileSubnetLayer(network, 400,400,400,400);
        if (null != style.contentImage) {
          Tensor content = network.eval(style.contentImage).getDataAndFree().getAndFree(0);
          System.gc();
          self.contentTarget.content.put(layerType, content);
          logger.info(String.format("%s : target content = %s", layerType.name(), content.prettyPrint()));
          logger.info(String.format(
              "%s : content statistics = %s",
              layerType.name(),
              JsonUtil.toJson(new ScalarStatistics().add(content.getData()).getMetrics())
          ));
        }
      } finally {
        network.freeRef();
      }
    }
  }

  public Map> measureStyles(
      final NotebookOutput log,
      final StyleSetup style,
      final NeuralSetup self
  ) {
    final List styleKeys = getStyleKeys(style);
    IntStream.range(0, styleKeys.size()).forEach(i -> {
      self.styleTargets.put(styleKeys.get(i), new SegmentedStyleTarget<>());
    });
    Map styleInputs = styleKeys.stream().collect(Collectors.toMap(x -> x, x -> {
      Tensor tensor = style.styleImages.get(x);
      tensor.assertAlive();
      return tensor;
    }));
    if (1 < getStyle_masks())
      log.h2(String.format("Style Partitioning (%d/%d/%d)", getStyle_masks(), getStlye_colorClusters(), getStyle_textureClusters()));
    Map> masks = styleInputs.entrySet().stream().collect(Collectors.toMap(x -> x.getValue(), (styleInput) -> {
      Set masks1 = getMasks(
          log,
          styleInput.getValue(),
          new MaskJob(getStyle_masks(), getStlye_colorClusters(), getStyle_textureClusters(), styleInput.getKey())
      );
      assert null != masks1;
      assert 0 != masks1.size();
      assert masks1.stream().allMatch(x -> x != null);
      assert masks1.stream().count() == masks1.stream().distinct().count();
      return masks1;
    }));
    for (final T layerType : getLayerTypes()) {
      System.gc();
      Layer network = layerType.network();
      try {
        ArtistryUtil.setPrecision((DAGNetwork) network, style.precision);
        for (Map.Entry styleEntry : styleInputs.entrySet()) {
          CharSequence key = styleEntry.getKey();
          SegmentedStyleTarget segmentedStyleTarget = self.styleTargets.get(key);
          if (0 == self.style.styles.entrySet().stream().filter(e1 -> e1.getKey().contains(key)).map(x -> x.getValue().params.get(
              layerType)).filter(x -> null != x).filter(x -> x.mean != 0 || x.cov != 0).count())
            continue;
          Tensor styleInput = styleEntry.getValue();
          alphaMap(styleInput, masks.get(styleInput)).forEach((mask, styleMask) -> {
            StyleTarget styleTarget = segmentedStyleTarget.getSegment(mask);
            if (0 == self.style.styles.entrySet().stream().filter(e1 -> e1.getKey().contains(key)).map(x -> x.getValue().params.get(
                layerType)).filter(x -> null != x).filter(x -> x.cov != 0).count())
              return;
            measureStyle(network, styleTarget, layerType, styleMask, 800);
            logStyle(styleTarget, layerType);
          });
        }
      } finally {
        network.freeRef();
      }
    }
    masks.forEach((k, v) -> {
      //k.freeRef();
      v.forEach(ReferenceCountingBase::freeRef);
    });
    return masks;
  }

  public void logStyle(final StyleTarget styleTarget, final T layerType) {
    Tensor cov0 = styleTarget.cov0.get(layerType);
    Tensor cov1 = styleTarget.cov1.get(layerType);
    Tensor mean = styleTarget.mean.get(layerType);
    if (null == mean) return;
    int featureBands = mean.getDimensions()[2];
    int covarianceElements = cov1.getDimensions()[2];
    int selectedBands = covarianceElements / featureBands;
    logger.info(String.format("%s : style mean = %s", layerType.name(), mean.prettyPrint()));
    logger.info(String.format(
        "%s : mean statistics = %s",
        layerType.name(),
        JsonUtil.toJson(new ScalarStatistics().add(mean.getData()).getMetrics())
    ));
    logger.info(String.format("%s : target cov0 = %s", layerType.name(), cov0.reshapeCast(featureBands, selectedBands, 1).prettyPrintAndFree()));
    logger.info(String.format(
        "%s : cov0 statistics = %s",
        layerType.name(),
        JsonUtil.toJson(new ScalarStatistics().add(cov0.getData()).getMetrics())
    ));
    logger.info(String.format("%s : target cov1 = %s", layerType.name(), cov1.reshapeCast(featureBands, selectedBands, 1).prettyPrintAndFree()));
    logger.info(String.format(
        "%s : cov1 statistics = %s",
        layerType.name(),
        JsonUtil.toJson(new ScalarStatistics().add(cov1.getData()).getMetrics())
    ));
  }

  public void measureStyle(final Layer network, final StyleTarget styleTarget, final T layerType, final Tensor image, int tileSize) {
    int[] dimensions = image.getDimensions();
    int width = tileSize;
    int height = tileSize;
    int strideX = tileSize;
    int strideY = tileSize;
    int cols = (int) Math.max(1, (Math.ceil((dimensions[0] - width) * 1.0 / strideX) + 1));
    int rows = (int) Math.max(1, (Math.ceil((dimensions[1] - height) * 1.0 / strideY) + 1));
    if (cols == 1 && rows == 1) {
      measureStyle(network, styleTarget, layerType, image);
    } else {
      StyleTarget tiledStyle = IntStream.range(0, rows).mapToObj(x -> x).flatMap(row -> {
        return IntStream.range(0, cols).mapToObj(col -> {
          StyleTarget styleTarget1 = new StyleTarget<>();
          int positionX = col * strideX;
          int positionY = row * strideY;
          assert positionX >= 0;
          assert positionY >= 0;
          assert positionX < dimensions[0];
          assert positionY < dimensions[1];
          ImgTileSelectLayer tileSelectLayer = new com.simiacryptus.mindseye.layers.java.ImgTileSelectLayer(width, height, positionX, positionY);
          Tensor selectedTile = tileSelectLayer.eval(image).getDataAndFree().getAndFree(0);
          tileSelectLayer.freeRef();
          double factor = (double) selectedTile.length() / image.length();
          measureStyle(network, styleTarget1, layerType, selectedTile);
          StyleTarget scale = styleTarget1.scale(factor);
          styleTarget1.freeRef();
          return scale;
        });
      }).reduce((a, b) -> {
        StyleTarget add = a.add(b);
        a.freeRef();
        b.freeRef();
        return add;
      }).get();
      System.gc();
      put(tiledStyle, styleTarget);
      tiledStyle.freeRef();
    }
  }

  public void put(final StyleTarget fromStyle, final StyleTarget toStyle) {
    toStyle.mean.putAll(fromStyle.mean);
    fromStyle.mean.values().stream().forEach(ReferenceCountingBase::addRef);
    toStyle.cov0.putAll(fromStyle.cov0);
    fromStyle.cov0.values().stream().forEach(ReferenceCountingBase::addRef);
    toStyle.cov1.putAll(fromStyle.cov1);
    fromStyle.cov1.values().stream().forEach(ReferenceCountingBase::addRef);
  }

  public void measureStyle(final Layer network, final StyleTarget styleTarget, final T layerType, final Tensor image) {
    try {
      if (image.length() <= 0) throw new IllegalArgumentException(Arrays.toString(image.getDimensions()));
      Layer wrapAvg = null;
      Tensor mean;
      try {
        wrapAvg = ArtistryUtil.wrapTiledAvg(network.copy(), 400);
        System.gc();
        ArtistryUtil.setPrecision((DAGNetwork) wrapAvg, Precision.Float);
        mean = wrapAvg.eval(image).getDataAndFree().getAndFree(0);
        if (mean.length() > 0 && styleTarget.mean.put(layerType, mean) != null) throw new AssertionError();
      } finally {
        if (null != wrapAvg) wrapAvg.freeRef();
      }

      Layer gram = null;
      try {
        gram = ArtistryUtil.wrapTiledAvg(ArtistryUtil.gram(network.copy()), 400);
        System.gc();
        ArtistryUtil.setPrecision((DAGNetwork) gram, Precision.Float);
        Tensor cov0 = gram.eval(image).getDataAndFree().getAndFree(0);
        if (cov0.length() > 0 && styleTarget.cov0.put(layerType, cov0) != null) throw new AssertionError();
      } finally {
        if (null != gram) gram.freeRef();
      }
      try {
        gram = ArtistryUtil.wrapTiledAvg(ArtistryUtil.gram(network.copy(), mean), 400);
        ArtistryUtil.setPrecision((DAGNetwork) gram, Precision.Float);
        System.gc();
        Tensor cov1 = gram.eval(image).getDataAndFree().getAndFree(0);
        if (cov1.length() > 0 && styleTarget.cov1.put(layerType, cov1) != null) throw new AssertionError();
      } finally {
        if (null != gram) gram.freeRef();
      }
    } finally {
      image.freeRef();
      System.gc();
    }
  }

  public Set getMasks(final NotebookOutput log, final Tensor value, final MaskJob maskJob1) {
    int width = value.getDimensions()[0];
    int height = value.getDimensions()[1];
    return getMaskCache().computeIfAbsent(maskJob1, maskJob -> {
      Set tensors = ImageSegmenter.quickMasks(
          log,
          value,
          maskJob.getStyle_masks(),
          maskJob.getStlye_colorClusters(),
          maskJob.getStyle_textureClusters()
      )
          .stream().distinct().collect(Collectors.toSet());
      assert null != tensors;
      return tensors;
    }).stream().map(img -> {
      Tensor tensor = Tensor.fromRGB(TestUtil.resize(img.toImage(),width, height));
      assert null != tensor;
      return tensor;
    }).collect(Collectors.toSet());
  }

  /**
   * Gets style components.
   *
   * @param setup    the setup
   * @param nodeMap  the node buildMap
   * @param selector the selector
   * @return the style components
   */
  @Nonnull
  public ArrayList> getStyleComponents(
      NeuralSetup setup,
      final Map nodeMap,
      final Function, StyleTarget> selector
  ) {
    ArrayList> styleComponents = new ArrayList<>();
    for (final List keys : setup.style.styles.keySet()) {
      StyleTarget styleTarget = keys.stream().map(x -> {
        SegmentedStyleTarget obj = setup.styleTargets.get(x);
        StyleTarget choose = selector.apply(obj);
        choose.addRef();
        return choose;
      }).reduce((a, b) -> {
        StyleTarget r = a.add(b);
        a.freeRef();
        b.freeRef();
        return r;
      }).map(x -> {
        StyleTarget r = x.scale(1.0 / keys.size());
        x.freeRef();
        return r;
      }).get();
      assert null != styleTarget;
      for (final T layerType : getLayerTypes()) {
        final StyleCoefficients styleCoefficients = setup.style.styles.get(keys);
        assert null != styleCoefficients;
        styleComponents.addAll(getStyleComponents(nodeMap, layerType, styleCoefficients, styleTarget));
      }
      styleTarget.freeRef();
    }
    return styleComponents;
  }

  /**
   * Gets style components.
   *
   * @param nodeMap            the node map
   * @param layerType          the key type
   * @param styleCoefficients  the style coefficients
   * @param chooseStyleSegment the choose style segment
   * @return the style components
   */
  @Nonnull
  public ArrayList> getStyleComponents(
      final Map nodeMap,
      final T layerType,
      final StyleCoefficients styleCoefficients,
      final StyleTarget chooseStyleSegment
  ) {
    final DAGNode node = nodeMap.get(layerType);
    if (null == node) throw new RuntimeException("Not Found: " + layerType);
    final PipelineNetwork network = (PipelineNetwork) node.getNetwork();
    LayerStyleParams styleParams = styleCoefficients.params.get(layerType);
    Tensor mean = chooseStyleSegment.mean.get(layerType);
    Tensor covariance;
    switch (styleCoefficients.centeringMode) {
      case Origin:
        covariance = chooseStyleSegment.cov0.get(layerType);
        break;
      case Dynamic:
      case Static:
        covariance = chooseStyleSegment.cov1.get(layerType);
        break;
      default:
        throw new RuntimeException();
    }
    return getStyleComponents(node, network, styleParams, mean, covariance, styleCoefficients.centeringMode);
  }

  @Nonnull
  public PipelineNetwork fitnessNetwork(final NeuralSetup setup, final Set masks) {
    U networkModel = getNetworkModel();
    PipelineNetwork mainNetwork = networkModel.getNetwork();
    Map modelNodes = networkModel.getNodes();
    List> mainFunctions = new ArrayList<>();
    Map mainNodes = getNodes(modelNodes, mainNetwork, null);
    mainFunctions.addAll(getContentComponents(setup, mainNodes));
    masks.forEach((contentMask) -> {
      HashMap idMap = new HashMap<>();
      DAGNetwork branchNetwork = mainNetwork.scrambleCopy(idMap);
      //logger.info("Branch Keys");
      //branchNetwork.logKeys();
      Map nodeMap = getNodes(modelNodes, branchNetwork, idMap);
      List> branchFunctions = new ArrayList<>();
      branchFunctions.addAll(getStyleComponents(setup, nodeMap,
          x -> x.segments.entrySet().stream().max(Comparator.comparingDouble(e -> alphaMaskSimilarity(
              contentMask,
              e.getKey()
          ))).get().getValue()
      ));
      if (!branchFunctions.isEmpty()) ArtistryUtil.reduce(branchNetwork, branchFunctions, parallelLossFunctions);
      InnerNode importNode = mainNetwork.wrap(
          branchNetwork,
          mainNetwork.wrap(new ProductLayer(), mainNetwork.getInput(0), mainNetwork.constValue(contentMask))
      );
      mainFunctions.add(new Tuple2<>(1.0, importNode));
    });
    ArtistryUtil.reduce(mainNetwork, mainFunctions, parallelLossFunctions);
    ArtistryUtil.setPrecision(mainNetwork, setup.style.precision);
    return mainNetwork;
  }

  @Nonnull
  public Map getNodes(final Map modelNodes, final DAGNetwork network, final HashMap replacements) {
    Map nodes = new HashMap<>();
    modelNodes.forEach((l, id) -> {
      UUID replaced = null == replacements ? id : replace(replacements, id);
      DAGNode childNode = network.getChildNode(replaced);
      if (null == childNode) {
        logger.warn(String.format("Could not find Node ID %s (replaced from %s) to represent %s", replaced, id, l));
      } else {
        nodes.put(l, childNode);
      }
    });
    return nodes;
  }

  @Nonnull
  public UUID replace(final HashMap replacements, final UUID id) {
    return UUID.fromString(replacements.getOrDefault(id.toString(), id.toString()));
  }

  /**
   * Get key types t [ ].
   *
   * @return the t [ ]
   */
  @Nonnull
  public abstract T[] getLayerTypes();

  /**
   * Gets content components.
   *
   * @param setup   the setup
   * @param nodeMap the node buildMap
   * @return the content components
   */
  @Nonnull
  public ArrayList> getContentComponents(NeuralSetup setup, final Map nodeMap) {
    ArrayList> contentComponents = new ArrayList<>();
    for (final T layerType : getLayerTypes()) {
      final DAGNode node = nodeMap.get(layerType);
      final double coeff_content = !setup.style.content.params.containsKey(layerType) ? 0 : setup.style.content.params.get(layerType);
      final PipelineNetwork network1 = (PipelineNetwork) node.getNetwork();
      if (coeff_content != 0) {
        Tensor content = setup.contentTarget.content.get(layerType);
        contentComponents.add(new Tuple2<>(coeff_content, network1.wrap(new MeanSqLossLayer().setAlpha(1.0 / content.rms()),
            node, network1.wrap(new ValueLayer(content), new DAGNode[]{})
        )));
      }
    }
    return contentComponents;
  }

  /**
   * Gets instance.
   *
   * @return the instance
   */
  public abstract U getNetworkModel();

  /**
   * Is tiled boolean.
   *
   * @return the boolean
   */
  public boolean isTiled() {
    return tiled;
  }

  /**
   * Sets tiled.
   *
   * @param tiled the tiled
   * @return the tiled
   */
  public SegmentedStyleTransfer setTiled(boolean tiled) {
    this.tiled = tiled;
    return this;
  }

  public int getContent_masks() {
    return content_masks;
  }

  public SegmentedStyleTransfer setContent_masks(int content_masks) {
    this.content_masks = content_masks;
    return this;
  }

  public int getContent_colorClusters() {
    return content_colorClusters;
  }

  public SegmentedStyleTransfer setContent_colorClusters(int content_colorClusters) {
    this.content_colorClusters = content_colorClusters;
    return this;
  }

  public int getContent_textureClusters() {
    return content_textureClusters;
  }

  public SegmentedStyleTransfer setContent_textureClusters(int content_textureClusters) {
    this.content_textureClusters = content_textureClusters;
    return this;
  }

  public int getStyle_masks() {
    return style_masks;
  }

  public SegmentedStyleTransfer setStyle_masks(int style_masks) {
    this.style_masks = style_masks;
    return this;
  }

  public int getStlye_colorClusters() {
    return stlye_colorClusters;
  }

  public SegmentedStyleTransfer setStlye_colorClusters(int stlye_colorClusters) {
    this.stlye_colorClusters = stlye_colorClusters;
    return this;
  }

  public int getStyle_textureClusters() {
    return style_textureClusters;
  }

  public SegmentedStyleTransfer setStyle_textureClusters(int style_textureClusters) {
    this.style_textureClusters = style_textureClusters;
    return this;
  }

  public Map> getMaskCache() {
    return maskCache;
  }

  /**
   * The enum Centering mode.
   */
  public enum CenteringMode {
    /**
     * Dynamic centering mode.
     */
    Dynamic,
    /**
     * Static centering mode.
     */
    Static,
    /**
     * Origin centering mode.
     */
    Origin
  }

  /**
   * The type Vgg 16.
   */
  public static class VGG16 extends SegmentedStyleTransfer {

    public CVPipe_VGG16 getNetworkModel() {
      return CVPipe_VGG16.INSTANCE;
    }

    @Nonnull
    public CVPipe_VGG16.Layer[] getLayerTypes() {
      return CVPipe_VGG16.Layer.values();
    }

  }

  /**
   * The type Vgg 19.
   */
  public static class VGG19 extends SegmentedStyleTransfer {
    public VGG19() {
    }

    public CVPipe_VGG19 getNetworkModel() {
      return CVPipe_VGG19.INSTANCE;
    }

    @Nonnull
    public CVPipe_VGG19.Layer[] getLayerTypes() {
      return CVPipe_VGG19.Layer.values();
    }

  }

  /**
   * The type Content coefficients.
   *
   * @param  the type parameter
   */
  public static class ContentCoefficients> {
    /**
     * The Params.
     */
    public final Map params = new HashMap<>();

    /**
     * Set content coefficients.
     *
     * @param l the l
     * @param v the v
     * @return the content coefficients
     */
    public ContentCoefficients set(final T l, final double v) {
      params.put(l, v);
      return this;
    }

  }

  /**
   * The type Layer style params.
   */
  public static class LayerStyleParams {
    /**
     * The Coeff style mean 0.
     */
    public final double mean;
    /**
     * The Coeff style bandCovariance 0.
     */
    public final double cov;
    private final double enhance;

    /**
     * Instantiates a new Layer style params.
     *
     * @param mean    the mean
     * @param cov     the bandCovariance
     * @param enhance the enhance
     */
    public LayerStyleParams(final double mean, final double cov, final double enhance) {
      this.mean = mean;
      this.cov = cov;
      this.enhance = enhance;
    }
  }

  /**
   * The type Style setup.
   *
   * @param  the type parameter
   */
  public static class StyleSetup> {
    /**
     * The Precision.
     */
    public final Precision precision;
    /**
     * The Style png.
     */
    public final transient Map styleImages;
    /**
     * The Styles.
     */
    public final Map, StyleCoefficients> styles;
    /**
     * The Content.
     */
    public final ContentCoefficients content;
    /**
     * The Content png.
     */
    public transient Tensor contentImage;


    /**
     * Instantiates a new Style setup.
     *
     * @param precision           the precision
     * @param contentImage        the content png
     * @param contentCoefficients the content coefficients
     * @param styleImages         the style png
     * @param styles              the styles
     */
    public StyleSetup(
        final Precision precision,
        final Tensor contentImage,
        ContentCoefficients contentCoefficients,
        final Map styleImages,
        final Map, StyleCoefficients> styles
    ) {
      if (!styleImages.values().stream().allMatch(x -> x instanceof Tensor)) throw new AssertionError();
      this.precision = precision;
      this.contentImage = contentImage;
      this.styleImages = styleImages;
      this.styles = styles;
      this.content = contentCoefficients;
    }

  }

  /**
   * The type Style coefficients.
   *
   * @param  the type parameter
   */
  public static class StyleCoefficients> {
    /**
     * The Dynamic center.
     */
    public final CenteringMode centeringMode;
    /**
     * The Params.
     */
    public final Map params = new HashMap<>();


    /**
     * Instantiates a new Style coefficients.
     *
     * @param centeringMode the dynamic center
     */
    public StyleCoefficients(final CenteringMode centeringMode) {
      this.centeringMode = centeringMode;
    }

    /**
     * Set style coefficients.
     *
     * @param layerType        the key type
     * @param coeff_style_mean the coeff style mean
     * @param coeff_style_cov  the coeff style bandCovariance
     * @return the style coefficients
     */
    public StyleCoefficients set(final T layerType, final double coeff_style_mean, final double coeff_style_cov) {
      return set(
          layerType,
          coeff_style_mean,
          coeff_style_cov,
          0.0
      );
    }

    /**
     * Set style coefficients.
     *
     * @param layerType        the key type
     * @param coeff_style_mean the coeff style mean
     * @param coeff_style_cov  the coeff style bandCovariance
     * @param dream            the dream
     * @return the style coefficients
     */
    public StyleCoefficients set(final T layerType, final double coeff_style_mean, final double coeff_style_cov, final double dream) {
      params.put(layerType, new LayerStyleParams(coeff_style_mean, coeff_style_cov, dream));
      return this;
    }

  }

  /**
   * The type Content target.
   *
   * @param  the type parameter
   */
  public static class ContentTarget> {
    /**
     * The Content.
     */
    public Map content = new HashMap<>();
  }

  public static class MaskJob {
    private final int style_masks;
    private final int stlye_colorClusters;
    private final int style_textureClusters;
    private final CharSequence key;

    private MaskJob(final int style_masks, final int stlye_colorClusters, final int style_textureClusters, final CharSequence key) {
      this.style_masks = style_masks;
      this.stlye_colorClusters = stlye_colorClusters;
      this.style_textureClusters = style_textureClusters;
      this.key = key;
    }

    public int getStyle_masks() {
      return style_masks;
    }

    public int getStlye_colorClusters() {
      return stlye_colorClusters;
    }

    public int getStyle_textureClusters() {
      return style_textureClusters;
    }

    public CharSequence getKey() {
      return key;
    }

    @Override
    public boolean equals(final Object o) {
      if (this == o) return true;
      if (!(o instanceof MaskJob)) return false;
      final MaskJob maskJob = (MaskJob) o;
      return style_masks == maskJob.style_masks &&
          stlye_colorClusters == maskJob.stlye_colorClusters &&
          style_textureClusters == maskJob.style_textureClusters &&
          Objects.equals(key, maskJob.key);
    }

    @Override
    public int hashCode() {
      return Objects.hash(style_masks, stlye_colorClusters, style_textureClusters, key);
    }
  }

  /**
   * The type Segmented style target.
   *
   * @param  the type parameter
   */
  public static class SegmentedStyleTarget> {
    /**
     * The Segments.
     */
    private final Map> segments = new HashMap<>();

    /**
     * Gets segment.
     *
     * @param styleMask the style mask
     * @return the segment
     */
    public StyleTarget getSegment(final Tensor styleMask) {
      synchronized (segments) {
        StyleTarget styleTarget = segments.computeIfAbsent(styleMask, x -> {
          StyleTarget tStyleTarget = new StyleTarget<>();
          styleMask.addRef();
          return tStyleTarget;
        });
        styleTarget.addRef();
        return styleTarget;
      }
    }
  }

  /**
   * The type Style target.
   *
   * @param  the type parameter
   */
  public static class StyleTarget> extends ReferenceCountingBase {
    /**
     * The Cov.
     */
    public Map cov0 = new HashMap<>();
    /**
     * The Cov.
     */
    public Map cov1 = new HashMap<>();
    /**
     * The Mean.
     */
    public Map mean = new HashMap<>();

    @Override
    protected void _free() {
      super._free();
      if (null != cov0) cov0.values().forEach(ReferenceCountingBase::freeRef);
      if (null != cov1) cov1.values().forEach(ReferenceCountingBase::freeRef);
      if (null != mean) mean.values().forEach(ReferenceCountingBase::freeRef);
    }

    /**
     * Add style target.
     *
     * @param right the right
     * @return the style target
     */
    public StyleTarget add(StyleTarget right) {
      StyleTarget newStyle = new StyleTarget<>();
      Stream.concat(mean.keySet().stream(), right.mean.keySet().stream()).distinct().forEach(layer -> {
        Tensor l = mean.get(layer);
        Tensor r = right.mean.get(layer);
        if (l != null && l != r) {
          Tensor add = l.add(r);
          newStyle.mean.put(layer, add);
        } else if (l != null) {
          l.addRef();
          newStyle.mean.put(layer, l);
        } else if (r != null) {
          r.addRef();
          newStyle.mean.put(layer, r);
        }
      });
      Stream.concat(cov0.keySet().stream(), right.cov0.keySet().stream()).distinct().forEach(layer -> {
        Tensor l = cov0.get(layer);
        Tensor r = right.cov0.get(layer);
        if (l != null && l != r) {
          Tensor add = l.add(r);
          newStyle.cov0.put(layer, add);
        } else if (l != null) {
          l.addRef();
          newStyle.cov0.put(layer, l);
        } else if (r != null) {
          r.addRef();
          newStyle.cov0.put(layer, r);
        }
      });
      Stream.concat(cov1.keySet().stream(), right.cov1.keySet().stream()).distinct().forEach(layer -> {
        Tensor l = cov1.get(layer);
        Tensor r = right.cov1.get(layer);
        if (l != null && l != r) {
          Tensor add = l.add(r);
          newStyle.cov1.put(layer, add);
        } else if (l != null) {
          l.addRef();
          newStyle.cov1.put(layer, l);
        } else if (r != null) {
          r.addRef();
          newStyle.cov1.put(layer, r);
        }
      });
      return newStyle;
    }

    /**
     * Scale style target.
     *
     * @param value the value
     * @return the style target
     */
    public StyleTarget scale(double value) {
      StyleTarget newStyle = new StyleTarget<>();
      mean.keySet().stream().distinct().forEach(layer -> {
        newStyle.mean.put(layer, mean.get(layer).scale(value));
      });
      cov0.keySet().stream().distinct().forEach(layer -> {
        newStyle.cov0.put(layer, cov0.get(layer).scale(value));
      });
      cov1.keySet().stream().distinct().forEach(layer -> {
        newStyle.cov1.put(layer, cov1.get(layer).scale(value));
      });
      return newStyle;
    }

  }

  public static class NeuralSetup> {

    /**
     * The Style parameters.
     */
    public final StyleSetup style;
    /**
     * The Content target.
     */
    public ContentTarget contentTarget = new ContentTarget<>();
    /**
     * The Style targets.
     */
    public Map> styleTargets = new HashMap<>();
    /**
     * The Content source.
     */
    public Tensor contentSource;


    /**
     * Instantiates a new Neural setup.
     *
     * @param style the style
     */
    public NeuralSetup(final StyleSetup style) {
      this.style = style;
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy