
opennlp.dl.doccat.DocumentCategorizerDL Maven / Gradle / Ivy
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package opennlp.dl.doccat;
import java.io.File;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.SortedMap;
import java.util.TreeMap;
import opennlp.dl.Inference;
import opennlp.tools.doccat.DocumentCategorizer;
/**
* An implementation of {@link DocumentCategorizer} that performs document classification
* using ONNX models.
*/
public class DocumentCategorizerDL implements DocumentCategorizer {
private final File model;
private final File vocab;
private final Map categories;
/**
* Creates a new document categorizer using ONNX models.
* @param model The ONNX model file.
* @param vocab The model's vocabulary file.
* @param categories The categories.
*/
public DocumentCategorizerDL(File model, File vocab, Map categories) {
this.model = model;
this.vocab = vocab;
this.categories = categories;
}
@Override
public double[] categorize(String[] strings) {
try {
final DocumentCategorizerInference inference = new DocumentCategorizerInference(model, vocab);
final double[][] vectors = inference.infer(strings[0]);
final double[] results = inference.softmax(vectors[0]);
return results;
} catch (Exception ex) {
System.err.println("Unload to perform document classification inference: " + ex.getMessage());
}
return new double[]{};
}
@Override
public double[] categorize(String[] strings, Map map) {
return categorize(strings);
}
@Override
public String getBestCategory(double[] doubles) {
return categories.get(Inference.maxIndex(doubles));
}
@Override
public int getIndex(String s) {
return getKey(s);
}
@Override
public String getCategory(int i) {
return categories.get(i);
}
@Override
public int getNumberOfCategories() {
return categories.size();
}
@Override
public String getAllResults(double[] doubles) {
return null;
}
@Override
public Map scoreMap(String[] strings) {
final double[] scores = categorize(strings);
final Map scoreMap = new HashMap<>();
for (int x : categories.keySet()) {
scoreMap.put(categories.get(x), scores[x]);
}
return scoreMap;
}
@Override
public SortedMap> sortedScoreMap(String[] strings) {
final double[] scores = categorize(strings);
final SortedMap> scoreMap = new TreeMap<>();
for (int x : categories.keySet()) {
if (scoreMap.get(scores[x]) == null) {
scoreMap.put(scores[x], new HashSet<>());
}
scoreMap.get(scores[x]).add(categories.get(x));
}
return scoreMap;
}
private int getKey(String value) {
for (Map.Entry entry : categories.entrySet()) {
if (entry.getValue().equals(value)) {
return entry.getKey();
}
}
// The String wasn't found as a value in the map.
return -1;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy