
de.julielab.genemapper.resources.TransformerDisambiguationDataWriter Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of gene-mapper-resources Show documentation
Show all versions of gene-mapper-resources Show documentation
This project assembles code and files required to build the dictionaries and indexes used by the JCoRe
Gene Mapper.
The newest version!
package de.julielab.genemapper.resources;
import com.google.inject.Guice;
import com.google.inject.Injector;
import de.julielab.geneexpbase.data.DocumentLoader;
import de.julielab.geneexpbase.data.DocumentLoadingException;
import de.julielab.geneexpbase.data.DocumentSourceFileRegistry;
import de.julielab.geneexpbase.data.DocumentSourceFiles;
import de.julielab.geneexpbase.genemodel.GeneDocument;
import de.julielab.geneexpbase.ioc.ServicesShutdownHub;
import de.julielab.genemapper.Configuration;
import de.julielab.genemapper.GeneMapper;
import de.julielab.genemapper.classification.TransformerDisambiguationDataUtils;
import de.julielab.genemapper.ioc.GeneMappingModule;
import de.julielab.genemapper.utils.GeneMapperException;
import de.julielab.genemapper.utils.GeneMapperInitializationException;
import org.apache.commons.io.FileUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class TransformerDisambiguationDataWriter {
private final static Logger log = LoggerFactory.getLogger(TransformerDisambiguationDataWriter.class);
public static void main(String[] args) throws IOException, GeneMapperInitializationException, ExecutionException, GeneMapperException, DocumentLoadingException {
Configuration configuration = new Configuration(new File("configurations/genemapper_transformer_data.properties"));
Injector injector = Guice.createInjector(new GeneMappingModule(configuration));
GeneMapper geneMapper = injector.getInstance(GeneMapper.class);
DocumentLoader documentLoader = injector.getInstance(DocumentLoader.class);
String goldTaxMode = TransformerDisambiguationDataUtils.USE_GOLD_TAX_FOR_CANDIDATE_RETRIEVAL ? "goldTax" : "noGoldTax";
String matchMode;
if (TransformerDisambiguationDataUtils.ONLY_APPROX_MATCHES)
matchMode = "onlyApproxMatches";
else if (TransformerDisambiguationDataUtils.ONLY_EXACT_MATCHES)
matchMode = "onlyExactMatches";
else
matchMode = "allMatches";
String fpMentionMode = TransformerDisambiguationDataUtils.EXCLUDE_FP_GM ? "excludeFpMentions" : "includeFpMentions";
// DocumentSourceFiles documentSourceFiles = DocumentSourceFileRegistry.gnpBc2gnTrain();
// DocumentSourceFiles documentSourceFiles = DocumentSourceFileRegistry.gnpNlmIat();
DocumentSourceFiles documentSourceFiles = DocumentSourceFileRegistry.gnpBc2gnTest();
// File corpusSplitMapping = new File("splitmappings/lexrank-gnormplus-bc2train-10split-5devfreq.txt");
// File corpusSplitMapping = new File("splitmappings/lexrank-gnormplus-nlmiat-10split-5devfreq.txt");
File corpusSplitMapping = null;
// File outputFile = new File("transformerDisambiguationData-gnpbc2gntrain-v" + TransformerDisambiguationDataUtils.VERSION + "-" + goldTaxMode + "-" + matchMode + "-" + fpMentionMode + ".tsv");
// File outputFile = new File("transformerDisambiguationData-nlmiat-v"+ TransformerDisambiguationDataUtils.VERSION+"-" + goldTaxMode + "-" + matchMode + "-" + fpMentionMode + ".tsv");
File outputFile = new File("transformerDisambiguationData-bc2test-v" + TransformerDisambiguationDataUtils.VERSION + "-" + goldTaxMode + "-" + matchMode + "-" + fpMentionMode + ".tsv");
createDisambiguationData(documentSourceFiles, documentLoader, geneMapper, outputFile, corpusSplitMapping);
injector.getInstance(ServicesShutdownHub.class).shutdown();
log.info("Data creation complete.");
}
public static void createDisambiguationData(DocumentSourceFiles sourceFiles, DocumentLoader documentLoader, GeneMapper mapper, File outputFile, File corpusSplitMapping) throws IOException, ExecutionException, DocumentLoadingException, GeneMapperException {
// Split mappings are created when running SmacOptimizationRoute implementations. They store a file mapping each
// document in the corpus to a partition.
// Format: docIdpartition
// We use this information to write separate training test and dev files for transformer fine tuning and evaluation.
String outputPath = outputFile.getAbsolutePath();
List documents = documentLoader.getDocuments(sourceFiles).collect(Collectors.toList());
if (corpusSplitMapping != null) {
List dataSplitLines = FileUtils.readLines(corpusSplitMapping, StandardCharsets.UTF_8);
log.info("Read {} document IDs from {}", dataSplitLines.size(), corpusSplitMapping);
// Set splitNumbers = dataSplitLines.stream().map(l -> l.split("\\t")[1]).filter(split -> !"dev".equals(split)).collect(Collectors.toSet());
// for (var splitNumber : splitNumbers) {
// File trainSplitFile = new File(outputPath.substring(0, outputPath.lastIndexOf('.')) + "-train"+splitNumber+".tsv");
// File testSplitFile = new File(outputPath.substring(0, outputPath.lastIndexOf('.')) + "-test"+splitNumber+".tsv");
// String currentRound = splitNumber;
// Set trainDocIds = dataSplitLines.stream().map(line -> line.split("\\s+")).filter(s -> !s[1].equals("dev") && !s[1].equals(currentRound)).map(s -> s[0]).collect(Collectors.toSet());
// Set testDocIds = dataSplitLines.stream().map(line -> line.split("\\s+")).filter(s -> s[1].equals(currentRound)).map(s -> s[0]).collect(Collectors.toSet());
// Stream trainStream = documents.stream().filter(d -> trainDocIds.contains(d.getId()));
// Stream testStream = documents.stream().filter(d -> testDocIds.contains(d.getId()));
// log.info("Writing transformer training data for corpus {} to {}, split {}", sourceFiles.getName(), trainSplitFile, splitNumber);
// TransformerDisambiguationDataUtils.writeData(mapper, trainSplitFile, trainStream);
// log.info("Writing transformer test data for corpus {} to {}, split {}", sourceFiles.getName(), testSplitFile, splitNumber);
// TransformerDisambiguationDataUtils.writeData(mapper, testSplitFile, testStream);
// }
Set devDocIds = dataSplitLines.stream().map(line -> line.split("\\s+")).filter(s -> s[1].equals("dev")).map(s -> s[0]).collect(Collectors.toSet());
File devFile = new File(outputPath.substring(0, outputPath.lastIndexOf('.')) + "-dev.tsv");
log.info("Got {} dev docs from {} that will be omitted from the training data and written to {}.", devDocIds.size(), corpusSplitMapping, devFile);
Stream trainStream = documents.stream().filter(d -> !devDocIds.contains(d.getId()));
Stream devStream = documents.stream().filter(d -> devDocIds.contains(d.getId()));
log.info("Writing transformer training data for corpus {} to {}", sourceFiles.getName(), outputFile);
TransformerDisambiguationDataUtils.writeData(mapper, outputFile, trainStream);
TransformerDisambiguationDataUtils.writeData(mapper, devFile, devStream);
} else {
TransformerDisambiguationDataUtils.writeData(mapper, outputFile, documents.stream());
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy