All Downloads are FREE. Search and download functionalities are using the official Maven repository.

de.julielab.geneexpbase.hpo.InspectionFilePrinter Maven / Gradle / Ivy

package de.julielab.geneexpbase.hpo;

import com.google.common.collect.Sets;
import de.julielab.evaluation.entities.EntityEvaluationResult;
import de.julielab.evaluation.entities.EvaluationMode;
import de.julielab.geneexpbase.candidateretrieval.CandidateRetrieval;
import de.julielab.geneexpbase.candidateretrieval.QueryGenerator;
import de.julielab.geneexpbase.genemodel.GeneDocument;
import de.julielab.geneexpbase.genemodel.GeneMention;
import de.julielab.java.utilities.Color;
import de.julielab.java.utilities.FileUtilities;

import java.io.BufferedWriter;
import java.io.File;
import java.io.IOException;
import java.nio.file.Path;
import java.text.DecimalFormat;
import java.util.*;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class InspectionFilePrinter {

    private final Function correctnessFunction;
    private final Function correctIdRenderer;
    private final Function wrongIdRenderer;
    private final Function fpRenderer;
    private final Function> geneMentionGoldIdFunction;
    private final Function> documentGoldIdFunction;
    private final Function> predictedIdFunction;

    public InspectionFilePrinter(Function correctnessFunction,
                                 Function correctIdRenderer,
                                 Function wrongIdRenderer,
                                 Function fpRenderer,
                                 Function> geneMentionGoldIdFunction,
                                 Function> documentGoldIdFunction,
                                 Function> predictedIdFunction) {
        this.correctnessFunction = correctnessFunction;
        this.correctIdRenderer = correctIdRenderer;
        this.wrongIdRenderer = wrongIdRenderer;
        this.fpRenderer = fpRenderer;
        this.geneMentionGoldIdFunction = geneMentionGoldIdFunction;
        this.documentGoldIdFunction = documentGoldIdFunction;
        this.predictedIdFunction = predictedIdFunction;


    }

    public static Color getHighlightColor(double score) {
        if (score < .4)
            return Color.RED;
        else if (score < .6)
            return Color.MAGENTA;
        return Color.GREEN;
    }

    private static void resetDocument(GeneDocument d) {
        // Clean the results created in this run, just to be sure there is no information bleed-through between
        // optimization rounds.
        d.reset();
        d.getGenes().forEach(gm -> {
            gm.setIds(new ArrayList<>());
            gm.setTaxonomyIds(new ArrayList<>());
        });
        d.addState(GeneDocument.State.GENES_SELECTED);
        d.addState(GeneDocument.State.SPECIES_MENTIONS_SET);
        d.addState(GeneDocument.State.REFERENCE_SPECIES_ADDED);
    }

    public void printInspectionFile(EntityEvaluationResult result, String idType, HpoRoute.Metric fileNameMetric, HpoInstance hpoInstance, List documents, CandidateRetrieval candidateRetrievalForInspections, QueryGenerator queryGenerator) {
        Function r;
        Function p;
        Function f;
        if (result.getEvaluationMode() == EvaluationMode.MENTION) {
            r = e -> e.getMicroRecallMentionWise();
            p = e -> e.getMicroPrecisionMentionWise();
            f = e -> e.getMicroFMeasureMentionWise();
        } else {
            r = e -> e.getMicroRecallDocWise();
            p = e -> e.getMicroPrecisionDocWise();
            f = e -> e.getMicroFMeasureDocWise();
        }
        double overallF = f.apply(result);
        double fileMetricValue;
        switch (fileNameMetric) {
            case RECALL:
            case RECALL_REJECTION:
            case NDCG:
            case MAX_RECALL:
                fileMetricValue = r.apply(result);
                break;
            case PRECISION:
            case PRECISION_REJECTION:
                fileMetricValue = p.apply(result);
                break;
            case F:
            case F_REJECTION:
                fileMetricValue = f.apply(result);
                break;
            default:
                throw new IllegalArgumentException("Unsupported file name metric: " + fileNameMetric);
        }

        DecimalFormat df = new DecimalFormat("0.##");
        File output = Path.of("inspectionfiles-" + idType, hpoInstance.toString() + "-" + df.format(fileMetricValue) + ".txt").toFile();
        if (!output.getParentFile().exists())
            output.getParentFile().mkdirs();


        Function> goldWithOffsetIdFunction = d -> d.getGoldGenes().values().stream().flatMap(Collection::stream).flatMap(geneMentionGoldIdFunction).sorted().collect(Collectors.toCollection(LinkedHashSet::new));
        Function> goldNoOffsetIdFunction = d -> documentGoldIdFunction.apply(d).sorted().collect(Collectors.toCollection(LinkedHashSet::new));
        Function> goldIdFunction = d -> d.isGoldHasOffsets() ? goldWithOffsetIdFunction.apply(d) : goldNoOffsetIdFunction.apply(d);

        Map> renderMap = Map.of(
                GeneDocument.MentionCorrectness.CORRECT_ID, correctIdRenderer,
                GeneDocument.MentionCorrectness.WRONG_ID, wrongIdRenderer,
                GeneDocument.MentionCorrectness.CANT_FIND, fpRenderer);
        try (BufferedWriter bw = FileUtilities.getWriterToFile(output)) {
            bw.write("Instance: " + hpoInstance);
            bw.newLine();
            bw.write(String.format("Overall score [R/P/F]: %s%.2f, %.2f, %.2f%s", Color.RED, r.apply(result), p.apply(result), overallF, Color.RESET));
            bw.newLine();
            bw.newLine();
            for (GeneDocument d : documents) {
                double rd = result.getRecall(d.getId());
                double pd = result.getPrecision(d.getId());
                double fd = result.getFMeasure(d.getId());
                Color rcol = getHighlightColor(rd);
                Color pcol = getHighlightColor(pd);
                Color fcol = getHighlightColor(fd);
                bw.write(String.format("%s [R/P/F]: %2$s%5$.2f, %3$s%6$.2f, %4$s%7$.2f%8$s", d.getId(), rcol, pcol, fcol, rd, pd, fd, Color.RESET));
                bw.newLine();
                Set goldIds = goldIdFunction.apply(d);
                Set predictedIds = d.getNonRejectedGenes().flatMap(predictedIdFunction).sorted().collect(Collectors.toCollection(LinkedHashSet::new));
                Set fps = Sets.difference(predictedIds, goldIds);
                Set fns = Sets.difference(goldIds, predictedIds);
                bw.write("Gold tax IDs in doc: " + goldIds.stream().collect(Collectors.joining(" ")));
                bw.newLine();
                bw.write("Predicted IDs in doc: " + predictedIds.stream().collect(Collectors.joining(" ")));
                bw.newLine();
                bw.write("Missing IDs doclevel: " + fns.stream().collect(Collectors.joining(", ")));
                bw.newLine();
                bw.write("Wrong IDs doclevel: " + fps.stream().collect(Collectors.joining(", ")));
                bw.newLine();
                bw.write(d.getInspectionText(correctnessFunction, renderMap));
                bw.newLine();
                bw.newLine();

                resetDocument(d);
            }

        } catch (IOException e) {
            e.printStackTrace();
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy