opennlp.tools.util.model.BaseModel Maven / Gradle / Ivy
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package opennlp.tools.util.model;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.net.URL;
import java.nio.file.Path;
import java.util.HashMap;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Properties;
import java.util.UUID;
import java.util.zip.ZipEntry;
import java.util.zip.ZipInputStream;
import java.util.zip.ZipOutputStream;
import ai.idylnlp.opennlp.custom.EncryptedDataOutputStream;
import opennlp.tools.util.BaseToolFactory;
import opennlp.tools.util.InvalidFormatException;
import opennlp.tools.util.Version;
import opennlp.tools.util.ext.ExtensionLoader;
/**
* This model is a common based which can be used by the components
* model classes.
*
* TODO:
* Provide sub classes access to serializers already in constructor
*/
public abstract class BaseModel implements ArtifactProvider, Serializable {
protected static final String MANIFEST_ENTRY = "manifest.properties";
protected static final String FACTORY_NAME = "factory";
private static final String MANIFEST_VERSION_PROPERTY = "Manifest-Version";
private static final String COMPONENT_NAME_PROPERTY = "Component-Name";
private static final String VERSION_PROPERTY = "OpenNLP-Version";
private static final String TIMESTAMP_PROPERTY = "Timestamp";
private static final String LANGUAGE_PROPERTY = "Language";
public static final String TRAINING_CUTOFF_PROPERTY = "Training-Cutoff";
public static final String TRAINING_ITERATIONS_PROPERTY = "Training-Iterations";
public static final String TRAINING_EVENTHASH_PROPERTY = "Training-Eventhash";
private static String SERIALIZER_CLASS_NAME_PREFIX = "serializer-class-";
private Map artifactSerializers = new HashMap<>();
protected Map artifactMap = new HashMap<>();
protected BaseToolFactory toolFactory;
private String componentName;
private boolean subclassSerializersInitiated = false;
private boolean finishedLoadingArtifacts = false;
private boolean isLoadedFromSerialized;
private BaseModel(String componentName, boolean isLoadedFromSerialized) {
this.isLoadedFromSerialized = isLoadedFromSerialized;
this.componentName = Objects.requireNonNull(componentName, "componentName must not be null!");
}
/**
* Initializes the current instance. The sub-class constructor should call the
* method {@link #checkArtifactMap()} to check the artifact map is OK.
*
* Sub-classes will have access to custom artifacts and serializers provided
* by the factory.
*
* @param componentName
* the component name
* @param languageCode
* the language code
* @param manifestInfoEntries
* additional information in the manifest
* @param factory
* the factory
*/
protected BaseModel(String componentName, String languageCode,
Map manifestInfoEntries, BaseToolFactory factory) {
this(componentName, false);
Objects.requireNonNull(languageCode, "languageCode must not be null");
createBaseArtifactSerializers(artifactSerializers);
Properties manifest = new Properties();
manifest.setProperty(MANIFEST_VERSION_PROPERTY, "1.0");
manifest.setProperty(LANGUAGE_PROPERTY, languageCode);
manifest.setProperty(VERSION_PROPERTY, Version.currentVersion().toString());
manifest.setProperty(TIMESTAMP_PROPERTY, Long.toString(System.currentTimeMillis()));
manifest.setProperty(COMPONENT_NAME_PROPERTY, componentName);
if (manifestInfoEntries != null) {
for (Map.Entry entry : manifestInfoEntries.entrySet()) {
manifest.setProperty(entry.getKey(), entry.getValue());
}
}
artifactMap.put(MANIFEST_ENTRY, manifest);
finishedLoadingArtifacts = true;
if (factory != null) {
setManifestProperty(FACTORY_NAME, factory.getClass().getCanonicalName());
artifactMap.putAll(factory.createArtifactMap());
// new manifest entries
Map entries = factory.createManifestEntries();
for (Entry entry : entries.entrySet()) {
setManifestProperty(entry.getKey(), entry.getValue());
}
}
try {
initializeFactory();
} catch (InvalidFormatException e) {
throw new IllegalArgumentException("Could not initialize tool factory. ", e);
}
loadArtifactSerializers();
}
/**
* Initializes the current instance. The sub-class constructor should call the
* method {@link #checkArtifactMap()} to check the artifact map is OK.
*
* @param componentName
* the component name
* @param languageCode
* the language code
* @param manifestInfoEntries
* additional information in the manifest
*/
protected BaseModel(String componentName, String languageCode, Map manifestInfoEntries) {
this(componentName, languageCode, manifestInfoEntries, null);
}
/**
* Initializes the current instance.
*
* @param componentName the component name
* @param in the input stream containing the model
*
* @throws IOException
*/
protected BaseModel(String componentName, InputStream in) throws IOException {
this(componentName, true);
loadModel(in);
}
protected BaseModel(String componentName, File modelFile) throws IOException {
this(componentName, true);
try (InputStream in = new BufferedInputStream(new FileInputStream(modelFile))) {
loadModel(in);
}
}
protected BaseModel(String componentName, URL modelURL) throws IOException {
this(componentName, true);
try (InputStream in = new BufferedInputStream(modelURL.openStream())) {
loadModel(in);
}
}
protected BaseModel() {
// Used when the model is not an OpenNLP model.
}
private void loadModel(InputStream in) throws IOException {
Objects.requireNonNull(in, "in must not be null");
createBaseArtifactSerializers(artifactSerializers);
if (!in.markSupported()) {
in = new BufferedInputStream(in);
}
// TODO: Discuss this solution, the buffering should
int MODEL_BUFFER_SIZE_LIMIT = Integer.MAX_VALUE;
in.mark(MODEL_BUFFER_SIZE_LIMIT);
final ZipInputStream zip = new ZipInputStream(in);
// The model package can contain artifacts which are serialized with 3rd party
// serializers which are configured in the manifest file. To be able to load
// the model the manifest must be read first, and afterwards all the artifacts
// can be de-serialized.
// The ordering of artifacts in a zip package is not guaranteed. The stream is first
// read until the manifest appears, reseted, and read again to load all artifacts.
boolean isSearchingForManifest = true;
ZipEntry entry;
while ((entry = zip.getNextEntry()) != null && isSearchingForManifest) {
if ("manifest.properties".equals(entry.getName())) {
// TODO: Probably better to use the serializer here directly!
ArtifactSerializer factory = artifactSerializers.get("properties");
artifactMap.put(entry.getName(), factory.create(zip));
isSearchingForManifest = false;
}
zip.closeEntry();
}
initializeFactory();
loadArtifactSerializers();
// The Input Stream should always be reset-able because if markSupport returns
// false it is wrapped before hand into an Buffered InputStream
in.reset();
finishLoadingArtifacts(in);
checkArtifactMap();
}
private void initializeFactory() throws InvalidFormatException {
String factoryName = getManifestProperty(FACTORY_NAME);
if (factoryName == null) {
// load the default factory
Class extends BaseToolFactory> factoryClass = getDefaultFactory();
if (factoryClass != null) {
this.toolFactory = BaseToolFactory.create(factoryClass, this);
}
} else {
try {
this.toolFactory = BaseToolFactory.create(factoryName, this);
} catch (InvalidFormatException e) {
throw new IllegalArgumentException(e);
}
}
}
/**
* Sub-classes should override this method if their module has a default
* BaseToolFactory sub-class.
*
* @return the default {@link BaseToolFactory} for the module, or null if none.
*/
protected Class extends BaseToolFactory> getDefaultFactory() {
return null;
}
/**
* Loads the artifact serializers.
*/
private void loadArtifactSerializers() {
if (!subclassSerializersInitiated)
createArtifactSerializers(artifactSerializers);
subclassSerializersInitiated = true;
}
/**
* Finish loading the artifacts now that it knows all serializers.
*/
private void finishLoadingArtifacts(InputStream in)
throws IOException {
final ZipInputStream zip = new ZipInputStream(in);
Map artifactMap = new HashMap<>();
ZipEntry entry;
while ((entry = zip.getNextEntry()) != null ) {
// Note: The manifest.properties file will be read here again,
// there should be no need to prevent that.
String entryName = entry.getName();
String extension = getEntryExtension(entryName);
ArtifactSerializer factory = artifactSerializers.get(extension);
String artifactSerializerClazzName =
getManifestProperty(SERIALIZER_CLASS_NAME_PREFIX + entryName);
if (artifactSerializerClazzName != null) {
factory = ExtensionLoader.instantiateExtension(ArtifactSerializer.class, artifactSerializerClazzName);
}
if (factory != null) {
artifactMap.put(entryName, factory.create(zip));
} else {
throw new InvalidFormatException("Unknown artifact format: " + extension);
}
zip.closeEntry();
}
this.artifactMap.putAll(artifactMap);
finishedLoadingArtifacts = true;
}
/**
* Extracts the "." extension from an entry name.
*
* @param entry the entry name which contains the extension
*
* @return the extension
*
* @throws InvalidFormatException if no extension can be extracted
*/
private String getEntryExtension(String entry) throws InvalidFormatException {
int extensionIndex = entry.lastIndexOf('.') + 1;
if (extensionIndex == -1 || extensionIndex >= entry.length())
throw new InvalidFormatException("Entry name must have type extension: " + entry);
return entry.substring(extensionIndex);
}
protected ArtifactSerializer getArtifactSerializer(String resourceName) {
try {
return artifactSerializers.get(getEntryExtension(resourceName));
} catch (InvalidFormatException e) {
throw new IllegalStateException(e);
}
}
protected static Map createArtifactSerializers() {
Map serializers = new HashMap<>();
GenericModelSerializer.register(serializers);
PropertiesSerializer.register(serializers);
DictionarySerializer.register(serializers);
serializers.put("txt", new ByteArraySerializer());
serializers.put("html", new ByteArraySerializer());
return serializers;
}
/**
* Registers all {@link ArtifactSerializer} for their artifact file name extensions.
* The registered {@link ArtifactSerializer} are used to create and serialize
* resources in the model package.
*
* Override this method to register custom {@link ArtifactSerializer}s.
*
* Note:
* Subclasses should generally invoke super.createArtifactSerializers at the beginning
* of this method.
*
* This method is called during construction.
*
* @param serializers the key of the map is the file extension used to lookup
* the {@link ArtifactSerializer}.
*/
protected void createArtifactSerializers(
Map serializers) {
if (this.toolFactory != null)
serializers.putAll(this.toolFactory.createArtifactSerializersMap());
}
private void createBaseArtifactSerializers(
Map serializers) {
serializers.putAll(createArtifactSerializers());
}
/**
* Validates the parsed artifacts. If something is not
* valid subclasses should throw an {@link InvalidFormatException}.
*
* Note:
* Subclasses should generally invoke super.validateArtifactMap at the beginning
* of this method.
*
* @throws InvalidFormatException
*/
protected void validateArtifactMap() throws InvalidFormatException {
if (!(artifactMap.get(MANIFEST_ENTRY) instanceof Properties))
throw new InvalidFormatException("Missing the " + MANIFEST_ENTRY + "!");
// First check version, everything else might change in the future
String versionString = getManifestProperty(VERSION_PROPERTY);
if (versionString != null) {
Version version;
try {
version = Version.parse(versionString);
}
catch (NumberFormatException e) {
throw new InvalidFormatException("Unable to parse model version '" + versionString + "'!", e);
}
// Version check is only performed if current version is not the dev/debug version
if (!Version.currentVersion().equals(Version.DEV_VERSION)) {
// Major and minor version must match, revision might be
// this check allows for the use of models of n minor release behind current minor release
if (Version.currentVersion().getMajor() != version.getMajor() ||
Version.currentVersion().getMinor() - 3 > version.getMinor()) {
throw new InvalidFormatException("Model version " + version + " is not supported by this ("
+ Version.currentVersion() + ") version of OpenNLP!");
}
// Reject loading a snapshot model with a non-snapshot version
if (!Version.currentVersion().isSnapshot() && version.isSnapshot()) {
throw new InvalidFormatException("Model version " + version
+ " is a snapshot - snapshot models are not supported by this non-snapshot version ("
+ Version.currentVersion() + ") of OpenNLP!");
}
}
}
else {
throw new InvalidFormatException("Missing " + VERSION_PROPERTY + " property in " +
MANIFEST_ENTRY + "!");
}
if (getManifestProperty(COMPONENT_NAME_PROPERTY) == null)
throw new InvalidFormatException("Missing " + COMPONENT_NAME_PROPERTY + " property in " +
MANIFEST_ENTRY + "!");
if (!getManifestProperty(COMPONENT_NAME_PROPERTY).equals(componentName))
throw new InvalidFormatException("The " + componentName + " cannot load a model for the " +
getManifestProperty(COMPONENT_NAME_PROPERTY) + "!");
if (getManifestProperty(LANGUAGE_PROPERTY) == null)
throw new InvalidFormatException("Missing " + LANGUAGE_PROPERTY + " property in " +
MANIFEST_ENTRY + "!");
// Validate the factory. We try to load it using the ExtensionLoader. It
// will return the factory, null or raise an exception
String factoryName = getManifestProperty(FACTORY_NAME);
if (factoryName != null) {
try {
if (ExtensionLoader.instantiateExtension(BaseToolFactory.class,
factoryName) == null) {
throw new InvalidFormatException(
"Could not load an user extension specified by the model: "
+ factoryName);
}
} catch (Exception e) {
throw new InvalidFormatException(
"Could not load an user extension specified by the model: "
+ factoryName, e);
}
}
// validate artifacts declared by the factory
if (toolFactory != null) {
toolFactory.validateArtifactMap();
}
}
/**
* Checks the artifact map.
*
* A subclass should call this method from a constructor which accepts the individual
* artifact map items, to validate that these items form a valid model.
*
* If the artifacts are not valid an IllegalArgumentException will be thrown.
*/
protected void checkArtifactMap() {
if (!finishedLoadingArtifacts)
throw new IllegalStateException(
"The method BaseModel.finishLoadingArtifacts(..) was not called by BaseModel sub-class.");
try {
validateArtifactMap();
} catch (InvalidFormatException e) {
throw new IllegalArgumentException(e);
}
}
/**
* Retrieves the value to the given key from the manifest.properties
* entry.
*
* @param key
*
* @return the value
*/
public final String getManifestProperty(String key) {
Properties manifest = (Properties) artifactMap.get(MANIFEST_ENTRY);
return manifest.getProperty(key);
}
/**
* Sets a given value for a given key to the manifest.properties entry.
*
* @param key
* @param value
*/
protected final void setManifestProperty(String key, String value) {
Properties manifest = (Properties) artifactMap.get(MANIFEST_ENTRY);
manifest.setProperty(key, value);
}
public String getModelId() {
return getManifestProperty("model.id");
}
/**
* Retrieves the language code of the material which
* was used to train the model or x-unspecified if
* non was set.
*
* @return the language code of this model
*/
public final String getLanguage() {
return getManifestProperty(LANGUAGE_PROPERTY);
}
/**
* Retrieves the OpenNLP version which was used
* to create the model.
*
* @return the version
*/
public final Version getVersion() {
String version = getManifestProperty(VERSION_PROPERTY);
return Version.parse(version);
}
/**
* Serializes the model to the given {@link OutputStream}.
*
* @param out stream to write the model to
* @throws IOException
*/
@SuppressWarnings("unchecked")
public final String serialize(OutputStream out) throws IOException {
if (!subclassSerializersInitiated) {
throw new IllegalStateException(
"The method BaseModel.loadArtifactSerializers() was not called by BaseModel subclass constructor.");
}
// The model ID is generated here in order to reduce the code changes
// necessary. If the modelId is generated external to this function then
// anywhere this function is called must be changed. I just want to
// minimize the number of code changes to upgrades to newer OpenNLP
// versions are simplified.
final String modelId = UUID.randomUUID().toString();
// Write the model ID to the model's properties.
setManifestProperty("model.id", modelId);
for (Entry entry : artifactMap.entrySet()) {
final String name = entry.getKey();
final Object artifact = entry.getValue();
if (artifact instanceof SerializableArtifact) {
SerializableArtifact serializableArtifact = (SerializableArtifact) artifact;
String artifactSerializerName = serializableArtifact
.getArtifactSerializerClass().getName();
setManifestProperty(SERIALIZER_CLASS_NAME_PREFIX + name,
artifactSerializerName);
}
}
ZipOutputStream zip = new ZipOutputStream(out);
for (Entry entry : artifactMap.entrySet()) {
String name = entry.getKey();
zip.putNextEntry(new ZipEntry(name));
Object artifact = entry.getValue();
ArtifactSerializer serializer = getArtifactSerializer(name);
// If model is serialize-able always use the provided serializer
if (artifact instanceof SerializableArtifact) {
SerializableArtifact serializableArtifact = (SerializableArtifact) artifact;
String artifactSerializerName =
serializableArtifact.getArtifactSerializerClass().getName();
serializer = ExtensionLoader.instantiateExtension(ArtifactSerializer.class, artifactSerializerName);
}
if (serializer == null) {
throw new IllegalStateException("Missing serializer for " + name);
}
serializer.serialize(artifactMap.get(name), zip);
zip.closeEntry();
}
zip.finish();
zip.flush();
return modelId;
}
public final void serialize(File model) throws IOException {
try (OutputStream out = new BufferedOutputStream(new FileOutputStream(model))) {
serialize(out);
}
}
public final void serialize(Path model) throws IOException {
serialize(model.toFile());
}
@SuppressWarnings("unchecked")
public T getArtifact(String key) {
Object artifact = artifactMap.get(key);
if (artifact == null)
return null;
return (T) artifact;
}
public boolean isLoadedFromSerialized() {
return isLoadedFromSerialized;
}
// These methods are required to serialize/deserialize the model because
// many of the included objects in this model are not Serializable.
// An alternative to this solution is to make all included objects
// Serializable and remove the writeObject and readObject methods.
// This will allow the usage of final for fields that should not change.
private void writeObject(EncryptedDataOutputStream out) throws IOException {
out.writeEncryptedUTF(componentName);
this.serialize(out);
}
private void readObject(final ObjectInputStream in) throws IOException {
isLoadedFromSerialized = true;
artifactSerializers = new HashMap<>();
artifactMap = new HashMap<>();
componentName = in.readUTF();
this.loadModel(in);
}
}