All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.github.mianalysis.mia.module.objects.detect.StarDistDetection Maven / Gradle / Ivy

Go to download

ModularImageAnalysis (MIA) is an ImageJ plugin which provides a modular framework for assembling image and object analysis workflows. Detected objects can be transformed, filtered, measured and related. Analysis workflows are batch-enabled by default, allowing easy processing of high-content datasets.

There is a newer version: 1.6.8
Show newest version
package io.github.mianalysis.mia.module.objects.detect;

import java.awt.Color;
import java.io.File;
import java.io.IOException;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.stream.IntStream;

import org.jetbrains.annotations.Nullable;
import org.scijava.Context;
import org.scijava.Priority;
import org.scijava.command.CommandModule;
import org.scijava.plugin.Plugin;

import de.csbdresden.stardist.Candidates;
import de.csbdresden.stardist.StarDist2DModel;
import de.csbdresden.stardist.Utils;
import ij.ImagePlus;
import ij.gui.PolygonRoi;
import io.github.mianalysis.mia.MIA;
import io.github.mianalysis.mia.module.Categories;
import io.github.mianalysis.mia.module.Category;
import io.github.mianalysis.mia.module.Module;
import io.github.mianalysis.mia.module.Modules;
import io.github.mianalysis.mia.module.images.configure.SetDisplayRange;
import io.github.mianalysis.mia.module.images.transform.ExtractSubstack;
import io.github.mianalysis.mia.module.visualise.overlays.AddObjectOutline;
import io.github.mianalysis.mia.object.Obj;
import io.github.mianalysis.mia.object.Objs;
import io.github.mianalysis.mia.object.Workspace;
import io.github.mianalysis.mia.object.coordinates.volume.SpatCal;
import io.github.mianalysis.mia.object.coordinates.volume.VolumeType;
import io.github.mianalysis.mia.object.image.Image;
import io.github.mianalysis.mia.object.parameters.BooleanP;
import io.github.mianalysis.mia.object.parameters.ChoiceP;
import io.github.mianalysis.mia.object.parameters.FilePathP;
import io.github.mianalysis.mia.object.parameters.InputImageP;
import io.github.mianalysis.mia.object.parameters.Parameters;
import io.github.mianalysis.mia.object.parameters.SeparatorP;
import io.github.mianalysis.mia.object.parameters.objects.OutputObjectsP;
import io.github.mianalysis.mia.object.parameters.text.DoubleP;
import io.github.mianalysis.mia.object.parameters.text.IntegerP;
import io.github.mianalysis.mia.object.refs.collections.ImageMeasurementRefs;
import io.github.mianalysis.mia.object.refs.collections.MetadataRefs;
import io.github.mianalysis.mia.object.refs.collections.ObjMeasurementRefs;
import io.github.mianalysis.mia.object.refs.collections.ObjMetadataRefs;
import io.github.mianalysis.mia.object.refs.collections.ParentChildRefs;
import io.github.mianalysis.mia.object.refs.collections.PartnerRefs;
import io.github.mianalysis.mia.object.system.Status;
import io.github.mianalysis.mia.object.units.TemporalUnit;
import io.github.mianalysis.mia.process.ColourFactory;
import io.github.mianalysis.mia.process.exceptions.IntegerOverflowException;
import net.imagej.Dataset;
import net.imagej.DatasetService;
import net.imagej.DefaultDataset;
import net.imagej.ImageJ;
import net.imagej.ImgPlus;
import net.imagej.axis.Axes;
import net.imagej.axis.AxisType;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Pair;
import net.imglib2.util.ValuePair;
import net.imglib2.view.IntervalView;
import net.imglib2.view.Views;


/**
* Implements the StarDist plugin to detect objects.  For more information on StarDist please see https://imagej.net/plugins/stardist.
*/
@Plugin(type = Module.class, priority = Priority.LOW, visible = true)
public class StarDistDetection extends Module {

	/**
	* 
	*/
    public static final String INPUT_SEPARATOR = "Image input/object output";

	/**
	* 
	*/
    public static final String INPUT_IMAGE = "Input image";

	/**
	* 
	*/
    public static final String OUTPUT_OBJECTS = "Output objects";


	/**
	* 
	*/
    public static final String MODEL_SEPARATOR = "Model settings";

	/**
	* 
	*/
    public static final String MODEL_MODE = "Model mode";

	/**
	* 
	*/
    public static final String MODEL_NAME = "Model name";

	/**
	* 
	*/
    public static final String MODEL_PATH = "Model path";


	/**
	* 
	*/
    public static final String PREDICTION_SEPARATOR = "Prediction settings";

	/**
	* 
	*/
    public static final String NORMALISE_INPUT = "Normalise image";

	/**
	* 
	*/
    public static final String PERCENTILE_LOW = "Percentile low";

	/**
	* 
	*/
    public static final String PERCENTILE_HIGH = "Percentile high";

	/**
	* 
	*/
    public static final String NUMBER_OF_TILES = "Number of tiles";


	/**
	* 
	*/
    public static final String NMS_SEPARATOR = "NMS post-processing";

	/**
	* 
	*/
    public static final String PROB_THRESHOLD = "Probability threshold";

	/**
	* 
	*/
    public static final String OVERLAP_THRESHOLD = "Overlap threshold";

	/**
	* 
	*/
    public static final String BOUNDARY_EXCLUSION = "Boundary exclusion";

    // From
    // https://github.com/stardist/stardist-imagej/blob/master/src/main/java/de/csbdresden/stardist/StarDist2DModel.java
    // (Accessed 2021-11-10)
    static final String MODEL_DSB2018_HEAVY_AUGMENTATION = "Versatile (fluorescent nuclei)";
    static final String MODEL_DSB2018_PAPER = "DSB 2018 (from StarDist 2D paper)";
    static final String MODEL_HE_HEAVY_AUGMENTATION = "Versatile (H&E nuclei)";
    static final String MODEL_DEFAULT = MODEL_DSB2018_HEAVY_AUGMENTATION;

    static final Map MODELS = new LinkedHashMap();
    static {
        MODELS.put(MODEL_DSB2018_PAPER,
                new StarDist2DModel(StarDist2DModel.class.getClassLoader().getResource("models/2D/dsb2018_paper.zip"),
                        0.417819, 0.5, 8, 48));
        MODELS.put(MODEL_DSB2018_HEAVY_AUGMENTATION,
                new StarDist2DModel(
                        StarDist2DModel.class.getClassLoader().getResource("models/2D/dsb2018_heavy_augment.zip"),
                        0.479071, 0.3, 16, 96));
        MODELS.put(MODEL_HE_HEAVY_AUGMENTATION,
                new StarDist2DModel(
                        StarDist2DModel.class.getClassLoader().getResource("models/2D/he_heavy_augment.zip"), 0.692478,
                        0.3, 16, 96));
    }

    public interface ModelModes {
        String FROM_FILE = "From file";
        String PRE_DEFINED = "Pre-defined";

        String[] ALL = new String[] { FROM_FILE, PRE_DEFINED };

    }

    public interface ModelNames {
        String VERSATILE_FLUORESCENT_NUCLEI = "Versatile (fluorescent nuclei)";
        String VERSATILE_HE_NUCLEI = "Versatile (H&E nuclei)";
        String DSB_2018 = "DSB 2018 (from StarDist 2D paper)";

        String[] ALL = new String[] { VERSATILE_FLUORESCENT_NUCLEI, VERSATILE_HE_NUCLEI, DSB_2018 };

    }

    public StarDistDetection(Modules modules) {
        super("StarDist detection", modules);

    }

    @Override
    public Category getCategory() {
        return Categories.OBJECTS_DETECT;
    }

    @Override
    public String getVersionNumber() {
        return "1.0.0";
    }
    
    public String getDescription() {
        return "Implements the StarDist plugin to detect objects.  For more information on StarDist please see https://imagej.net/plugins/stardist.";
    }

    File getModelFile(@Nullable HashMap paramsCNN, Workspace workspace) {
        switch ((String) parameters.getValue(MODEL_MODE, workspace)) {
            case ModelModes.FROM_FILE:
                return new File((String) parameters.getValue(MODEL_PATH, workspace));

            case ModelModes.PRE_DEFINED:
                StarDist2DModel model = null;
                switch ((String) parameters.getValue(MODEL_NAME, workspace)) {
                    case ModelNames.DSB_2018:
                        model = MODELS.get(MODEL_DSB2018_PAPER);
                        break;
                    case ModelNames.VERSATILE_FLUORESCENT_NUCLEI:
                        model = MODELS.get(MODEL_DSB2018_HEAVY_AUGMENTATION);
                        break;
                    case ModelNames.VERSATILE_HE_NUCLEI:
                        model = MODELS.get(MODEL_HE_HEAVY_AUGMENTATION);
                        break;
                }

                if (model.canGetFile()) {
                    try {
                        paramsCNN.put("blockMultiple", model.sizeDivBy);
                        paramsCNN.put("overlap", model.tileOverlap);
                        return model.getFile();
                    } catch (IOException e) {
                        MIA.log.writeWarning("Can't get StarDist model");
                        return null;
                    }
                } else {
                    MIA.log.writeWarning("Can't get StarDist model");
                    return null;
                }
        }

        return null;

    }

    Pair splitPrediction(final Dataset prediction, final DatasetService dataset) {
        RandomAccessibleInterval predictionRAI = (RandomAccessibleInterval) prediction
                .getImgPlus();
        LinkedHashSet predAxes = Utils.orderedAxesSet(prediction);

        final int predChannelDim = IntStream.range(0, predAxes.size())
                .filter(d -> prediction.axis(d).type() == Axes.CHANNEL).findFirst().getAsInt();
        final long[] predStart = predAxes.stream().mapToLong(axis -> {
            return axis == Axes.CHANNEL ? 1L : 0L;
        }).toArray();
        final long[] predSize = predAxes.stream().mapToLong(axis -> {
            return axis == Axes.CHANNEL ? prediction.dimension(axis) - 1 : prediction.dimension(axis);
        }).toArray();

        IntervalView probRAI = Views.hyperSlice(predictionRAI, predChannelDim, 0);
        IntervalView distRAI = Views.offsetInterval(predictionRAI, predStart, predSize);

        Dataset probDS = Utils.raiToDataset(dataset, "Probability/Score Image", probRAI,
                predAxes.stream().filter(axis -> axis != Axes.CHANNEL));
        Dataset distDS = Utils.raiToDataset(dataset, "Distance Image", distRAI, predAxes);

        return new ValuePair<>(probDS, distDS);

    }

    @Override
    protected Status process(Workspace workspace) {
        // Getting input image
        String inputImageName = parameters.getValue(INPUT_IMAGE, workspace);
        String outputObjectsName = parameters.getValue(OUTPUT_OBJECTS, workspace);

        Image image = workspace.getImages().get(inputImageName);
        ImagePlus ipl = image.getImagePlus();
        SpatCal cal = SpatCal.getFromImage(ipl);
        int nFrames = ipl.getNFrames();
        double frameInterval = ipl.getCalibration().frameInterval;

        // Creating output object collection
        Objs outputObjects = new Objs(outputObjectsName, cal, nFrames, frameInterval, TemporalUnit.getOMEUnit());

        ImageJ ij = new ImageJ();
        Context context = MIA.getIJService().context();
        DatasetService datasetService = (DatasetService) MIA.getIJService().context()
                .service("net.imagej.DatasetService");

        // Initialising parameter sets
        HashMap paramsCNN = new HashMap<>();
        paramsCNN.put("normalizeInput", parameters.getValue(NORMALISE_INPUT, workspace));
        paramsCNN.put("percentileBottom", parameters.getValue(PERCENTILE_LOW, workspace));
        paramsCNN.put("percentileTop", parameters.getValue(PERCENTILE_HIGH, workspace));
        paramsCNN.put("clip", false);
        paramsCNN.put("nTiles", parameters.getValue(NUMBER_OF_TILES, workspace));
        paramsCNN.put("blockMultiple", 64);
        paramsCNN.put("overlap", 64);
        paramsCNN.put("batchSize", 1);
        paramsCNN.put("showProgressDialog", false);

        File modelFile = getModelFile(paramsCNN, workspace);
        if (modelFile == null)
            return Status.FAIL;
        paramsCNN.put("modelFile", modelFile);

        HashMap paramsNMS = new HashMap<>();
        paramsNMS.put("probThresh", parameters.getValue(PROB_THRESHOLD, workspace));
        paramsNMS.put("nmsThresh", parameters.getValue(OVERLAP_THRESHOLD, workspace));
        paramsNMS.put("excludeBoundary", parameters.getValue(BOUNDARY_EXCLUSION, workspace));
        paramsNMS.put("roiPosition", "Stack");
        paramsNMS.put("verbose", false);
        paramsNMS.put("outputType", "Polygons");

        int count = 0;
        int total = ipl.getNFrames() * ipl.getNSlices();
        for (int t = 0; t < ipl.getNFrames(); t++) {
            for (int z = 0; z < ipl.getNSlices(); z++) {
                Image subs = ExtractSubstack.extractSubstack(image, "Subs", "1-end", String.valueOf(z + 1),
                        String.valueOf(t + 1));
                ImgPlus img = subs.getImgPlus();
                DefaultDataset dataset = new DefaultDataset(context, img);
                paramsCNN.put("input", dataset);

                try {
                    Future futureCNN = ij.command().run(
                            de.csbdresden.csbdeep.commands.GenericNetwork.class,
                            false, paramsCNN);
                    Dataset prediction = (Dataset) futureCNN.get().getOutput("output");

                    Pair probAndDist = splitPrediction(prediction, datasetService);
                    Dataset probDS = probAndDist.getA();
                    Dataset distDS = probAndDist.getB();
                    paramsNMS.put("prob", probDS);
                    paramsNMS.put("dist", distDS);

                    Future futureNMS = ij.command().run(de.csbdresden.stardist.StarDist2DNMS.class,
                            false,
                            paramsNMS);
                    Candidates polygons = (Candidates) futureNMS.get().getOutput("polygons");
                    List indices = polygons.getWinner();

                    for (Integer idx : indices) {
                        PolygonRoi polygon = polygons.getPolygonRoi(idx);
                        Obj obj = outputObjects.createAndAddNewObject(VolumeType.QUADTREE);
                        try {
                            obj.addPointsFromRoi(polygon, z);
                        } catch (IntegerOverflowException e) {
                        }
                        obj.setT(t);
                    }
                } catch (InterruptedException | ExecutionException e) {
                    return Status.FAIL;
                }

                Module.writeProgressStatus(++count, total, "slices", "StarDist");

            }
        }

        workspace.addObjects(outputObjects);

        if (showOutput) {
            ImagePlus overlayIpl = image.getImagePlus().duplicate();

            HashMap hues = ColourFactory.getIDHues(outputObjects, true);
            HashMap colours = ColourFactory.getColours(hues);
            AddObjectOutline.addOverlay(overlayIpl, outputObjects, 1, 1, colours, false, true);
            SetDisplayRange.setDisplayRangeAuto(overlayIpl, SetDisplayRange.CalculationModes.FAST,
                    new double[] { 0, 0 }, new boolean[] { true, true }, null);
            overlayIpl.show();
        }

        return Status.PASS;

    }

    @Override
    protected void initialiseParameters() {
        parameters.add(new SeparatorP(INPUT_SEPARATOR, this));
        parameters.add(new InputImageP(INPUT_IMAGE, this));
        parameters.add(new OutputObjectsP(OUTPUT_OBJECTS, this));

        parameters.add(new SeparatorP(MODEL_SEPARATOR, this));
        parameters.add(new ChoiceP(MODEL_MODE, this, ModelModes.PRE_DEFINED, ModelModes.ALL));
        parameters.add(new ChoiceP(MODEL_NAME, this, ModelNames.VERSATILE_FLUORESCENT_NUCLEI, ModelNames.ALL));
        parameters.add(new FilePathP(MODEL_PATH, this));

        parameters.add(new SeparatorP(PREDICTION_SEPARATOR, this));
        parameters.add(new BooleanP(NORMALISE_INPUT, this, true));
        parameters.add(new DoubleP(PERCENTILE_LOW, this, 0.0));
        parameters.add(new DoubleP(PERCENTILE_HIGH, this, 95.0));
        parameters.add(new IntegerP(NUMBER_OF_TILES, this, 1));

        parameters.add(new SeparatorP(NMS_SEPARATOR, this));
        parameters.add(new DoubleP(PROB_THRESHOLD, this, 0.4));
        parameters.add(new DoubleP(OVERLAP_THRESHOLD, this, 0.4));
        parameters.add(new IntegerP(BOUNDARY_EXCLUSION, this, 2));

    }

    @Override
    public Parameters updateAndGetParameters() {
        Parameters returnedParameters = new Parameters();

        returnedParameters.add(parameters.getParameter(INPUT_SEPARATOR));
        returnedParameters.add(parameters.getParameter(INPUT_IMAGE));
        returnedParameters.add(parameters.getParameter(OUTPUT_OBJECTS));

        returnedParameters.add(parameters.getParameter(MODEL_SEPARATOR));
        returnedParameters.add(parameters.getParameter(MODEL_MODE));
        switch ((String) parameters.getValue(MODEL_MODE, null)) {
            case ModelModes.FROM_FILE:
                returnedParameters.add(parameters.getParameter(MODEL_PATH));
                break;
            case ModelModes.PRE_DEFINED:
                returnedParameters.add(parameters.getParameter(MODEL_NAME));
                break;
        }

        returnedParameters.add(parameters.getParameter(PREDICTION_SEPARATOR));
        returnedParameters.add(parameters.getParameter(NORMALISE_INPUT));
        returnedParameters.add(parameters.getParameter(PERCENTILE_LOW));
        returnedParameters.add(parameters.getParameter(PERCENTILE_HIGH));
        returnedParameters.add(parameters.getParameter(NUMBER_OF_TILES));

        returnedParameters.add(parameters.getParameter(NMS_SEPARATOR));
        returnedParameters.add(parameters.getParameter(PROB_THRESHOLD));
        returnedParameters.add(parameters.getParameter(OVERLAP_THRESHOLD));
        returnedParameters.add(parameters.getParameter(BOUNDARY_EXCLUSION));

        return returnedParameters;

    }

    @Override
    public ImageMeasurementRefs updateAndGetImageMeasurementRefs() {
        return null;
    }

    @Override
    public ObjMeasurementRefs updateAndGetObjectMeasurementRefs() {
        return null;
    }

    @Override
    public ObjMetadataRefs updateAndGetObjectMetadataRefs() {  
	return null; 
    }

    @Override
    public MetadataRefs updateAndGetMetadataReferences() {
        return null;
    }

    @Override
    public ParentChildRefs updateAndGetParentChildRefs() {
        return null;
    }

    @Override
    public PartnerRefs updateAndGetPartnerRefs() {
        return null;
    }

    @Override
    public boolean verify() {
        return true;
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy