com.simiacryptus.mindseye.test.unit.SerializationTest Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mindseye-test Show documentation
Show all versions of mindseye-test Show documentation
Testing Tools for Neural Network Components
/*
* Copyright (c) 2019 by Andrew Charneski.
*
* The author licenses this file to you under the
* Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance
* with the License. You may obtain a copy
* of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package com.simiacryptus.mindseye.test.unit;
import com.google.gson.GsonBuilder;
import com.google.gson.JsonObject;
import com.simiacryptus.mindseye.lang.Layer;
import com.simiacryptus.mindseye.lang.SerialPrecision;
import com.simiacryptus.mindseye.lang.Tensor;
import com.simiacryptus.mindseye.test.ToleranceStatistics;
import com.simiacryptus.notebook.NotebookOutput;
import com.simiacryptus.util.Util;
import org.apache.commons.io.IOUtils;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException;
import java.nio.charset.Charset;
import java.util.Arrays;
import java.util.HashMap;
import java.util.zip.GZIPOutputStream;
import java.util.zip.ZipException;
import java.util.zip.ZipFile;
/**
* The type Json apply.
*/
public class SerializationTest extends ComponentTestBase {
@Nonnull
private final HashMap models = new HashMap<>();
private boolean persist = false;
/**
* Compress gz byte [ ].
*
* @param prettyPrint the pretty print
* @return the byte [ ]
*/
public static byte[] compressGZ(@Nonnull String prettyPrint) {
return compressGZ(prettyPrint.getBytes(Charset.forName("UTF-8")));
}
/**
* Compress gz byte [ ].
*
* @param bytes the bytes
* @return the byte [ ]
*/
public static byte[] compressGZ(byte[] bytes) {
@Nonnull ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
try {
try (@Nonnull GZIPOutputStream out = new GZIPOutputStream(byteArrayOutputStream)) {
IOUtils.write(bytes, out);
}
} catch (IOException e) {
throw new RuntimeException(e);
}
return byteArrayOutputStream.toByteArray();
}
@Nullable
@Override
public ToleranceStatistics test(@Nonnull final NotebookOutput log, @Nonnull final Layer layer, final Tensor... inputPrototype) {
log.h1("Serialization");
log.p("This apply will demonstrate the key's JSON serialization, and verify deserialization integrity.");
String prettyPrint = "";
log.h2("Raw Json");
try {
prettyPrint = log.eval(() -> {
final JsonObject json = layer.getJson().getAsJsonObject();
@Nonnull final Layer echo = Layer.fromJson(json);
if (echo == null) throw new AssertionError("Failed to deserialize");
if (layer == echo) throw new AssertionError("Serialization did not copy");
if (!layer.equals(echo)) throw new AssertionError("Serialization not equal");
echo.freeRef();
return new GsonBuilder().setPrettyPrinting().create().toJson(json);
});
@Nonnull String filename = layer.getClass().getSimpleName() + "_" + log.getName() + ".json";
log.p(log.file(prettyPrint, filename, String.format("Wrote Model to %s; %s characters", filename, prettyPrint.length())));
} catch (RuntimeException e) {
e.printStackTrace();
Util.sleep(1000);
} catch (OutOfMemoryError e) {
e.printStackTrace();
Util.sleep(1000);
}
log.p("");
@Nonnull Object outSync = new Object();
if (prettyPrint.isEmpty() || prettyPrint.length() > 1024 * 64)
Arrays.stream(SerialPrecision.values()).parallel().forEach(precision -> {
try {
@Nonnull File file = new File(log.getResourceDir(), log.getName() + "_" + precision.name() + ".zip");
layer.writeZip(file, precision);
@Nonnull final Layer echo = Layer.fromZip(new ZipFile(file));
getModels().put(precision, echo);
synchronized (outSync) {
log.h2(String.format("Zipfile %s", precision.name()));
log.p(log.link(file, String.format("Wrote Model apply %s precision to %s; %.3fMiB bytes", precision, file.getName(), file.length() * 1.0 / (0x100000))));
}
if (!isPersist()) file.delete();
if (echo == null) throw new AssertionError("Failed to deserialize");
if (layer == echo) throw new AssertionError("Serialization did not copy");
if (!layer.equals(echo)) throw new AssertionError("Serialization not equal");
echo.freeRef();
} catch (RuntimeException e) {
e.printStackTrace();
} catch (OutOfMemoryError e) {
e.printStackTrace();
} catch (ZipException e) {
e.printStackTrace();
} catch (IOException e) {
e.printStackTrace();
}
});
return null;
}
/**
* Gets models.
*
* @return the models
*/
@Nonnull
public HashMap getModels() {
return models;
}
/**
* Is persist boolean.
*
* @return the boolean
*/
public boolean isPersist() {
return persist;
}
/**
* Sets persist.
*
* @param persist the persist
* @return the persist
*/
@Nonnull
public SerializationTest setPersist(boolean persist) {
this.persist = persist;
return this;
}
@Nonnull
@Override
public String toString() {
return "SerializationTest{" +
"models=" + models +
", persist=" + persist +
'}';
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy