All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.simiacryptus.mindseye.art.util.ImageArtUtil Maven / Gradle / Ivy

/*
 * Copyright (c) 2019 by Andrew Charneski.
 *
 * The author licenses this file to you under the
 * Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance
 * with the License.  You may obtain a copy
 * of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.simiacryptus.mindseye.art.util;

import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
import com.simiacryptus.hadoop_jgit.GitFileSystem;
import com.simiacryptus.mindseye.lang.Tensor;
import com.simiacryptus.mindseye.lang.cudnn.CudaSystem;
import com.simiacryptus.mindseye.layers.cudnn.conv.SimpleConvolutionLayer;
import com.simiacryptus.mindseye.layers.tensorflow.TFLayer;
import com.simiacryptus.mindseye.network.PipelineNetwork;
import com.simiacryptus.mindseye.util.ImageUtil;
import com.simiacryptus.mindseye.util.TFConverter;
import com.simiacryptus.notebook.MarkdownNotebookOutput;
import com.simiacryptus.notebook.NotebookOutput;
import com.simiacryptus.notebook.OptionalUploadImageQuery;
import com.simiacryptus.notebook.UploadImageQuery;
import com.simiacryptus.ref.lang.RefUtil;
import com.simiacryptus.ref.wrappers.*;
import com.simiacryptus.tensorflow.GraphModel;
import com.simiacryptus.util.FastRandom;
import com.simiacryptus.util.JsonUtil;
import com.simiacryptus.util.Util;
import org.apache.commons.io.IOUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.*;
import org.apache.hadoop.fs.s3a.S3AFileSystem;
import org.tensorflow.framework.GraphDef;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.*;
import java.net.URL;
import java.text.SimpleDateFormat;
import java.util.*;
import java.util.concurrent.TimeUnit;
import java.util.stream.IntStream;

public class ImageArtUtil {

  @Nonnull
  public static Configuration getHadoopConfig() {
    Configuration configuration = new Configuration(false);
    File tempDir = new File("temp");
    tempDir.mkdirs();
    configuration.set("hadoop.tmp.dir", tempDir.getAbsolutePath());
    configuration.set("fs.git.impl", GitFileSystem.class.getCanonicalName());
    configuration.set("fs.s3a.impl", S3AFileSystem.class.getCanonicalName());
    configuration.set("fs.s3.impl", S3AFileSystem.class.getCanonicalName());
    configuration.set("fs.s3a.aws.credentials.provider", DefaultAWSCredentialsProviderChain.class.getCanonicalName());
    return configuration;
  }

  @Nonnull
  public static Closeable cudaReports(@Nonnull NotebookOutput log, boolean interceptLog) {
    Closeable handler_info = log.getHttpd().addGET("cuda/info.txt", "text/plain", outputStream -> {
      try {
        PrintStream stream = new PrintStream(outputStream);
        CudaSystem.printHeader(stream);
        stream.flush();
      } catch (Throwable e) {
        try {
          outputStream.write(Util.toString(e).getBytes("UTF-8"));
        } catch (IOException e1) {
          e1.printStackTrace();
        }
      }
    });
    Closeable handler_stats = log.getHttpd().addGET("cuda/stats.json", "application/json", outputStream -> {
      try {
        PrintStream stream = new PrintStream(outputStream);
        stream.println(JsonUtil.toJson(CudaSystem.getExecutionStatistics()));
        stream.flush();
      } catch (Throwable e) {
        try {
          outputStream.write(Util.toString(e).getBytes("UTF-8"));
        } catch (IOException e1) {
          e1.printStackTrace();
        }
      }
    });
    if (interceptLog)
      log.subreport("Cuda Logs",
          sublog -> {
            CudaSystem.addLog(new RefConsumer() {
              @Nullable
              PrintWriter out;
              long remainingOut = 0;
              long killAt = 0;

              @Override
              public void accept(@Nonnull String formattedMessage) {
                if (null == out) {
                  SimpleDateFormat dateFormat = new SimpleDateFormat("dd_HH_mm_ss");
                  String date = dateFormat.format(new Date());
                  try {
                    String caption = RefString.format("Log at %s", date);
                    String filename = RefString.format("%s_cuda.log", date);
                    out = new PrintWriter(sublog.file(filename));
                    sublog.p("[%s](etc/%s)", caption, filename);
                    sublog.write();
                  } catch (Throwable e) {
                    throw Util.throwException(e);
                  }
                  killAt = RefSystem.currentTimeMillis() + TimeUnit.MINUTES.toMillis(1);
                  remainingOut = 10L * 1024 * 1024;
                }
                out.println(formattedMessage);
                out.flush();
                int length = formattedMessage.length();
                remainingOut -= length;
                if (remainingOut < 0 || killAt < RefSystem.currentTimeMillis()) {
                  out.close();
                  out = null;
                }
              }
            });
            return null;
          });

    return new Closeable() {
      @Override
      public void close() throws IOException {
        handler_info.close();
        handler_stats.close();
      }
    };
  }

  @Nonnull
  public static RefMap convertPipeline(@Nonnull GraphDef graphDef, @Nonnull String... nodes) {
    GraphModel graphModel = new GraphModel(graphDef.toByteArray());
    RefMap graphs = new RefHashMap<>();
    TFConverter tfConverter = new TFConverter();
    TFLayer tfLayer0 = new TFLayer(
        graphModel.getChild(nodes[0]).subgraph(new HashSet<>(Arrays.asList())).toByteArray(), new RefHashMap<>(),
        nodes[0], "input");
    RefUtil.freeRef(graphs.put(nodes[0], tfConverter.convert(tfLayer0)));
    for (int i = 1; i < nodes.length; i++) {
      String currentNode = nodes[i];
      String priorNode = nodes[i - 1];
      TFLayer tfLayer1 = new TFLayer(
          graphModel.getChild(currentNode).subgraph(new HashSet<>(Arrays.asList(priorNode))).toByteArray(),
          new RefHashMap<>(), currentNode, priorNode);
      RefUtil.freeRef(graphs.put(currentNode, tfConverter.convert(tfLayer1)));
    }
    return graphs;
  }

  @Nonnull
  public static BufferedImage loadImage(@Nonnull NotebookOutput log, @Nonnull final CharSequence image, final int imageSize) {
    Tensor imageTensor = getImageTensor(image, log, imageSize);
    BufferedImage bufferedImage = imageTensor.toImage();
    imageTensor.freeRef();
    return bufferedImage;
  }

  @Nonnull
  public static BufferedImage[] loadImages(@Nonnull NotebookOutput log, @Nonnull final CharSequence image, final int imageSize) {
    return RefArrays.stream(getImageTensors(image, log, imageSize)).map(tensor -> {
      BufferedImage bufferedImage = tensor.toImage();
      tensor.freeRef();
      return bufferedImage;
    }).toArray(BufferedImage[]::new);
  }

  @Nonnull
  public static BufferedImage[] loadImages(@Nonnull NotebookOutput log, @Nonnull final List image, final int imageSize) {
    return image.stream().flatMap(img -> Arrays.stream(getImageTensors(img, log, imageSize)))
        .map(tensor -> {
          BufferedImage bufferedImage = tensor.toImage();
          tensor.freeRef();
          return bufferedImage;
        }).toArray(BufferedImage[]::new);
  }

  @Nonnull
  public static BufferedImage loadImage(@Nonnull NotebookOutput log, @Nonnull final CharSequence image, final int width, final int height) {
    Tensor imageTensor = getImageTensor(image, log, width, height);
    BufferedImage bufferedImage = imageTensor.toImage();
    imageTensor.freeRef();
    return bufferedImage;
  }

  @Nonnull
  public static BufferedImage[] loadImages(@Nonnull NotebookOutput log, @Nonnull final CharSequence image, final int width, final int height) {
    return RefArrays.stream(getImageTensors(image, log, width, height))
        .map(tensor -> {
          BufferedImage img = tensor.toImage();
          tensor.freeRef();
          return img;
        }).toArray(BufferedImage[]::new);
  }

  @Nonnull
  public static BufferedImage[] loadImages(@Nonnull NotebookOutput log, @Nonnull final List image, final int width, final int height) {
    return image.stream().flatMap(img -> Arrays.stream(getImageTensors(img, log, width, height)))
        .map(tensor -> {
          BufferedImage img = tensor.toImage();
          tensor.freeRef();
          return img;
        }).toArray(BufferedImage[]::new);
  }

  @Nonnull
  public static Tensor getImageTensor(@Nonnull final CharSequence file, @Nonnull NotebookOutput log, int width) {
    String fileStr = file.toString();
    int length = fileStr.split("\\:")[0].length();
    if (length <= 0 || length >= Math.min(7, fileStr.length())) {
      if (fileStr.contains(" + ")) {
        Tensor sampleImage = getImageTensor(fileStr.split(" +\\+ +")[0], log, width);
        int[] sampleImageDimensions = sampleImage.getDimensions();
        sampleImage.freeRef();
        return RefUtil.get(RefArrays.stream(fileStr.split(" +\\+ +"))
            .map(x -> getImageTensor(x, log, sampleImageDimensions[0], sampleImageDimensions[1]))
            .reduce((a, b) -> {
              Tensor r = a.mapCoords(c -> a.get(c) + b.get(c));
              b.freeRef();
              a.freeRef();
              return r;
            }));
      } else if (fileStr.contains(" * ")) {
        Tensor sampleImage = RefUtil.get(RefArrays.stream(fileStr.split(" +\\* +")).map(x -> getImageTensor(x, log, width)).findFirst());
        int[] sampleImageDimensions = sampleImage.getDimensions();
        sampleImage.freeRef();
        return RefUtil.get(RefArrays.stream(fileStr.split(" +\\* +"))
            .map(x -> getImageTensor(x, log, sampleImageDimensions[0], sampleImageDimensions[1]))
            .reduce((a, b) -> {
              Tensor r = a.mapCoords(c -> a.get(c) * b.get(c));
              b.freeRef();
              a.freeRef();
              return r;
            }));
      } else if (fileStr.trim().toLowerCase().equals("plasma")) {
        return new Plasma().paint(width, width);
      } else if (fileStr.trim().toLowerCase().equals("noise")) {
        Tensor baseTensor = new Tensor(width, width, 3);
        Tensor map = baseTensor.map(x -> FastRandom.INSTANCE.random() * 100);
        baseTensor.freeRef();
        return map;
      } else if (fileStr.matches("\\-?\\d+(?:\\.\\d*)?(?:[eE]\\-?\\d+)?")) {
        double v = Double.parseDouble(fileStr);
        Tensor baseTensor = new Tensor(width, width, 3);
        Tensor map = baseTensor.map(x -> v);
        baseTensor.freeRef();
        return map;
      }
    }
    Tensor tensor = getTensor(log, file);
    Tensor resized = Tensor.fromRGB(ImageUtil.resize(tensor.toImage(), width, true));
    tensor.freeRef();
    return resized;
  }

  @Nonnull
  public static Tensor[] getImageTensors(@Nonnull final CharSequence file, @Nonnull NotebookOutput log, int width) {
    return RefArrays.stream(getTensors(log, file))
        .map(tensor -> {
          BufferedImage bufferedImage = tensor.toImage();
          tensor.freeRef();
          return bufferedImage;
        }).map(image -> ImageUtil.resize(image, width, true))
        .map(resize -> Tensor.fromRGB(resize))
        .toArray(Tensor[]::new);
  }

  @Nonnull
  public static Tensor[] getImageTensors(@Nonnull final CharSequence file, @Nonnull NotebookOutput log, int width, int height) {
    return RefArrays.stream(getTensors(log, file))
        .map(tensor -> {
          BufferedImage img = tensor.toImage();
          tensor.freeRef();
          return img;
        })
        .map(image -> ImageUtil.resize(image, width, height))
        .map(resize -> Tensor.fromRGB(resize))
        .toArray(Tensor[]::new);
  }

  @Nonnull
  public static Tensor getImageTensor(@Nonnull final CharSequence file, @Nonnull NotebookOutput log, int width, int height) {
    String fileStr = file.toString();
    int length = fileStr.split("\\:")[0].length();
    if (length <= 0 || length >= Math.min(7, fileStr.length())) {
      if (fileStr.contains(" + ")) {
        Tensor sampleImage = getImageTensor(fileStr.split(" +\\+ +")[0], log, width, height);
        int[] sampleImageDimensions = sampleImage.getDimensions();
        sampleImage.freeRef();
        return RefUtil.get(RefArrays.stream(fileStr.split(" +\\+ +"))
            .map(x -> getImageTensor(x, log, sampleImageDimensions[0], sampleImageDimensions[1]))
            .reduce((a, b) -> {
              Tensor r = a.mapCoords(c -> Math.min(255, Math.max(0, a.get(c) + b.get(c))));
              a.freeRef();
              b.freeRef();
              return r;
            }));
      } else if (fileStr.contains(" * ")) {
        Tensor sampleImage = getImageTensor(fileStr.split(" +\\* +")[0], log, width, height);
        ;
        int[] sampleImageDimensions = sampleImage.getDimensions();
        sampleImage.freeRef();
        return RefUtil.get(RefArrays.stream(fileStr.split(" +\\* +"))
            .map(x -> getImageTensor(x, log, sampleImageDimensions[0], sampleImageDimensions[1]))
            .reduce((a, b) -> {
              try {
                return a.mapCoords(c -> Math.min(255, Math.max(0, a.get(c) * b.get(c))));
              } finally {
                b.freeRef();
                a.freeRef();
              }
            }));
      } else if (fileStr.trim().toLowerCase().equals("plasma")) {
        return new Plasma().paint(width, height);
      } else {
        Tensor dims = new Tensor(width, height, 3);
        try {
          if (fileStr.trim().toLowerCase().equals("noise")) {
            return dims.map(x -> FastRandom.INSTANCE.random() * 100);
          } else if (fileStr.matches("\\-?\\d+(?:\\.\\d*)?(?:[eE]\\-?\\d+)?")) {
            double v = Double.parseDouble(fileStr);
            return dims.map(x -> v);
          }
        } finally {
          dims.freeRef();
        }
      }
    }
    Tensor tensor = getTensor(log, file);
    Tensor resized = Tensor.fromRGB(ImageUtil.resize(tensor.toImage(), width, height));
    tensor.freeRef();
    return resized;
  }

  @Nonnull
  public static Tensor[] getTensors(@Nonnull NotebookOutput log, @Nonnull CharSequence file) {
    String fileStr = file.toString();
    try {
      String uploadPrefix = "upload:";
      if (fileStr.trim().toLowerCase().startsWith(uploadPrefix)) {
        String key = fileStr.substring(uploadPrefix.length());
        return (Tensor[]) MarkdownNotebookOutput.uploadCache.computeIfAbsent(key, (RefFunction) k -> {
          try {
            RefArrayList tensors = new RefArrayList<>();
            while (true) {
              OptionalUploadImageQuery uploadImageQuery = new OptionalUploadImageQuery(k, log);
              Optional optionalFile = uploadImageQuery.print().get();
              if (optionalFile.isPresent()) {
                tensors.add(Tensor.fromRGB(ImageIO.read(optionalFile.get())));
              } else {
                break;
              }
            }
            Tensor[] array = tensors.toArray(new Tensor[]{});
            tensors.freeRef();
            return array;
          } catch (IOException e) {
            throw Util.throwException(e);
          }
        });
      } else if (fileStr.startsWith("http")) {
        BufferedImage read = ImageIO.read(new URL(fileStr));
        if (null == read)
          throw new IllegalArgumentException("Error reading " + file);
        return new Tensor[]{Tensor.fromRGB(read)};
      } else {
        FileSystem fileSystem = getFileSystem(fileStr);
        Path path = new Path(fileStr);
        if (!fileSystem.exists(path))
          throw new IllegalArgumentException("Not Found: " + path);
        RefArrayList tensors = new RefArrayList<>();
        RemoteIterator iterator = fileSystem.listFiles(path, true);
        while (iterator.hasNext()) {
          LocatedFileStatus locatedFileStatus = iterator.next();
          try (FSDataInputStream open = fileSystem.open(locatedFileStatus.getPath())) {
            byte[] bytes = IOUtils.toByteArray(open);
            try (ByteArrayInputStream in = new ByteArrayInputStream(bytes)) {
              tensors.add(Tensor.fromRGB(ImageIO.read(in)));
            }
          }
        }
        Tensor[] array = tensors.toArray(new Tensor[]{});
        tensors.freeRef();
        return array;
      }
    } catch (Throwable e) {
      throw new RuntimeException("Error reading " + file, e);
    }
  }

  @Nonnull
  public static Tensor getTensor(@Nonnull NotebookOutput log, @Nonnull CharSequence file) {
    String fileStr = file.toString();
    if (fileStr.trim().toLowerCase().startsWith("upload:")) {
      String key = fileStr.substring("upload:".length());
      return (Tensor) MarkdownNotebookOutput.uploadCache.computeIfAbsent(key, (RefFunction) k -> {
        try {
          UploadImageQuery uploadImageQuery = new UploadImageQuery(k, log);
          return Tensor.fromRGB(ImageIO.read(uploadImageQuery.print().get()));
        } catch (IOException e) {
          throw Util.throwException(e);
        }
      });
    } else if (fileStr.startsWith("http")) {
      try {
        BufferedImage read = ImageIO.read(new URL(fileStr));
        if (null == read)
          throw new IllegalArgumentException("Error reading " + file);
        return Tensor.fromRGB(read);
      } catch (Throwable e) {
        throw new RuntimeException("Error reading " + file, e);
      }
    }
    FileSystem fileSystem = getFileSystem(fileStr);
    Path path = new Path(fileStr);
    try {
      if (!fileSystem.exists(path))
        throw new IllegalArgumentException("Not Found: " + path);
      try (FSDataInputStream open = fileSystem.open(path)) {
        byte[] bytes = IOUtils.toByteArray(open);
        try (ByteArrayInputStream in = new ByteArrayInputStream(bytes)) {
          return Tensor.fromRGB(ImageIO.read(in));
        }
      }
    } catch (Throwable e) {
      throw new RuntimeException("Error reading " + file, e);
    }
  }

  public static FileSystem getFileSystem(@Nonnull final CharSequence file) {
    Configuration conf = getHadoopConfig();
    FileSystem fileSystem;
    try {
      fileSystem = FileSystem.get(new Path(file.toString()).toUri(), conf);
    } catch (IOException e) {
      throw Util.throwException(e);
    }
    return fileSystem;
  }

  @Nonnull
  public static int[][] getIndexMap(@Nonnull final SimpleConvolutionLayer layer) {
    int[] kernelDimensions = layer.getKernelDimensions();
    layer.freeRef();
    double b = Math.sqrt(kernelDimensions[2]);
    int h = kernelDimensions[1];
    int w = kernelDimensions[0];
    int l = (int) (w * h * b);
    return IntStream.range(0, (int) b).mapToObj(i -> {
      return IntStream.range(0, l).map(j -> j + l * i).toArray();
    }).toArray(i -> new int[i][]);
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy