All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.github.mianalysis.mia.module.objects.detect.manualextensions.SAMJExtension Maven / Gradle / Ivy

Go to download

ModularImageAnalysis (MIA) is an ImageJ plugin which provides a modular framework for assembling image and object analysis workflows. Detected objects can be transformed, filtered, measured and related. Analysis workflows are batch-enabled by default, allowing easy processing of high-content datasets.

There is a newer version: 1.6.12
Show newest version
package io.github.mianalysis.mia.module.objects.detect.manualextensions;

import java.awt.Color;
import java.awt.GridBagConstraints;
import java.awt.GridBagLayout;
import java.awt.Insets;
import java.awt.Point;
import java.awt.Polygon;
import java.awt.Rectangle;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.awt.event.MouseEvent;
import java.awt.event.MouseListener;
import java.io.IOException;
import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;

import javax.swing.JCheckBox;
import javax.swing.JLabel;
import javax.swing.JPanel;
import javax.swing.JSeparator;

import org.apache.commons.compress.archivers.ArchiveException;
import org.scijava.Priority;
import org.scijava.plugin.Plugin;

import com.drew.lang.annotations.Nullable;

import ai.nets.samj.install.EfficientSamEnvManager;
import ai.nets.samj.install.Sam2EnvManager;
import ai.nets.samj.install.SamEnvManagerAbstract;
import ai.nets.samj.models.AbstractSamJ;
import ai.nets.samj.models.EfficientSamJ;
import ai.nets.samj.models.Sam2;
import ij.ImagePlus;
import ij.gui.PolygonRoi;
import ij.gui.Roi;
import io.bioimage.modelrunner.apposed.appose.MambaInstallException;
import io.github.mianalysis.mia.MIA;
import io.github.mianalysis.mia.module.Module;
import io.github.mianalysis.mia.module.objects.detect.ManuallyIdentifyObjects;
import io.github.mianalysis.mia.module.objects.detect.extensions.ManualExtension;
import io.github.mianalysis.mia.object.Workspace;
import io.github.mianalysis.mia.object.image.Image;
import io.github.mianalysis.mia.object.parameters.BooleanP;
import io.github.mianalysis.mia.object.parameters.ChoiceP;
import io.github.mianalysis.mia.object.parameters.FolderPathP;
import io.github.mianalysis.mia.object.parameters.Parameters;
import io.github.mianalysis.mia.object.parameters.SeparatorP;
import io.github.mianalysis.mia.object.system.Status;
import io.github.mianalysis.mia.process.selectors.ObjectSelector;
import net.imagej.ImageJ;
import net.imagej.patcher.LegacyInjector;

@Plugin(type = ManualExtension.class, priority = Priority.LOW, visible = true)
public class SAMJExtension extends ManualExtension implements MouseListener {
    public static final String SAMJ_SEPARATOR = "Segment Anything (SAMJ) controls";
    public static final String USE_SAM = "Enable Segment Anything";
    public static final String MODEL = "Segment Anything model";
    public static final String ENVIRONMENT_PATH_MODE = "Environment path mode";
    public static final String ENVIRONMENT_PATH = "Environment path";
    public static final String INSTALL_IF_MISSING = "Install model if missing";
    public static final String USE_ENCODING_IF_AVAILABLE = "Use encoding if available";
    public static final String STORE_ENCODING = "Store encoding";
    public static final String TERMINATE_AFTER_ENCODING = "Terminate after encoding";

    HashMap encodings = new HashMap<>();

    public interface SAMJModels {
        public String EFFICIENT = "Efficient";
        public String LARGE = "Large";
        public String SMALL = "Small";
        public String TINY = "Tiny";

        public String[] ALL = new String[] { EFFICIENT, TINY };

    }

    public interface EnvironmentPathModes {
        String DEFAULT = "Default";
        String SPECIFIC = "Specific path";

        String[] ALL = new String[] { DEFAULT, SPECIFIC };

    }

    protected Workspace workspace;
    protected AbstractSamJ samJ = null;
    protected ImagePlus displayIpl;
    protected boolean useSAM = true;

    protected String prevModelName = "";
    protected String prevEnvionmentPathMode = "";
    protected String prevEnvironmentPath = "";

    public static void main(String[] args) throws Exception {
        try {
            LegacyInjector.preinit();
        } catch (Exception e) {
        }

        try {
            new ij.ImageJ();
            new ImageJ().command().run("io.github.mianalysis.mia.MIA_", false);

        } catch (Exception e) {
            MIA.log.writeError(e);
        }
    }

    public SAMJExtension(Module module) {
        super(module);
    }

    protected void initialiseSAMJ() {
        String modelName = parameters.getValue(MODEL, workspace);
        prevModelName = modelName;
        String envionmentPathMode = parameters.getValue(ENVIRONMENT_PATH_MODE, workspace);
        prevEnvionmentPathMode = envionmentPathMode;
        String environmentPath = parameters.getValue(ENVIRONMENT_PATH, workspace) + "/";
        prevEnvironmentPath = environmentPath;
        boolean installIfMissing = parameters.getValue(INSTALL_IF_MISSING, workspace);

        AbstractSamJ.MAX_ENCODED_AREA_RS = 3000;
        AbstractSamJ.MAX_ENCODED_SIDE = 3000;

        try {
            switch (envionmentPathMode) {
                case EnvironmentPathModes.DEFAULT:
                    environmentPath = SamEnvManagerAbstract.DEFAULT_DIR;
                    break;
            }

            SamEnvManagerAbstract manager;
            switch (modelName) {
                case SAMJModels.EFFICIENT:
                default:
                    manager = EfficientSamEnvManager.create(environmentPath);
                    break;
                case SAMJModels.LARGE:
                    manager = Sam2EnvManager.create(environmentPath, "large");
                    break;
                case SAMJModels.SMALL:
                    manager = Sam2EnvManager.create(environmentPath, "small");
                    break;
                case SAMJModels.TINY:
                    manager = Sam2EnvManager.create(environmentPath, "tiny");
                    break;
            }

            if (!manager.checkEverythingInstalled())
                if (installIfMissing) {
                    module.writeStatus("Installing SAM model");
                    manager.installEverything();
                } else {
                    MIA.log.writeWarning(modelName + " not available.  Please install manually or enable \""
                            + INSTALL_IF_MISSING + "\" parameter.");
                }

            switch (modelName) {
                case SAMJModels.EFFICIENT:
                default:
                    samJ = EfficientSamJ.initializeSam(manager);
                    break;
                case SAMJModels.LARGE:
                case SAMJModels.SMALL:
                case SAMJModels.TINY:
                    samJ = Sam2.initializeSam(manager);
                    break;
            }

        } catch (IOException | RuntimeException | InterruptedException | ArchiveException | URISyntaxException
                | MambaInstallException e) {
            e.printStackTrace();
        }
    }

    @Override
    public Status initialiseBeforeImageShown(Workspace workspace) {
        this.workspace = workspace;

        // Getting parameters
        boolean useSAM = parameters.getValue(USE_SAM, workspace);
        if (!useSAM)
            return Status.PASS;

        String modelName = parameters.getValue(MODEL, workspace);
        String envionmentPathMode = parameters.getValue(ENVIRONMENT_PATH_MODE, workspace);
        String environmentPath = parameters.getValue(ENVIRONMENT_PATH, workspace) + "/";
        boolean useEncoding = parameters.getValue(USE_ENCODING_IF_AVAILABLE, workspace);
        boolean storeEncoding = parameters.getValue(STORE_ENCODING, workspace);
        boolean terminateAfterEncoding = parameters.getValue(TERMINATE_AFTER_ENCODING, workspace);

        if (samJ == null | !modelName.equals(prevModelName) | !envionmentPathMode.equals(prevEnvionmentPathMode)
                | !environmentPath.equals(prevEnvironmentPath))
            initialiseSAMJ();

        String path = workspace.getMetadata().getFile().getAbsolutePath();
        if (useEncoding && encodings.containsKey(path))
            try {
                samJ.selectEncoding(encodings.get(path));
            } catch (Exception e) {
                MIA.log.writeError(e);
            }
        else {
            try {
                String imageName = module.getParameterValue(ManuallyIdentifyObjects.INPUT_IMAGE, workspace);
                Image image = workspace.getImage(imageName);
                samJ.setImage(image.getImgPlus());
            } catch (IOException | RuntimeException | InterruptedException e) {
                MIA.log.writeError(e);
            }
        }

        if (storeEncoding) {
            try {
                encodings.put(path, samJ.persistEncoding());

                if (terminateAfterEncoding)
                    return Status.TERMINATE_SILENT;

            } catch (IOException | InterruptedException e) {
                MIA.log.writeError(e);
            }
        }

        return Status.PASS;

    }

    @Override
    public Status initialiseAfterImageShown(@Nullable ImagePlus displayIpl) {
        // Getting parameters
        boolean useSAM = parameters.getValue(USE_SAM, workspace);
        if (!useSAM)
            return Status.PASS;

        this.displayIpl = displayIpl;
        displayIpl.getCanvas().addMouseListener(this);

        return Status.PASS;

    }

    @Override
    public Parameters initialiseParameters() {
        parameters.add(new SeparatorP(SAMJ_SEPARATOR, module));
        parameters.add(new BooleanP(USE_SAM, module, false));
        parameters.add(new ChoiceP(MODEL, module, SAMJModels.EFFICIENT, SAMJModels.ALL));
        parameters.add(
                new ChoiceP(ENVIRONMENT_PATH_MODE, module, EnvironmentPathModes.DEFAULT, EnvironmentPathModes.ALL));
        parameters.add(new FolderPathP(ENVIRONMENT_PATH, module));
        parameters.add(new BooleanP(INSTALL_IF_MISSING, module, true));
        parameters.add(new BooleanP(USE_ENCODING_IF_AVAILABLE, module, true));
        parameters.add(new BooleanP(STORE_ENCODING, module, true));
        parameters.add(new BooleanP(TERMINATE_AFTER_ENCODING, module, false));

        return parameters;

    }

    @Override
    public Parameters updateAndGetParameters() {
        Parameters returnedParameters = new Parameters();

        returnedParameters.add(parameters.getParameter(SAMJ_SEPARATOR));
        returnedParameters.add(parameters.getParameter(USE_SAM));

        if ((boolean) parameters.getValue(USE_SAM, workspace)) {
            returnedParameters.add(parameters.getParameter(MODEL));
            returnedParameters.add(parameters.getParameter(ENVIRONMENT_PATH_MODE));

            switch ((String) parameters.getValue(ENVIRONMENT_PATH_MODE, workspace)) {
                case EnvironmentPathModes.SPECIFIC:
                    returnedParameters.add(parameters.getParameter(ENVIRONMENT_PATH));
                    break;
            }

            returnedParameters.add(parameters.getParameter(INSTALL_IF_MISSING));
            returnedParameters.add(parameters.getParameter(USE_ENCODING_IF_AVAILABLE));
            returnedParameters.add(parameters.getParameter(STORE_ENCODING));
            returnedParameters.add(parameters.getParameter(TERMINATE_AFTER_ENCODING));

        }

        return returnedParameters;

    }

    @Override
    public boolean skipAutoAccept() {
        if ((boolean) parameters.getValue(USE_SAM, workspace) && useSAM)
            return true;
        else
            return false;
    }

    @Override
    public JPanel getControlPanel() {
        if (!(boolean) parameters.getValue(USE_SAM, workspace))
            return null;

        JPanel panel = new JPanel(new GridBagLayout());
        GridBagConstraints c = new GridBagConstraints();

        c.gridwidth = 1;
        c.gridx = 0;
        c.anchor = GridBagConstraints.LINE_START;
        c.fill = GridBagConstraints.BOTH;
        c.insets = new Insets(0, 5, 5, 0);

        JLabel useSAMLabel = new JLabel("Use Segment Anything");
        useSAMLabel.setHorizontalAlignment(JLabel.RIGHT);
        panel.add(useSAMLabel, c);

        JCheckBox useSAMCheckBox = new JCheckBox();
        useSAMCheckBox.setEnabled(true);
        useSAMCheckBox.setSelected(true);
        useSAMCheckBox.addActionListener(new ActionListener() {
            @Override
            public void actionPerformed(ActionEvent e) {
                useSAM = useSAMCheckBox.isSelected();
            }
        });
        c.gridx++;
        panel.add(useSAMCheckBox, c);

        JSeparator separator = new JSeparator();
        separator.setForeground(new Color(0, 0, 0, 0));
        c.gridx++;
        c.gridwidth = 4;
        c.weightx = 1;
        panel.add(separator, c);

        return panel;

    }

    @Override
    public void mouseClicked(MouseEvent e) {

    }

    @Override
    public void mousePressed(MouseEvent e) {
    }

    @Override
    public void mouseReleased(MouseEvent e) {
        if (!useSAM)
            return;

        if (displayIpl.getRoi() == null)
            return;

        if (samJ == null) {
            MIA.log.writeWarning("SAMJ not ready yet");
            return;
        }

        Roi roi = displayIpl.getRoi();

        List polygons = null;
        switch (roi.getType()) {
            case Roi.FREELINE:
            case Roi.LINE:
            case Roi.POINT:
            case Roi.POLYLINE:
            default:
                polygons = getPolygonFromPoints(roi);
                break;

            case Roi.FREEROI:
            case Roi.OVAL:
            case Roi.POLYGON:
            case Roi.TRACED_ROI:
            //     polygons = getPolygonFromMask(roi);
            // break;
            case Roi.RECTANGLE:
                polygons = getPolygonFromRectangle(roi);
                break;
        }

        if (polygons == null)
            return;

        for (Polygon polygon : polygons) {
            PolygonRoi polygonRoi = new PolygonRoi(polygon, PolygonRoi.POLYGON);
            displayIpl.setRoi(polygonRoi);
        }

        // Checking if the object selector is set to automatically accept objects
        ObjectSelector objectSelector = ((ManuallyIdentifyObjects) module).getObjectSelector();

        switch (objectSelector.getAutoAcceptMode()) {
            case ObjectSelector.AutoAcceptModes.DO_NOTHING:
            default:
                break;
            case ObjectSelector.AutoAcceptModes.ADD_AS_NEW_OBJECT:
                new Thread(() -> {
                    objectSelector.addNewObject();
                }).start();
                break;
            case ObjectSelector.AutoAcceptModes.ADD_TO_EXISTING_OBJECT:
                objectSelector.addToExistingObject();
                break;
        }
    }

    @Override
    public void mouseEntered(MouseEvent e) {
    }

    @Override
    public void mouseExited(MouseEvent e) {
    }

    // public List getPolygonFromMask(Roi roi) {
        
    //     ImageProcessor mask = roi.getMask();
    //     Img img = ImageJFunctions.wrap(new ImagePlus("Mask",mask));
    //     ImageJFunctions.show(img);
    //     try {
    //         return samJ.processMask(img);
    //     } catch (IOException | InterruptedException | RuntimeException e) {
    //         MIA.log.writeError(e);
    //     }

    //     return null;

    // }

    public List getPolygonFromPoints(Roi roi) {
        Point[] points = roi.getContainedPoints();

        ArrayList pts = new ArrayList<>();
        for (Point point : points)
            pts.add(new int[] { point.x, point.y });
        
        try {
            return samJ.processPoints(pts);
        } catch (IOException | InterruptedException | RuntimeException e) {
            MIA.log.writeError(e);
        }

        return null;

    }

    public List getPolygonFromRectangle(Roi roi) {
        Rectangle rectangle = roi.getBounds();
        int x0 = rectangle.x;
        int y0 = rectangle.y;
        int x1 = x0 + rectangle.width;
        int y1 = y0 + rectangle.height;
        int[] box = new int[] { x0, y0, x1, y1 };

        try {
            return samJ.processBox(box);
        } catch (IOException | InterruptedException | RuntimeException e) {
            MIA.log.writeError(e);
        }

        return null;

    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy