com.simiacryptus.mindseye.art.util.VisionPipelineUtil Maven / Gradle / Ivy
/*
* Copyright (c) 2019 by Andrew Charneski.
*
* The author licenses this file to you under the
* Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance
* with the License. You may obtain a copy
* of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package com.simiacryptus.mindseye.art.util;
import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
import com.simiacryptus.mindseye.art.VisionPipelineLayer;
import com.simiacryptus.mindseye.lang.Coordinate;
import com.simiacryptus.mindseye.lang.Layer;
import com.simiacryptus.mindseye.lang.Tensor;
import com.simiacryptus.mindseye.lang.cudnn.CudaSystem;
import com.simiacryptus.mindseye.layers.cudnn.ImgBandBiasLayer;
import com.simiacryptus.mindseye.layers.cudnn.conv.SimpleConvolutionLayer;
import com.simiacryptus.mindseye.layers.tensorflow.TFLayer;
import com.simiacryptus.mindseye.network.DAGNetwork;
import com.simiacryptus.mindseye.test.TestUtil;
import com.simiacryptus.mindseye.util.TFConverter;
import com.simiacryptus.notebook.NotebookOutput;
import com.simiacryptus.tensorflow.GraphModel;
import com.simiacryptus.util.JsonUtil;
import com.simiacryptus.util.Util;
import org.apache.commons.io.IOUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.fs.s3a.S3AFileSystem;
import org.jetbrains.annotations.NotNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.framework.GraphDef;
import javax.annotation.Nonnull;
import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.*;
import java.net.URL;
import java.text.SimpleDateFormat;
import java.util.*;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
public class VisionPipelineUtil {
private static final Logger log = LoggerFactory.getLogger(VisionPipelineUtil.class);
public static Closeable cudaReports(NotebookOutput log, boolean interceptLog) {
Closeable handler_info = log.getHttpd().addGET("cuda/info.txt", "text/plain", outputStream -> {
try {
PrintStream stream = new PrintStream(outputStream);
CudaSystem.printHeader(stream);
stream.flush();
} catch (Throwable e) {
try {
outputStream.write(Util.toString(e).getBytes("UTF-8"));
} catch (IOException e1) {
e1.printStackTrace();
}
}
});
Closeable handler_stats = log.getHttpd().addGET("cuda/stats.json", "application/json", outputStream -> {
try {
PrintStream stream = new PrintStream(outputStream);
stream.println(JsonUtil.toJson(CudaSystem.getExecutionStatistics()));
stream.flush();
} catch (Throwable e) {
try {
outputStream.write(Util.toString(e).getBytes("UTF-8"));
} catch (IOException e1) {
e1.printStackTrace();
}
}
});
if (interceptLog) log.subreport("cuda_log", sublog -> {
CudaSystem.addLog(new Consumer() {
PrintWriter out;
long remainingOut = 0;
long killAt = 0;
@Override
public void accept(String formattedMessage) {
if (null == out) {
SimpleDateFormat dateFormat = new SimpleDateFormat("dd_HH_mm_ss");
String date = dateFormat.format(new Date());
try {
String caption = String.format("Log at %s", date);
String filename = String.format("%s_cuda.log", date);
out = new PrintWriter(sublog.file(filename));
sublog.p("[%s](etc/%s)", caption, filename);
sublog.write();
} catch (Throwable e) {
throw new RuntimeException(e);
}
killAt = System.currentTimeMillis() + TimeUnit.MINUTES.toMillis(1);
remainingOut = 10L * 1024 * 1024;
}
out.println(formattedMessage);
out.flush();
int length = formattedMessage.length();
remainingOut -= length;
if (remainingOut < 0 || killAt < System.currentTimeMillis()) {
out.close();
out = null;
}
}
});
return null;
});
return new Closeable() {
@Override
public void close() throws IOException {
handler_info.close();
handler_stats.close();
}
};
}
@NotNull
public static Map convertPipeline(GraphDef graphDef, String... nodes) {
GraphModel graphModel = new GraphModel(graphDef.toByteArray());
Map graphs = new HashMap<>();
TFConverter tfConverter = new TFConverter();
TFLayer tfLayer0 = new TFLayer(
graphModel.getChild(nodes[0]).subgraph(new HashSet<>(Arrays.asList())).toByteArray(),
new HashMap<>(),
nodes[0],
"input");
graphs.put(nodes[0], tfConverter.convert(tfLayer0));
tfLayer0.freeRef();
for (int i = 1; i < nodes.length; i++) {
String currentNode = nodes[i];
String priorNode = nodes[i - 1];
TFLayer tfLayer1 = new TFLayer(
graphModel.getChild(currentNode).subgraph(new HashSet<>(Arrays.asList(priorNode))).toByteArray(),
new HashMap<>(),
currentNode,
priorNode);
graphs.put(currentNode, tfConverter.convert(tfLayer1));
tfLayer1.freeRef();
}
return graphs;
}
@NotNull
public static ArrayList getNodes(GraphModel graphModel, List nodes) {
ArrayList graphs = new ArrayList<>();
graphs.add(graphModel.getChild(nodes.get(0)).subgraph(new HashSet<>(Arrays.asList())));
for (int i = 1; i < nodes.size(); i++) {
graphs.add(graphModel.getChild(nodes.get(i)).subgraph(new HashSet<>(Arrays.asList(nodes.get(i - 1)))));
}
return graphs;
}
public static void testPinConnectivity(VisionPipelineLayer layer, int... inputDims) {
DAGNetwork liveTestingNetwork = (DAGNetwork) layer.getLayer();
liveTestingNetwork.visitLayers(l -> {
if (l instanceof SimpleConvolutionLayer) {
Tensor kernel = ((SimpleConvolutionLayer) l).getKernel().map(x -> 1.0);
((SimpleConvolutionLayer) l).set(kernel);
kernel.freeRef();
} else if (l instanceof ImgBandBiasLayer) {
((ImgBandBiasLayer) l).setWeights(x -> 0);
} else if (l instanceof DAGNetwork) {
// Ignore
} else if (!l.state().isEmpty()) {
throw new RuntimeException(l.getClass().toString());
}
});
int[] outputDims = evalDims(inputDims, liveTestingNetwork.addRef());
log.info(String.format("testPins(%s,%s) => %s", layer, Arrays.toString(inputDims), Arrays.toString(outputDims)));
Tensor coordSource = new Tensor(inputDims);
Map> fwdPinMapping = coordSource.coordStream(true).distinct().filter(x -> x.getCoords()[2] == 0).collect(Collectors.toMap(
inputPin -> inputPin,
inputPin -> {
Tensor testInput = new Tensor(inputDims).setAll(0.0).set(inputPin, 1.0);
Tensor testOutput = liveTestingNetwork.eval(testInput).getDataAndFree().getAndFree(0).mapAndFree(outValue -> outValue == 0.0 ? 0.0 : 1.0);
List coordinates = testOutput.coordStream(true).filter(c -> testOutput.get(c) != 0.0 && c.getCoords()[2] == 0).collect(Collectors.toList());
testOutput.freeRef();
testInput.freeRef();
return coordinates;
}));
coordSource.freeRef();
liveTestingNetwork.freeRef();
Map fwdSizes = fwdPinMapping.entrySet().stream().collect(Collectors.groupingBy(
e -> e.getKey(), Collectors.summingInt(e -> e.getValue().size())));
log.info("fwdSizes=" + fwdSizes.entrySet().stream().collect(Collectors.groupingBy(x -> x.getValue(), Collectors.counting())).toString());
int minDividedKernelSize = IntStream.range(0, 2).map(d -> {
return (int) Math.floor((double) layer.getKernelSize()[d] / layer.getStrides()[d]);
}).reduce((a, b) -> a * b).getAsInt();
int maxDividedKernelSize = IntStream.range(0, 2).map(d -> {
return (int) Math.ceil((double) layer.getKernelSize()[d] / layer.getStrides()[d]);
}).reduce((a, b) -> a * b).getAsInt();
if (!fwdSizes.entrySet().stream().filter(e -> e.getValue() == maxDividedKernelSize).findAny().isPresent()) {
log.warn("No Fully Represented Input Found");
}
int kernelSize = IntStream.range(0, 2).map(d -> {
return layer.getKernelSize()[d];
}).reduce((a, b) -> a * b).getAsInt();
Map> bakPinMapping = fwdPinMapping.entrySet().stream().flatMap(fwdEntry -> fwdEntry.getValue().stream()
.map(outputCoord -> new Coordinate[]{fwdEntry.getKey(), outputCoord}))
.collect(Collectors.groupingBy(x -> x[1])).entrySet().stream().collect(Collectors.toMap(
e -> e.getKey(),
e -> e.getValue().stream().map(x -> x[0]).collect(Collectors.toList())));
Map bakSizes = bakPinMapping.entrySet().stream().collect(Collectors.groupingBy(
e -> e.getKey(), Collectors.summingInt(e -> e.getValue().size())));
log.info("bakSizes=" + bakSizes.entrySet().stream().collect(Collectors.groupingBy(x -> x.getValue(), Collectors.counting())).toString());
if (!bakSizes.entrySet().stream().filter(e -> e.getValue() == kernelSize).findAny().isPresent()) {
log.warn("No Fully Represented Output Found");
}
fwdSizes.entrySet().stream().filter(e -> e.getValue() > maxDividedKernelSize).forEach(e -> {
log.info("Overrepresented Input: " + e.getKey() + " = " + e.getValue());
});
fwdSizes.entrySet().stream().filter(e -> e.getValue() < minDividedKernelSize).forEach(e -> {
int[] coords = e.getKey().getCoords();
int[] inputBorders = layer.getInputBorders();
int[] array = IntStream.range(0, inputBorders.length).filter(d -> {
if (inputBorders[d] > coords[d]) return true;
if (((inputDims[d]) - inputBorders[d]) <= coords[d]) return true;
return false;
}).toArray();
if (array.length == 0) {
log.warn("Underrepresented Input: " + e.getKey() + " = " + e.getValue());
}
});
bakSizes.entrySet().stream().filter(e -> e.getValue() < kernelSize).forEach(e -> {
int[] coords = e.getKey().getCoords();
int[] outputBorders = layer.getOutputBorders();
int[] array = IntStream.range(0, outputBorders.length).filter(d -> {
if (outputBorders[d] > coords[d]) return true;
if (((outputDims[d]) - outputBorders[d]) <= coords[d]) return true;
return false;
}).toArray();
if (0 == array.length) {
log.warn("Underrepresented Output: " + e.getKey() + " = " + e.getValue());
}
});
}
public static int[] evalDims(int[] inputDims, Layer layer) {
Tensor input = new Tensor(inputDims);
Tensor tensor = layer.eval(input).getDataAndFree().getAndFree(0);
input.freeRef();
int[] dimensions = tensor.getDimensions();
tensor.freeRef();
layer.freeRef();
return dimensions;
}
@Nonnull
public static BufferedImage load(final CharSequence image, final int imageSize) {
BufferedImage source = getImage(image);
return imageSize <= 0 ? source : TestUtil.resize(source, imageSize, true);
}
@Nonnull
public static BufferedImage load(final CharSequence image, final int width, final int height) {
BufferedImage source = getImage(image);
return width <= 0 ? source : TestUtil.resize(source, width, height);
}
@Nonnull
public static BufferedImage getImage(final CharSequence file) {
if (file.toString().startsWith("http")) {
try {
BufferedImage read = ImageIO.read(new URL(file.toString()));
if (null == read) throw new IllegalArgumentException("Error reading " + file);
return read;
} catch (Throwable e) {
throw new RuntimeException("Error reading " + file, e);
}
}
FileSystem fileSystem = getFileSystem(file.toString());
Path path = new Path(file.toString());
try {
if (!fileSystem.exists(path)) throw new IllegalArgumentException("Not Found: " + path);
try (FSDataInputStream open = fileSystem.open(path)) {
byte[] bytes = IOUtils.toByteArray(open);
try (ByteArrayInputStream in = new ByteArrayInputStream(bytes)) {
return ImageIO.read(in);
}
}
} catch (Throwable e) {
throw new RuntimeException("Error reading " + file, e);
}
}
public static FileSystem getFileSystem(final CharSequence file) {
Configuration conf = getHadoopConfig();
FileSystem fileSystem;
try {
fileSystem = FileSystem.get(new Path(file.toString()).toUri(), conf);
} catch (IOException e) {
throw new RuntimeException(e);
}
return fileSystem;
}
/**
* Gets hadoop config.
*
* @return the hadoop config
*/
@Nonnull
public static Configuration getHadoopConfig() {
Configuration configuration = new Configuration(false);
File tempDir = new File("temp");
tempDir.mkdirs();
configuration.set("hadoop.tmp.dir", tempDir.getAbsolutePath());
// configuration.set("fs.http.impl", org.apache.hadoop.fs.http.HttpFileSystem.class.getCanonicalName());
// configuration.set("fs.https.impl", org.apache.hadoop.fs.http.HttpsFileSystem.class.getCanonicalName());
configuration.set("fs.git.impl", com.simiacryptus.hadoop_jgit.GitFileSystem.class.getCanonicalName());
configuration.set("fs.s3a.impl", S3AFileSystem.class.getCanonicalName());
configuration.set("fs.s3.impl", S3AFileSystem.class.getCanonicalName());
configuration.set("fs.s3a.aws.credentials.provider", DefaultAWSCredentialsProviderChain.class.getCanonicalName());
return configuration;
}
public static int[][] getIndexMap(final SimpleConvolutionLayer layer) {
int[] kernelDimensions = layer.getKernelDimensions();
double b = Math.sqrt(kernelDimensions[2]);
int h = kernelDimensions[1];
int w = kernelDimensions[0];
int l = (int) (w * h * b);
return IntStream.range(0, (int) b).mapToObj(i -> {
return IntStream.range(0, l).map(j -> j + l * i).toArray();
}).toArray(i -> new int[i][]);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy