All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.tika.parser.captioning.tf.im2txtapi.py Maven / Gradle / Ivy

#!/usr/bin/env python
#  Licensed to the Apache Software Foundation (ASF) under one
#  or more contributor license agreements.  See the NOTICE file
#  distributed with this work for additional information
#  regarding copyright ownership.  The ASF licenses this file
#  to you under the Apache License, Version 2.0 (the
#  "License"); you may not use this file except in compliance
#  with the License.  You may obtain a copy of the License at
#    http://www.apache.org/licenses/LICENSE-2.0
#  Unless required by applicable law or agreed to in writing,
#  software distributed under the License is distributed on an
#  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
#  KIND, either express or implied.  See the License for the
#  specific language governing permissions and limitations
#  under the License.


"""
    This script exposes image captioning service over a REST API. Image captioning implementation based on the paper,

        "Show and Tell: A Neural Image Caption Generator"
        Oriol Vinyals, Alexander Toshev, Samy Bengio, Dumitru Erhan

    For more details, please visit :
        http://arxiv.org/abs/1411.4555
    Requirements :
      Flask
      tensorflow
      numpy
      requests
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import json
import logging
import math
import requests
import sys

from flask import Flask, request, Response, jsonify
from io import BytesIO
from PIL import Image
from time import time

import tensorflow as tf
import xml.etree.ElementTree as ET

import model_wrapper
import vocabulary
import caption_generator

# turning off the traceback by limiting its depth
sys.tracebacklimit = 0

# informative log messages for advanced users to troubleshoot errors when modifying model_info.xml
try:
    info = ET.parse('/usr/share/apache-tika/models/dl/image/caption/model_info.xml').getroot()
except IOError:
    logging.exception('model_info.xml is not found')
    sys.exit(1)

model_main = info.find('model_main')
if model_main is None:
    logging.exception(' tag under  tag in model_info.xml is not found')
    sys.exit(1)

checkpoint_path = model_main.find('checkpoint_path')
if checkpoint_path is None:
    logging.exception(' tag under  tag in model_info.xml is not found')
    sys.exit(1)
else:
    checkpoint_path = checkpoint_path.text

vocab_file = model_main.find('vocab_file')
if vocab_file is None:
    logging.exception(' tag under  tag in model_info.xml is not found')
    sys.exit(1)
else:
    vocab_file = vocab_file.text

port = info.get('port')
if port is None:
    logging.exception('port attribute in  tag in model_info.xml is not found')
    sys.exit(1)

# turning on the traceback by setting it to default
sys.tracebacklimit = 1000

FLAGS = tf.flags.FLAGS
tf.flags.DEFINE_string("checkpoint_path", checkpoint_path, """Directory containing the model checkpoint file.""")
tf.flags.DEFINE_string('vocab_file', vocab_file, """Text file containing the vocabulary.""")
tf.flags.DEFINE_integer('port', port, """Server PORT, default:8764""")

tf.logging.set_verbosity(tf.logging.INFO)


class Initializer(Flask):
    """
        Class to initialize the REST API, this class loads the model from the given checkpoint path in model_info.xml
        and prepares a caption_generator object
    """

    def __init__(self, name):
        super(Initializer, self).__init__(name)
        # build the inference graph
        g = tf.Graph()
        with g.as_default():
            model = model_wrapper.ModelWrapper()
            restore_fn = model.build_graph(FLAGS.checkpoint_path)
        g.finalize()
        # make the model globally available
        self.model = model
        # create the vocabulary
        self.vocab = vocabulary.Vocabulary(FLAGS.vocab_file)
        self.sess = tf.Session(graph=g)
        # load the model from checkpoint
        restore_fn(self.sess)


def current_time():
    """Returns current time in milli seconds"""

    return int(1000 * time())


app = Initializer(__name__)


def get_remote_file(url, success=200, timeout=10):
    """
        Given HTTP URL, this api gets the content of it
        returns (Content-Type, image_content)
    """
    try:
        app.logger.info("GET: %s" % url)
        auth = None
        res = requests.get(url, stream=True, timeout=timeout, auth=auth)
        if res.status_code == success:
            return res.headers.get('Content-Type', 'application/octet-stream'), res.raw.data
    except:
        pass
    return None, None


@app.route("/")
def index():
    """The index page which provide information about other API end points"""

    return """
    

Image Captioning REST API

The following API end points are valid

    Inception V3

  • /inception/v3/ping -
    Description : checks availability of the service. returns "pong" with status 200 when it is available
  • /inception/v3/caption/image -
    Description This is a service that can caption images
    How to supply Image Content
    With HTTP GET : Include a query parameter url which is an http url of JPEG image
    Example: curl "localhost:8764/inception/v3/caption/image?url=http://xyz.com/example.jpg"
    With HTTP POST : POST JPEG image content as binary data in request body.
    Example: curl -X POST "localhost:8764/inception/v3/caption/image" --data-binary @example.jpg
""" @app.route("/inception/v3/ping", methods=["GET"]) def ping_pong(): """API to do health check. If this says status code 200, then healthy""" return "pong" @app.route("/inception/v3/caption/image", methods=["GET", "POST"]) def caption_image(): """API to caption images""" image_format = "not jpeg" st = current_time() # get beam_size beam_size = int(request.args.get("beam_size", "3")) # get max_caption_length max_caption_length = int(request.args.get("max_caption_length", "20")) # get image_data if request.method == 'POST': image_data = request.get_data() else: url = request.args.get("url") c_type, image_data = get_remote_file(url) if not image_data: return Response(status=400, response=jsonify(error="Could not HTTP GET %s" % url)) if 'image/jpeg' in c_type: image_format = "jpeg" # use c_type to find whether image_format is jpeg or not # if jpeg, don't convert if image_format == "jpeg": jpg_image = image_data # if not jpeg else: # open the image from raw bytes image = Image.open(BytesIO(image_data)) # convert the image to RGB format, otherwise will give errors when converting to jpeg, if the image isn't RGB rgb_image = image.convert("RGB") # convert the RGB image to jpeg image_bytes = BytesIO() rgb_image.save(image_bytes, format="jpeg", quality=95) jpg_image = image_bytes.getvalue() image_bytes.close() read_time = current_time() - st # restart counter st = current_time() generator = caption_generator.CaptionGenerator(app.model, app.vocab, beam_size=beam_size, max_caption_length=max_caption_length) captions = generator.beam_search(app.sess, jpg_image) captioning_time = current_time() - st app.logger.info("Captioning time : %d" % captioning_time) array_captions = [] for caption in captions: sentence = [app.vocab.id_to_word(w) for w in caption.sentence[1:-1]] sentence = " ".join(sentence) array_captions.append({ 'sentence': sentence, 'confidence': math.exp(caption.logprob) }) response = { 'beam_size': beam_size, 'max_caption_length': max_caption_length, 'captions': array_captions, 'time': { 'read': read_time, 'captioning': captioning_time, 'units': 'ms' } } return Response(response=json.dumps(response), status=200, mimetype="application/json") def main(_): if not app.debug: print("Serving on port %d" % FLAGS.port) app.run(host="0.0.0.0", port=FLAGS.port) if __name__ == '__main__': tf.app.run()




© 2015 - 2024 Weber Informatics LLC | Privacy Policy