All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.github.mianalysis.mia.module.images.process.ApplyDeepImageJModel Maven / Gradle / Ivy

Go to download

ModularImageAnalysis (MIA) is an ImageJ plugin which provides a modular framework for assembling image and object analysis workflows. Detected objects can be transformed, filtered, measured and related. Analysis workflows are batch-enabled by default, allowing easy processing of high-content datasets.

There is a newer version: 1.6.12
Show newest version
package io.github.mianalysis.mia.module.images.process;

import org.scijava.Priority;
import org.scijava.plugin.Plugin;

import deepimagej.DeepImageJ;
import ij.IJ;
import ij.ImagePlus;
import ij.ImageStack;
import ij.measure.Calibration;
import io.github.mianalysis.mia.MIA;
import io.github.mianalysis.mia.module.Categories;
import io.github.mianalysis.mia.module.Category;
import io.github.mianalysis.mia.module.Module;
import io.github.mianalysis.mia.module.Modules;
import io.github.mianalysis.mia.object.Workspace;
import io.github.mianalysis.mia.object.image.Image;
import io.github.mianalysis.mia.object.image.ImageFactory;
import io.github.mianalysis.mia.object.parameters.BooleanP;
import io.github.mianalysis.mia.object.parameters.ChoiceP;
import io.github.mianalysis.mia.object.parameters.InputImageP;
import io.github.mianalysis.mia.object.parameters.OutputImageP;
import io.github.mianalysis.mia.object.parameters.ParameterState;
import io.github.mianalysis.mia.object.parameters.Parameters;
import io.github.mianalysis.mia.object.parameters.SeparatorP;
import io.github.mianalysis.mia.object.parameters.text.MessageP;
import io.github.mianalysis.mia.object.parameters.text.StringP;
import io.github.mianalysis.mia.object.refs.collections.ImageMeasurementRefs;
import io.github.mianalysis.mia.object.refs.collections.MetadataRefs;
import io.github.mianalysis.mia.object.refs.collections.ObjMeasurementRefs;
import io.github.mianalysis.mia.object.refs.collections.ObjMetadataRefs;
import io.github.mianalysis.mia.object.refs.collections.ParentChildRefs;
import io.github.mianalysis.mia.object.refs.collections.PartnerRefs;
import io.github.mianalysis.mia.object.system.Status;
import io.github.mianalysis.mia.process.deepimagej.PrepareDeepImageJ;

/**
 * Uses DeepImageJ to run
 * Tensorflow and Pytorch models from the
 * BioImage Model Zoo. This module will
 * detect and run any models already installed in the active copy of Fiji.
 */
@Plugin(type = Module.class, priority = Priority.LOW, visible = true)
public class ApplyDeepImageJModel extends Module {

    /**
    * 
    */
    public static final String INPUT_SEPARATOR = "Image input/output";

    /**
     * Image from the workspace to apply deep learning model to.
     */
    public static final String INPUT_IMAGE = "Input image";

    /**
    * 
    */
    public static final String AXES_ORDER = "Axes order";

    /**
     * Final image generated by model, which will be stored in the workspace with
     * this name.
     */
    public static final String OUTPUT_IMAGE = "Output image";

    /**
    * 
    */
    public static final String MODEL_SEPARATOR = "Model controls";

    /**
     * Model to apply to input image. This can be any model currently installed in
     * MIA. When using MIA's GUI, the available modules will automatically appear as
     * options.
     */
    public static final String MODEL = "Model";

    /**
     * If post-processing routines are available for the chosen model this option
     * will be visible. Note: If pre-processing routines are available, these will
     * always be applied.
     */
    public static final String USE_POSTPROCESSING = "Use postprocessing";

    /**
    * 
    */
    public static final String PATCH_SIZE = "Patch size";

    private String currModelName = "";

    public interface Models {
        String[] ALL = PrepareDeepImageJ.getAvailableModels();

    }

    public interface FormatsBoth {
        String PYTORCH = "Pytorch";
        String TENSORFLOW = "Tensorflow";

        String[] ALL = new String[] { PYTORCH, TENSORFLOW };

    }

    public ApplyDeepImageJModel(Modules modules) {
        super("Apply DeepImageJ model", modules);
    }

    @Override
    public Category getCategory() {
        return Categories.IMAGES_PROCESS;
    }

    @Override
    protected Status process(Workspace workspace) {
        // Getting parameters
        String inputImageName = parameters.getValue(INPUT_IMAGE, workspace);
        String outputImageName = parameters.getValue(OUTPUT_IMAGE, workspace);
        String modelName = parameters.getValue(MODEL, workspace);
        // String preprocessing = parameters.getValue(PREPROCESSING, workspace);
        boolean usePostprocessing = parameters.getValue(USE_POSTPROCESSING, workspace);
        // String postprocessing = parameters.getValue(POSTPROCESSING, workspace);
        String patchSize = parameters.getValue(PATCH_SIZE, workspace);

        // Get input image
        Image inputImage = workspace.getImage(inputImageName);
        ImagePlus inputIpl = inputImage.getImagePlus();

        // Running deep learning model
        DeepImageJ model = PrepareDeepImageJ.getModel(modelName);

        // Updating pre and post processing options
        boolean usePreprocessing = true;
        if (PrepareDeepImageJ.getPreprocessings(modelName).length == 0)
            usePreprocessing = false;
        if (PrepareDeepImageJ.getPostprocessings(modelName).length == 0)
            usePostprocessing = false;

        String format = PrepareDeepImageJ.getFormats(modelName)[0];
        PrepareDeepImageJ pDIJ = new PrepareDeepImageJ();

        ImageStack inputIst = inputIpl.getStack();
        ImagePlus outputIpl = null;
        ImageStack outputIst = null;

        int count = 0;
        for (int z = 0; z < inputIpl.getNSlices(); z++) {
            for (int t = 0; t < inputIpl.getNFrames(); t++) {
                int inputIdx = inputIpl.getStackIndex(1, z + 1, t + 1);
                ImagePlus tempIpl = new ImagePlus("Temp", inputIst.getProcessor(inputIdx));
                ImagePlus tempOutputIpl = pDIJ.runModel(tempIpl, model, format, usePreprocessing, usePostprocessing,
                        patchSize);
                
                // If it hasn't already been created (i.e. this is the first slice), create output ImagePlus
                if (outputIpl == null) {
                    int width = tempOutputIpl.getWidth();
                    int height = tempOutputIpl.getHeight();
                    int nChannels = tempOutputIpl.getNChannels();
                    int nSlices = inputIpl.getNSlices();
                    int nFrames = inputIpl.getNFrames();

                    outputIpl = IJ.createHyperStack(outputImageName, width, height, nChannels, nSlices, nFrames, 32);

                    Calibration inputCal = inputIpl.getCalibration();
                    Calibration outputCal = new Calibration();
                    outputCal.pixelHeight = inputCal.pixelHeight;
                    outputCal.pixelWidth = inputCal.pixelWidth;
                    outputCal.pixelDepth = inputCal.pixelDepth;
                    outputCal.setUnit(inputCal.getUnits());
                    outputCal.setTimeUnit(inputCal.getTimeUnit());
                    outputCal.fps = inputCal.fps;
                    outputCal.frameInterval = inputCal.frameInterval;
                    outputIpl.setCalibration(outputCal);

                    outputIst = outputIpl.getStack();

                }

                ImageStack tempIst = tempOutputIpl.getStack();
                for (int c=0;c 0)
            returnedParameters.add(parameters.getParameter(USE_POSTPROCESSING));

        if (!currModelName.equals(modelName)) {
            // We don't know the actual image size, so creating a small one.
            ImagePlus testIpl = IJ.createHyperStack("Test", 10, 10, 1, 1, 1, 8);
            String patchSize = PrepareDeepImageJ.getOptimalPatch(modelName, testIpl);
            parameters.getParameter(PATCH_SIZE).setValue(patchSize);
        }
        returnedParameters.add(parameters.getParameter(PATCH_SIZE));

        currModelName = modelName;

        return returnedParameters;

    }

    @Override
    public ImageMeasurementRefs updateAndGetImageMeasurementRefs() {
        return null;
    }

    @Override
    public ObjMeasurementRefs updateAndGetObjectMeasurementRefs() {
        return null;
    }

    @Override
    public ObjMetadataRefs updateAndGetObjectMetadataRefs() {
        return null;
    }

    @Override
    public MetadataRefs updateAndGetMetadataReferences() {
        return null;
    }

    @Override
    public ParentChildRefs updateAndGetParentChildRefs() {
        return null;
    }

    @Override
    public PartnerRefs updateAndGetPartnerRefs() {
        return null;
    }

    @Override
    public boolean verify() {
        return true;
    }

    @Override
    public String getVersionNumber() {
        return "1.0.0";
    }

    @Override
    public String getDescription() {
        return "Uses DeepImageJ to run Tensorflow and Pytorch models from the BioImage Model Zoo.  This module will detect and run any models already installed in the active copy of Fiji.  For more information on DeepImageJ and the BioImage Model Zoo, please go to DeepImageJ and BioImage Model Zoo.";
    }

    protected void addParameterDescriptions() {
        parameters.get(INPUT_IMAGE).setDescription("Image from the workspace to apply deep learning model to.");
        parameters.get(OUTPUT_IMAGE).setDescription(
                "Final image generated by model, which will be stored in the workspace with this name.");
        parameters.get(MODEL).setDescription(
                "Model to apply to input image.  This can be any model currently installed in MIA.  When using MIA's GUI, the available modules will automatically appear as options.");
        parameters.get(USE_POSTPROCESSING).setDescription(
                "If post-processing routines are available for the chosen model this option will be visible.  Note: If pre-processing routines are available, these will always be applied.");

    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy