All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.tika.detect.NNExampleModelDetector Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.tika.detect;

import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.net.URL;
import java.nio.file.Path;
import java.util.Objects;
import java.util.logging.Level;
import java.util.logging.Logger;

import org.apache.tika.mime.MediaType;

import static java.nio.charset.StandardCharsets.UTF_8;

public class NNExampleModelDetector extends TrainedModelDetector {
    private static final String EXAMPLE_NNMODEL_FILE = "tika-example.nnmodel";

    private static final long serialVersionUID = 1L;

    private static final Logger log = Logger.getLogger(NNExampleModelDetector.class.getName());

    public NNExampleModelDetector() {
        super();
    }

    public NNExampleModelDetector(final Path modelFile) {
        loadDefaultModels(modelFile);
    }

    public NNExampleModelDetector(final File modelFile) {
        loadDefaultModels(modelFile);
    }

    @Override
    public void loadDefaultModels(InputStream modelStream) {
        BufferedReader bReader = new BufferedReader(new InputStreamReader(modelStream, UTF_8));

        NNTrainedModelBuilder nnBuilder = new NNTrainedModelBuilder();
        String line;
        try {
            while ((line = bReader.readLine()) != null) {
                line = line.trim();
                if (line.startsWith("#")) {
                    readDescription(nnBuilder, line);
                } else {
                    readNNParams(nnBuilder, line);
                    // add this model into map of trained models.
                    super.registerModels(nnBuilder.getType(), nnBuilder.build());
                }

            }
        } catch (IOException e) {
            throw new RuntimeException("Unable to read the default media type registry", e);
        }
    }

    /**
     * this method gets overwritten to register load neural network models
     */
    @Override
    public void loadDefaultModels(ClassLoader classLoader) {
        if (classLoader == null) {
            classLoader = TrainedModelDetector.class.getClassLoader();
        }

        // This allows us to replicate class.getResource() when using
        // the classloader directly
        String classPrefix = TrainedModelDetector.class.getPackage().getName()
                .replace('.', '/')
                + "/";

        // Get the core URL, and all the extensions URLs
        URL modelURL = classLoader.getResource(classPrefix + EXAMPLE_NNMODEL_FILE);
        Objects.requireNonNull(modelURL, "required resource " + classPrefix + EXAMPLE_NNMODEL_FILE + " not found");
        try (InputStream stream = modelURL.openStream()) {
            loadDefaultModels(stream);
        } catch (IOException e) {
            throw new RuntimeException("Unable to read the default media type registry", e);
        }

    }

    /**
     * read the comments where the model configuration is written, e.g the
     * number of inputs, hiddens and output please ensure the first char in the
     * given string is # In this example grb model file, there are 4 elements 1)
     * type 2) number of input units 3) number of hidden units. 4) number of
     * output units.
     */
    private void readDescription(final NNTrainedModelBuilder builder,
                                 final String line) {
        int numInputs;
        int numHidden;
        int numOutputs;
        String[] sarr = line.split("\t");

        try {
            MediaType type = MediaType.parse(sarr[1]);
            numInputs = Integer.parseInt(sarr[2]);
            numHidden = Integer.parseInt(sarr[3]);
            numOutputs = Integer.parseInt(sarr[4]);
            builder.setNumOfInputs(numInputs);
            builder.setNumOfHidden(numHidden);
            builder.setNumOfOutputs(numOutputs);
            builder.setType(type);
        } catch (Exception e) {
            if (log.isLoggable(Level.WARNING)) {
                log.log(Level.WARNING, "Unable to parse the model configuration", e);
            }
            throw new RuntimeException("Unable to parse the model configuration", e);
        }
    }

    /**
     * Read the next line for the model parameters and populate the build which
     * later will be used to instantiate the instance of TrainedModel
     *
     * @param builder
     * @param line
     */
    private void readNNParams(final NNTrainedModelBuilder builder,
                              final String line) {
        String[] sarr = line.split("\t");
        int n = sarr.length;
        float[] params = new float[n];
        try {
            int i = 0;
            for (String fstr : sarr) {
                params[i] = Float.parseFloat(fstr);
                i++;
            }
            builder.setParams(params);
        } catch (Exception e) {
            if (log.isLoggable(Level.WARNING)) {
                log.log(Level.WARNING, "Unable to parse the model configuration", e);
            }
            throw new RuntimeException("Unable to parse the model configuration", e);
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy