opennlp.tools.util.model.BaseModel Maven / Gradle / Ivy
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package opennlp.tools.util.model;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.io.Serial;
import java.io.Serializable;
import java.net.URL;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.HashMap;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Properties;
import java.util.zip.ZipEntry;
import java.util.zip.ZipInputStream;
import java.util.zip.ZipOutputStream;
import opennlp.tools.util.BaseToolFactory;
import opennlp.tools.util.InvalidFormatException;
import opennlp.tools.util.Version;
import opennlp.tools.util.ext.ExtensionLoader;
/**
* This is a common base model which can be used by the components' specific
* model classes.
*/
// TODO: Provide subclasses access to serializers already in constructor
public abstract class BaseModel implements ArtifactProvider, Serializable {
@Serial
private static final long serialVersionUID = -4593612444791752264L;
protected static final String MANIFEST_ENTRY = "manifest.properties";
protected static final String FACTORY_NAME = "factory";
private static final String MANIFEST_VERSION_PROPERTY = "Manifest-Version";
private static final String COMPONENT_NAME_PROPERTY = "Component-Name";
private static final String VERSION_PROPERTY = "OpenNLP-Version";
private static final String TIMESTAMP_PROPERTY = "Timestamp";
private static final String LANGUAGE_PROPERTY = "Language";
public static final String TRAINING_CUTOFF_PROPERTY = "Training-Cutoff";
public static final String TRAINING_ITERATIONS_PROPERTY = "Training-Iterations";
public static final String TRAINING_EVENTHASH_PROPERTY = "Training-Eventhash";
private static final String SERIALIZER_CLASS_NAME_PREFIX = "serializer-class-";
private Map> artifactSerializers = new HashMap<>();
protected Map artifactMap = new HashMap<>();
protected BaseToolFactory toolFactory;
private String componentName;
private boolean subclassSerializersInitiated = false;
private boolean finishedLoadingArtifacts = false;
private boolean isLoadedFromSerialized;
private BaseModel(String componentName, boolean isLoadedFromSerialized) {
this.isLoadedFromSerialized = isLoadedFromSerialized;
this.componentName = Objects.requireNonNull(componentName, "componentName must not be null!");
}
/**
* Initializes a {@link BaseModel} instance. The subclass constructor should call the
* method {@link #checkArtifactMap()} to check the artifact map is in a valid state.
*
* Subclasses will have access to custom artifacts and serializers provided
* by the specified {@code factory}.
*
* @param componentName The component name to create the model for.
* @param languageCode The ISO language code to configure. Must not be {@code null}.
* @param manifestInfoEntries Mapping for additional information in the manifest.
* @param factory The {@link BaseToolFactory factory} to use.
*/
protected BaseModel(String componentName, String languageCode,
Map manifestInfoEntries, BaseToolFactory factory) {
this(componentName, false);
Objects.requireNonNull(languageCode, "languageCode must not be null");
createBaseArtifactSerializers(artifactSerializers);
Properties manifest = new Properties();
manifest.setProperty(MANIFEST_VERSION_PROPERTY, "1.0");
manifest.setProperty(LANGUAGE_PROPERTY, languageCode);
manifest.setProperty(VERSION_PROPERTY, Version.currentVersion().toString());
manifest.setProperty(TIMESTAMP_PROPERTY, Long.toString(System.currentTimeMillis()));
manifest.setProperty(COMPONENT_NAME_PROPERTY, componentName);
if (manifestInfoEntries != null) {
for (Map.Entry entry : manifestInfoEntries.entrySet()) {
manifest.setProperty(entry.getKey(), entry.getValue());
}
}
artifactMap.put(MANIFEST_ENTRY, manifest);
finishedLoadingArtifacts = true;
if (factory != null) {
setManifestProperty(FACTORY_NAME, factory.getClass().getCanonicalName());
artifactMap.putAll(factory.createArtifactMap());
// new manifest entries
Map entries = factory.createManifestEntries();
for (Entry entry : entries.entrySet()) {
setManifestProperty(entry.getKey(), entry.getValue());
}
}
try {
initializeFactory();
} catch (InvalidFormatException e) {
throw new IllegalArgumentException("Could not initialize tool factory. ", e);
}
loadArtifactSerializers();
}
/**
* Initializes a {@link BaseModel} instance. The subclass constructor should call the
* method {@link #checkArtifactMap()} to check the artifact map is in a valid state.
*
* @param componentName The component name to create the model for.
* @param languageCode The ISO language code to configure. Must not be {@code null}.
* @param manifestInfoEntries Mapping for additional information in the manifest.
*/
protected BaseModel(String componentName, String languageCode, Map manifestInfoEntries) {
this(componentName, languageCode, manifestInfoEntries, null);
}
/**
* Initializes a {@link BaseModel} instance. The subclass constructor should call the
* method {@link #checkArtifactMap()} to check the artifact map is in a valid state.
*
* @param componentName The component name to create the model for.
* @param in A valid, open {@link InputStream} to read the model from.
*
* @throws IOException Thrown if IO errors occurred.
*/
protected BaseModel(String componentName, InputStream in) throws IOException {
this(componentName, true);
loadModel(in);
}
/**
* Initializes a {@link BaseModel} instance. The subclass constructor should call the
* method {@link #checkArtifactMap()} to check the artifact map is in a valid state.
*
* @param componentName The component name to create the model for.
* @param modelFile A valid, accessible {@link File} to read the model from.
*
* @throws IOException Thrown if IO errors occurred.
*/
protected BaseModel(String componentName, File modelFile) throws IOException {
this(componentName, true);
try (InputStream in = new BufferedInputStream(new FileInputStream(modelFile))) {
loadModel(in);
}
}
/**
* Initializes a {@link BaseModel} instance. The subclass constructor should call the
* method {@link #checkArtifactMap()} to check the artifact map is in a valid state.
*
* @param componentName The component name to create the model for.
* @param modelPath A valid, accessible {@link Path} to read the model from.
*
* @throws IOException Thrown if IO errors occurred.
*/
protected BaseModel(String componentName, Path modelPath) throws IOException {
this(componentName, true);
try (InputStream in = Files.newInputStream(modelPath)) {
loadModel(in);
}
}
/**
* Initializes a {@link BaseModel} instance. The subclass constructor should call the
* method {@link #checkArtifactMap()} to check the artifact map is in a valid state.
*
* @param componentName The component name to create the model for.
* @param modelURL A valid, accessible {@link URL} to read the model from.
*
* @throws IOException Thrown if IO errors occurred.
*/
protected BaseModel(String componentName, URL modelURL) throws IOException {
this(componentName, true);
try (InputStream in = new BufferedInputStream(modelURL.openStream())) {
loadModel(in);
}
}
private void loadModel(InputStream in) throws IOException {
Objects.requireNonNull(in, "in must not be null");
createBaseArtifactSerializers(artifactSerializers);
if (!in.markSupported()) {
in = new BufferedInputStream(in);
}
// TODO: Discuss this solution, the buffering should
int MODEL_BUFFER_SIZE_LIMIT = Integer.MAX_VALUE;
in.mark(MODEL_BUFFER_SIZE_LIMIT);
final ZipInputStream zip = new ZipInputStream(in);
// The model package can contain artifacts which are serialized with 3rd party
// serializers which are configured in the manifest file. To be able to load
// the model the manifest must be read first, and afterwards all the artifacts
// can be de-serialized.
// The ordering of artifacts in a zip package is not guaranteed. The stream is first
// read until the manifest appears, reseted, and read again to load all artifacts.
boolean isSearchingForManifest = true;
ZipEntry entry;
while ((entry = zip.getNextEntry()) != null && isSearchingForManifest) {
if ("manifest.properties".equals(entry.getName())) {
// TODO: Probably better to use the serializer here directly!
ArtifactSerializer> factory = artifactSerializers.get("properties");
artifactMap.put(entry.getName(), factory.create(zip));
isSearchingForManifest = false;
}
zip.closeEntry();
}
initializeFactory();
loadArtifactSerializers();
// The Input Stream should always be reset-able because if markSupport returns
// false it is wrapped before hand into an Buffered InputStream
in.reset();
finishLoadingArtifacts(in);
checkArtifactMap();
}
private void initializeFactory() throws InvalidFormatException {
String factoryName = getManifestProperty(FACTORY_NAME);
if (factoryName == null) {
// load the default factory
Class extends BaseToolFactory> factoryClass = getDefaultFactory();
if (factoryClass != null) {
this.toolFactory = BaseToolFactory.create(factoryClass, this);
}
} else {
try {
this.toolFactory = BaseToolFactory.create(factoryName, this);
} catch (InvalidFormatException e) {
throw new IllegalArgumentException(e);
}
}
}
/**
* Subclasses should override this method if their module has a default
* {@link BaseToolFactory} subclass.
*
* @return The default {@link BaseToolFactory} for the component, or {@code null} if none.
*/
protected Class extends BaseToolFactory> getDefaultFactory() {
return null;
}
/**
* Loads the {@link ArtifactSerializer artifact serializers}.
*/
private void loadArtifactSerializers() {
if (!subclassSerializersInitiated)
createArtifactSerializers(artifactSerializers);
subclassSerializersInitiated = true;
}
/**
* Finishes loading the artifacts now that it knows all serializers.
*/
private void finishLoadingArtifacts(InputStream in)
throws IOException {
try (final ZipInputStream zip = new ZipInputStream(in)) {
Map artifactMap = new HashMap<>();
ZipEntry entry;
while ((entry = zip.getNextEntry()) != null ) {
// Note: The manifest.properties file will be read here again,
// there should be no need to prevent that.
String entryName = entry.getName();
String extension = getEntryExtension(entryName);
ArtifactSerializer> factory = artifactSerializers.get(extension);
String artifactSerializerClazzName =
getManifestProperty(SERIALIZER_CLASS_NAME_PREFIX + entryName);
if (artifactSerializerClazzName != null) {
factory = ExtensionLoader.instantiateExtension(
ArtifactSerializer.class, artifactSerializerClazzName);
}
if (factory != null) {
artifactMap.put(entryName, factory.create(zip));
} else {
throw new InvalidFormatException("Unknown artifact format: " + extension);
}
zip.closeEntry();
}
this.artifactMap.putAll(artifactMap);
finishedLoadingArtifacts = true;
}
}
/**
* Extracts the "." extension from an entry name.
*
* @param entry The entry name which contains the extension
*
* @return The extension.
*
* @throws InvalidFormatException Thrown if no extension can be extracted
*/
private String getEntryExtension(String entry) throws InvalidFormatException {
int extensionIndex = entry.lastIndexOf('.') + 1;
if (extensionIndex >= entry.length())
throw new InvalidFormatException("Entry name must have type extension: " + entry);
return entry.substring(extensionIndex);
}
/**
* @param resourceName The identifying name of a resource to retrieve an
* {link ArtifactSerializer} for.
* @return Retrieves an {@link ArtifactSerializer artifact serialize}.
*/
protected ArtifactSerializer> getArtifactSerializer(String resourceName) {
try {
return artifactSerializers.get(getEntryExtension(resourceName));
} catch (InvalidFormatException e) {
throw new IllegalStateException(e);
}
}
/**
* Creates and registers default {@link ArtifactSerializer artifact serializes}.
*
* @return A {@link Map} with all registered {@link ArtifactSerializer artifact serializes}.
*/
protected static Map> createArtifactSerializers() {
Map> serializers = new HashMap<>();
GenericModelSerializer.register(serializers);
PropertiesSerializer.register(serializers);
DictionarySerializer.register(serializers);
serializers.put("txt", new ByteArraySerializer());
serializers.put("html", new ByteArraySerializer());
return serializers;
}
/**
* Registers all {@link ArtifactSerializer} for their artifact file name extensions.
* The registered {@link ArtifactSerializer serializers} are used to
* create and serialize resources in the model package.
*
* Override this method to register custom {@link ArtifactSerializer serializers}.
*
* Note:
* Subclasses should generally invoke {@code super.createArtifactSerializers}
* at the beginning of this method.
*
* This method is called during construction.
*
* @param serializers The key of the map is the file extension used to look up
* an {@link ArtifactSerializer}.
*/
protected void createArtifactSerializers(Map> serializers) {
if (this.toolFactory != null)
serializers.putAll(this.toolFactory.createArtifactSerializersMap());
}
private void createBaseArtifactSerializers(Map> serializers) {
serializers.putAll(createArtifactSerializers());
}
/**
* Validates the parsed artifacts. If something is not
* valid, subclasses should throw an {@link InvalidFormatException}.
*
*
* Note:
* Subclasses should generally invoke {@code super.validateArtifactMap}
* at the beginning of this method.
*
* @throws InvalidFormatException Thrown if artifacts were found to be inconsistent.
*/
protected void validateArtifactMap() throws InvalidFormatException {
if (!(artifactMap.get(MANIFEST_ENTRY) instanceof Properties))
throw new InvalidFormatException("Missing the " + MANIFEST_ENTRY + "!");
// First check version, everything else might change in the future
String versionString = getManifestProperty(VERSION_PROPERTY);
if (versionString != null) {
Version version;
try {
version = Version.parse(versionString);
}
catch (NumberFormatException e) {
throw new InvalidFormatException("Unable to parse model version '" + versionString + "'!", e);
}
// Version check is only performed if current version is not the dev/debug version
if (!Version.currentVersion().equals(Version.DEV_VERSION)) {
// Support OpenNLP 1.x models.
if (version.getMajor() != 1 && version.getMajor() != 2) {
throw new InvalidFormatException("Model version " + version + " is not supported by this ("
+ Version.currentVersion() + ") version of OpenNLP!");
}
// Reject loading a snapshot model with a non-snapshot version
if (!Version.currentVersion().isSnapshot() && version.isSnapshot()) {
throw new InvalidFormatException("Model version " + version
+ " is a snapshot - snapshot models are not supported by this non-snapshot version ("
+ Version.currentVersion() + ") of OpenNLP!");
}
}
}
else {
throw new InvalidFormatException("Missing " + VERSION_PROPERTY + " property in " +
MANIFEST_ENTRY + "!");
}
if (getManifestProperty(COMPONENT_NAME_PROPERTY) == null)
throw new InvalidFormatException("Missing " + COMPONENT_NAME_PROPERTY + " property in " +
MANIFEST_ENTRY + "!");
if (!getManifestProperty(COMPONENT_NAME_PROPERTY).equals(componentName))
throw new InvalidFormatException("The " + componentName + " cannot load a model for the " +
getManifestProperty(COMPONENT_NAME_PROPERTY) + "!");
if (getManifestProperty(LANGUAGE_PROPERTY) == null)
throw new InvalidFormatException("Missing " + LANGUAGE_PROPERTY + " property in " +
MANIFEST_ENTRY + "!");
// Validate the factory. We try to load it using the ExtensionLoader. It
// will return the factory, null or raise an exception
String factoryName = getManifestProperty(FACTORY_NAME);
if (factoryName != null) {
try {
if (ExtensionLoader.instantiateExtension(BaseToolFactory.class,
factoryName) == null) {
throw new InvalidFormatException(
"Could not load an user extension specified by the model: "
+ factoryName);
}
} catch (Exception e) {
throw new InvalidFormatException(
"Could not load an user extension specified by the model: "
+ factoryName, e);
}
}
// validate artifacts declared by the factory
if (toolFactory != null) {
toolFactory.validateArtifactMap();
}
}
/**
* Checks the artifact map.
*
* A subclass should call this method from a constructor which accepts the individual
* artifact map items, to validate that these items form a valid model.
*
*
* @throws IllegalArgumentException Thrown if the artifacts are not valid.
* @throws IllegalStateException Thrown if {@link BaseModel#finishLoadingArtifacts(InputStream)} was
* not called by a subclass.
*/
protected void checkArtifactMap() {
if (!finishedLoadingArtifacts)
throw new IllegalStateException(
"The method BaseModel.finishLoadingArtifacts(..) was not called by BaseModel subclass.");
try {
validateArtifactMap();
} catch (InvalidFormatException e) {
throw new IllegalArgumentException(e);
}
}
@Override
public final String getManifestProperty(String key) {
Properties manifest = (Properties) artifactMap.get(MANIFEST_ENTRY);
return manifest.getProperty(key);
}
/**
* Sets a given value for a given key to the {@code manifest.properties} mapping.
*
* @param key The identifying key.
* @param value The value to set at {@code key}.
*/
protected final void setManifestProperty(String key, String value) {
Properties manifest = (Properties) artifactMap.get(MANIFEST_ENTRY);
manifest.setProperty(key, value);
}
@Override
public final String getLanguage() {
return getManifestProperty(LANGUAGE_PROPERTY);
}
/**
* @return Retrieves the OpenNLP {@link Version} which was used to create the model.
*/
public final Version getVersion() {
String version = getManifestProperty(VERSION_PROPERTY);
return Version.parse(version);
}
/**
* Serializes the model to the given {@link OutputStream}.
*
* @param out The {@link OutputStream} to write the model to.
*
* @throws IOException Thrown if IO errors occurred.
* @throws IllegalStateException Thrown if {@link BaseModel#loadArtifactSerializers()} was
* not called in a subclass constructor.
*/
@SuppressWarnings("unchecked")
public final void serialize(OutputStream out) throws IOException {
if (!subclassSerializersInitiated) {
throw new IllegalStateException(
"The method BaseModel.loadArtifactSerializers() was not called by BaseModel subclass constructor.");
}
for (Entry entry : artifactMap.entrySet()) {
final String name = entry.getKey();
final Object artifact = entry.getValue();
if (artifact instanceof SerializableArtifact serializableArtifact) {
String artifactSerializerName = serializableArtifact
.getArtifactSerializerClass().getName();
setManifestProperty(SERIALIZER_CLASS_NAME_PREFIX + name,
artifactSerializerName);
}
}
ZipOutputStream zip = new ZipOutputStream(out);
for (Entry entry : artifactMap.entrySet()) {
String name = entry.getKey();
zip.putNextEntry(new ZipEntry(name));
Object artifact = entry.getValue();
if (skipEntryForSerialization(entry)) {
continue;
}
ArtifactSerializer