com.simiacryptus.mindseye.art.photo.FastPhotoStyleTransfer Maven / Gradle / Ivy
/*
* Copyright (c) 2019 by Andrew Charneski.
*
* The author licenses this file to you under the
* Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance
* with the License. You may obtain a copy
* of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package com.simiacryptus.mindseye.art.photo;
import com.simiacryptus.mindseye.art.photo.affinity.ContextAffinity;
import com.simiacryptus.mindseye.art.photo.affinity.RelativeAffinity;
import com.simiacryptus.mindseye.art.photo.cuda.RefUnaryOperator;
import com.simiacryptus.mindseye.art.photo.cuda.SmoothSolver_Cuda;
import com.simiacryptus.mindseye.art.photo.topology.RadiusRasterTopology;
import com.simiacryptus.mindseye.art.photo.topology.RasterTopology;
import com.simiacryptus.mindseye.lang.*;
import com.simiacryptus.mindseye.network.PipelineNetwork;
import com.simiacryptus.ref.lang.ReferenceCountingBase;
import com.simiacryptus.util.Util;
import javax.annotation.Nonnull;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.HashMap;
import java.util.function.Function;
import java.util.function.UnaryOperator;
import java.util.zip.ZipFile;
import java.util.zip.ZipOutputStream;
import static com.simiacryptus.mindseye.lang.Layer.fromJson;
import static com.simiacryptus.util.JsonUtil.toJson;
/**
* Implemented process detailed in:
* A Closed-form Solution to Photorealistic Image Stylization
* https://arxiv.org/pdf/1802.06474.pdf
*/
public class FastPhotoStyleTransfer extends ReferenceCountingBase implements Function> {
public final Layer encode_1;
public final Layer decode_1;
public final Layer encode_2;
@Nonnull
public final Layer decode_2;
public final Layer encode_3;
@Nonnull
public final Layer decode_3;
public final Layer encode_4;
@Nonnull
public final Layer decode_4;
private boolean useCuda = true;
private boolean smooth = true;
private double lambda = 1e-4;
private double epsilon = 1e-7;
public FastPhotoStyleTransfer(Layer decode_1, Layer encode_1, @Nonnull Layer decode_2, Layer encode_2,
@Nonnull Layer decode_3, Layer encode_3, @Nonnull Layer decode_4, Layer encode_4) {
this.encode_4 = encode_4;
this.decode_4 = decode_4;
this.encode_3 = encode_3;
this.decode_3 = decode_3;
this.encode_2 = encode_2;
this.decode_2 = decode_2;
this.encode_1 = encode_1;
this.decode_1 = decode_1;
}
public double getEpsilon() {
return epsilon;
}
public void setEpsilon(double epsilon) {
this.epsilon = epsilon;
}
public double getLambda() {
return lambda;
}
public void setLambda(double lambda) {
this.lambda = lambda;
}
public boolean isSmooth() {
return smooth;
}
public void setSmooth(boolean smooth) {
this.smooth = smooth;
}
public boolean isUseCuda() {
return useCuda;
}
@Nonnull
public FastPhotoStyleTransfer setUseCuda(boolean useCuda) {
this.useCuda = useCuda;
return this;
}
@Nonnull
public static FastPhotoStyleTransfer fromZip(@Nonnull final ZipFile zipfile) {
@Nonnull
HashMap resources = ZipSerializable.extract(zipfile);
return new FastPhotoStyleTransfer(fromJson(toJson(resources.get("decode_1.json")), resources),
fromJson(toJson(resources.get("encode_1.json")), resources),
fromJson(toJson(resources.get("decode_2.json")), resources),
fromJson(toJson(resources.get("encode_2.json")), resources),
fromJson(toJson(resources.get("decode_3.json")), resources),
fromJson(toJson(resources.get("encode_3.json")), resources),
fromJson(toJson(resources.get("decode_4.json")), resources),
fromJson(toJson(resources.get("encode_4.json")), resources));
}
@Nonnull
public static Tensor transfer(Tensor contentImage, Tensor styleImage, @Nonnull Layer encode, @Nonnull Layer decode,
double contentDensity, double styleDensity) {
final Tensor encodedContent = Result.getData0(encode.eval(contentImage.addRef()));
final Tensor encodedStyle = Result.getData0(encode.eval(styleImage));
encode.freeRef();
final PipelineNetwork applicator = WCTUtil.applicator(encodedStyle, contentDensity, styleDensity);
final Tensor encodedTransformed = Result.getData0(applicator.eval(encodedContent));
applicator.freeRef();
final Tensor tensor = Result.getData0(decode.eval(encodedTransformed, contentImage));
decode.freeRef();
return tensor;
}
public void _free() {
encode_1.freeRef();
decode_1.freeRef();
encode_2.freeRef();
decode_2.freeRef();
encode_3.freeRef();
decode_3.freeRef();
encode_4.freeRef();
decode_4.freeRef();
super._free();
}
public void writeZip(@Nonnull File out, SerialPrecision precision) {
try (@Nonnull
ZipOutputStream zipOutputStream = new ZipOutputStream(new FileOutputStream(out))) {
final HashMap resources = new HashMap<>();
decode_1.writeZip(zipOutputStream, precision, resources, "decode_1.json");
encode_1.writeZip(zipOutputStream, precision, resources, "encode_1.json");
decode_2.writeZip(zipOutputStream, precision, resources, "decode_2.json");
encode_2.writeZip(zipOutputStream, precision, resources, "encode_2.json");
decode_3.writeZip(zipOutputStream, precision, resources, "decode_3.json");
encode_3.writeZip(zipOutputStream, precision, resources, "encode_3.json");
decode_4.writeZip(zipOutputStream, precision, resources, "decode_4.json");
encode_4.writeZip(zipOutputStream, precision, resources, "encode_4.json");
} catch (IOException e) {
throw Util.throwException(e);
}
}
@Nonnull
public RefUnaryOperator apply(@Nonnull Tensor contentImage) {
return new StyleUnaryOperator(contentImage, FastPhotoStyleTransfer.this);
}
@Nonnull
public Tensor photoWCT(Tensor style, Tensor content) {
return photoWCT(style, content, 1.0, 1.0);
}
@Nonnull
public Tensor photoWCT(Tensor style, Tensor content, double contentDensity, double styleDensity) {
Tensor wct1 = photoWCT_1(style.addRef(),
photoWCT_2(style.addRef(),
photoWCT_3(style.addRef(), photoWCT_4(style.addRef(), content, contentDensity, styleDensity), contentDensity, styleDensity),
contentDensity, styleDensity),
contentDensity, styleDensity);
style.freeRef();
return wct1;
}
public @Nonnull
Tensor photoWCT_1(Tensor style, Tensor content) {
return photoWCT_1(style, content, 1.0, 1.0);
}
public @Nonnull
Tensor photoWCT_1(Tensor style, Tensor content, double contentDensity, double styleDensity) {
final Tensor encodedContent = Result.getData0(encode_1.eval(content));
final Tensor encodedStyle = Result.getData0(encode_1.eval(style));
final PipelineNetwork applicator = WCTUtil.applicator(encodedStyle, contentDensity, styleDensity);
final Tensor encodedTransformed = Result.getData0(applicator.eval(encodedContent));
applicator.freeRef();
return Result.getData0(decode_1.eval(encodedTransformed));
}
@Nonnull
public Tensor photoWCT_2(Tensor style, Tensor content) {
return photoWCT_2(style, content, 1.0, 1.0);
}
@Nonnull
public Tensor photoWCT_2(Tensor style, Tensor content, double contentDensity, double styleDensity) {
return transfer(content, style, encode_2.addRef(), decode_2.addRef(), contentDensity, styleDensity);
}
@Nonnull
public Tensor photoWCT_3(Tensor style, Tensor content) {
return photoWCT_3(style, content, 1.0, 1.0);
}
@Nonnull
public Tensor photoWCT_3(Tensor style, Tensor content, double contentDensity, double styleDensity) {
return transfer(content, style, encode_3.addRef(), decode_3.addRef(), contentDensity, styleDensity);
}
@Nonnull
public Tensor photoWCT_4(Tensor style, Tensor content) {
return photoWCT_4(style, content, 1.0, 1.0);
}
@Nonnull
public Tensor photoWCT_4(Tensor style, Tensor content, double contentDensity, double styleDensity) {
return transfer(content, style, encode_4.addRef(), decode_4.addRef(), contentDensity, styleDensity);
}
@Nonnull
public RefUnaryOperator photoSmooth(@Nonnull Tensor content) {
if (isSmooth()) {
RasterTopology topology = new RadiusRasterTopology(content.getDimensions(), RadiusRasterTopology.getRadius(1, 1),
-1);
// RasterAffinity affinity = new MattingAffinity(mask, topology);
ContextAffinity affinity = new RelativeAffinity(content, topology.addRef());
//RasterAffinity affinity = new GaussianAffinity(mask, 20, topology);
//affinity = new TruncatedAffinity(affinity).setMin(1e-2);
return (isUseCuda() ? new SmoothSolver_Cuda() : new SmoothSolver_EJML()).solve(topology, affinity, getLambda());
} else
content.freeRef();
return new NullUnaryOperator();
}
@Nonnull
public @Override
@SuppressWarnings("unused")
FastPhotoStyleTransfer addRef() {
return (FastPhotoStyleTransfer) super.addRef();
}
private static class NullUnaryOperator extends ReferenceCountingBase implements RefUnaryOperator {
@Override
public T apply(T tensor) {
return tensor;
}
public @SuppressWarnings("unused")
void _free() {
super._free();
}
@Nonnull
public @Override
@SuppressWarnings("unused")
NullUnaryOperator addRef() {
return (NullUnaryOperator) super.addRef();
}
}
private static class StyleUnaryOperator extends ReferenceCountingBase implements RefUnaryOperator {
@Nonnull
final RefUnaryOperator photoSmooth;
@Nonnull
private final Tensor contentImage;
@Nonnull
private final FastPhotoStyleTransfer parent;
public StyleUnaryOperator(@Nonnull Tensor contentImage, @Nonnull FastPhotoStyleTransfer parent) {
this.parent = parent;
this.contentImage = contentImage;
photoSmooth = this.parent.photoSmooth(this.contentImage);
}
public void _free() {
contentImage.freeRef();
photoSmooth.freeRef();
parent.freeRef();
super._free();
}
@Override
public Tensor apply(Tensor styleImage) {
return photoSmooth.apply(parent.photoWCT(styleImage, contentImage.addRef()));
}
@Nonnull
public @Override
@SuppressWarnings("unused")
StyleUnaryOperator addRef() {
return (StyleUnaryOperator) super.addRef();
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy