All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.tika.parser.captioning.tf.model_wrapper.py Maven / Gradle / Ivy

There is a newer version: 2024.11.18751.20241128T090041Z-241100
Show newest version
#!/usr/bin/env python
#  Licensed to the Apache Software Foundation (ASF) under one
#  or more contributor license agreements.  See the NOTICE file
#  distributed with this work for additional information
#  regarding copyright ownership.  The ASF licenses this file
#  to you under the Apache License, Version 2.0 (the
#  "License"); you may not use this file except in compliance
#  with the License.  You may obtain a copy of the License at
#    http://www.apache.org/licenses/LICENSE-2.0
#  Unless required by applicable law or agreed to in writing,
#  software distributed under the License is distributed on an
#  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
#  KIND, either express or implied.  See the License for the
#  specific language governing permissions and limitations
#  under the License.


from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os.path

import tensorflow as tf
from tensorflow.contrib.slim.python.slim.nets.inception_v3 import inception_v3_base

slim = tf.contrib.slim


class ModelWrapper(object):
    """
        Model wrapper class to perform image captioning with a ShowAndTellModel
    """

    def __init__(self):
        super(ModelWrapper, self).__init__()

    def build_graph(self, checkpoint_path):
        """Builds the inference graph"""

        tf.logging.info("Building model.")
        ShowAndTellModel().build()
        saver = tf.train.Saver()

        return self._create_restore_fn(checkpoint_path, saver)

    def _create_restore_fn(self, checkpoint_path, saver):
        """Creates a function that restores a model from checkpoint file"""

        if tf.gfile.IsDirectory(checkpoint_path):
            checkpoint_path = tf.train.latest_checkpoint(checkpoint_path)
            if not checkpoint_path:
                raise ValueError("No checkpoint file found in: %s" % checkpoint_path)

        def _restore_fn(sess):
            tf.logging.info("Loading model from checkpoint: %s", checkpoint_path)
            saver.restore(sess, checkpoint_path)
            tf.logging.info("Successfully loaded checkpoint: %s",
                            os.path.basename(checkpoint_path))

        return _restore_fn

    def feed_image(self, sess, encoded_image):
        initial_state = sess.run(fetches="lstm/initial_state:0",
                                 feed_dict={"image_feed:0": encoded_image})
        return initial_state

    def inference_step(self, sess, input_feed, state_feed):
        softmax_output, state_output = sess.run(
            fetches=["softmax:0", "lstm/state:0"],
            feed_dict={
                "input_feed:0": input_feed,
                "lstm/state_feed:0": state_feed,
            })
        return softmax_output, state_output


class ShowAndTellModel(object):
    """
        Image captioning implementation based on the paper,

        "Show and Tell: A Neural Image Caption Generator"
        Oriol Vinyals, Alexander Toshev, Samy Bengio, Dumitru Erhan

        For more details, please visit : http://arxiv.org/abs/1411.4555
    """

    def __init__(self):

        # scale used to initialize model variables
        self.initializer_scale = 0.08

        # dimensions of Inception v3 input images
        self.image_height = 299
        self.image_width = 299

        # LSTM input and output dimensionality, respectively
        self.embedding_size = 512
        self.num_lstm_units = 512

        # number of unique words in the vocab (plus 1, for )
        # the default value is larger than the expected actual vocab size to allow
        # for differences between tokenizer versions used in preprocessing, there is
        # no harm in using a value greater than the actual vocab size, but using a
        # value less than the actual vocab size will result in an error
        self.vocab_size = 12000

        # reader for the input data
        self.reader = tf.TFRecordReader()

        # to match the "Show and Tell" paper we initialize all variables with a
        # random uniform initializer
        self.initializer = tf.random_uniform_initializer(
            minval=-self.initializer_scale,
            maxval=self.initializer_scale)

        # a float32 Tensor with shape [batch_size, height, width, channels]
        self.images = None

        # an int32 Tensor with shape [batch_size, padded_length]
        self.input_seqs = None

        # an int32 Tensor with shape [batch_size, padded_length]
        self.target_seqs = None

        # an int32 0/1 Tensor with shape [batch_size, padded_length]
        self.input_mask = None

        # a float32 Tensor with shape [batch_size, embedding_size]
        self.image_embeddings = None

        # a float32 Tensor with shape [batch_size, padded_length, embedding_size]
        self.seq_embeddings = None

        # collection of variables from the inception submodel
        self.inception_variables = []

        # global step Tensor
        self.global_step = None

    def process_image(self, encoded_image, resize_height=346, resize_width=346, thread_id=0):
        """Decodes and processes an image string"""

        # helper function to log an image summary to the visualizer. Summaries are
        # only logged in thread 0
        def image_summary(name, img):
            if not thread_id:
                tf.summary.image(name, tf.expand_dims(img, 0))

        # decode image into a float32 Tensor of shape [?, ?, 3] with values in [0, 1)
        with tf.name_scope("decode", values=[encoded_image]):
            image = tf.image.decode_jpeg(encoded_image, channels=3)

        image = tf.image.convert_image_dtype(image, dtype=tf.float32)
        image_summary("original_image", image)

        # resize image
        assert (resize_height > 0) == (resize_width > 0)
        if resize_height:
            image = tf.image.resize_images(image,
                                           size=[resize_height, resize_width],
                                           method=tf.image.ResizeMethod.BILINEAR)

        # central crop, assuming resize_height > height, resize_width > width
        image = tf.image.resize_image_with_crop_or_pad(image, self.image_height, self.image_width)

        image_summary("resized_image", image)

        image_summary("final_image", image)

        # rescale to [-1,1] instead of [0, 1]
        image = tf.subtract(image, 0.5)
        image = tf.multiply(image, 2.0)
        return image

    def build_inputs(self):
        """Input prefetching, preprocessing and batching"""

        image_feed = tf.placeholder(dtype=tf.string, shape=[], name="image_feed")
        input_feed = tf.placeholder(dtype=tf.int64,
                                    shape=[None],  # batch_size
                                    name="input_feed")

        # process image and insert batch dimensions
        images = tf.expand_dims(self.process_image(image_feed), 0)
        input_seqs = tf.expand_dims(input_feed, 1)

        # no target sequences or input mask in inference mode
        target_seqs = None
        input_mask = None

        self.images = images
        self.input_seqs = input_seqs
        self.target_seqs = target_seqs
        self.input_mask = input_mask

    def build_image_embeddings(self):
        """Builds the image model(Inception V3) subgraph and generates image embeddings"""

        # parameter initialization
        batch_norm_params = {
            "is_training": False,
            "trainable": False,
            # decay for the moving averages
            "decay": 0.9997,
            # epsilon to prevent 0s in variance
            "epsilon": 0.001,
            # collection containing the moving mean and moving variance
            "variables_collections": {
                "beta": None,
                "gamma": None,
                "moving_mean": ["moving_vars"],
                "moving_variance": ["moving_vars"],
            }
        }

        stddev = 0.1,
        dropout_keep_prob = 0.8

        with tf.variable_scope("InceptionV3", "InceptionV3", [self.images]) as scope:
            with slim.arg_scope(
                    [slim.conv2d, slim.fully_connected],
                    weights_regularizer=None,
                    trainable=False):
                with slim.arg_scope(
                        [slim.conv2d],
                        weights_initializer=tf.truncated_normal_initializer(stddev=stddev),
                        activation_fn=tf.nn.relu,
                        normalizer_fn=slim.batch_norm,
                        normalizer_params=batch_norm_params):
                    net, end_points = inception_v3_base(self.images, scope=scope)
                    with tf.variable_scope("logits"):
                        shape = net.get_shape()
                        net = slim.avg_pool2d(net, shape[1:3], padding="VALID", scope="pool")
                        net = slim.dropout(
                            net,
                            keep_prob=dropout_keep_prob,
                            is_training=False,
                            scope="dropout")
                        net = slim.flatten(net, scope="flatten")

        # add summaries
        for v in end_points.values():
            tf.contrib.layers.summaries.summarize_activation(v)

        self.inception_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="InceptionV3")

        # map inception output(net) into embedding space
        with tf.variable_scope("image_embedding") as scope:
            image_embeddings = tf.contrib.layers.fully_connected(
                inputs=net,
                num_outputs=self.embedding_size,
                activation_fn=None,
                weights_initializer=self.initializer,
                biases_initializer=None,
                scope=scope)

        # save the embedding size in the graph
        tf.constant(self.embedding_size, name="embedding_size")

        self.image_embeddings = image_embeddings

    def build_seq_embeddings(self):
        """Builds the input sequence embeddings"""

        with tf.variable_scope("seq_embedding"), tf.device("/cpu:0"):
            embedding_map = tf.get_variable(
                name="map",
                shape=[self.vocab_size, self.embedding_size],
                initializer=self.initializer)
            seq_embeddings = tf.nn.embedding_lookup(embedding_map, self.input_seqs)

        self.seq_embeddings = seq_embeddings

    def build_model(self):

        # this LSTM cell has biases and outputs tanh(new_c) * sigmoid(o), but the
        # modified LSTM in the "Show and Tell" paper has no biases and outputs
        # new_c * sigmoid(o).

        lstm_cell = tf.contrib.rnn.BasicLSTMCell(
            num_units=self.num_lstm_units, state_is_tuple=True)

        with tf.variable_scope("lstm", initializer=self.initializer) as lstm_scope:
            # feed the image embeddings to set the initial LSTM state
            zero_state = lstm_cell.zero_state(
                batch_size=self.image_embeddings.get_shape()[0], dtype=tf.float32)
            _, initial_state = lstm_cell(self.image_embeddings, zero_state)

            # allow the LSTM variables to be reused
            lstm_scope.reuse_variables()

            # because this is inference mode,
            # use concatenated states for convenient feeding and fetching
            tf.concat(axis=1, values=initial_state, name="initial_state")

            # placeholder for feeding a batch of concatenated states
            state_feed = tf.placeholder(dtype=tf.float32,
                                        shape=[None, sum(lstm_cell.state_size)],
                                        name="state_feed")
            state_tuple = tf.split(value=state_feed, num_or_size_splits=2, axis=1)

            # run a single LSTM step
            lstm_outputs, state_tuple = lstm_cell(
                inputs=tf.squeeze(self.seq_embeddings, axis=[1]),
                state=state_tuple)

            # concatentate the resulting state
            tf.concat(axis=1, values=state_tuple, name="state")

        # stack batches vertically
        lstm_outputs = tf.reshape(lstm_outputs, [-1, lstm_cell.output_size])

        with tf.variable_scope("logits") as logits_scope:
            logits = tf.contrib.layers.fully_connected(
                inputs=lstm_outputs,
                num_outputs=self.vocab_size,
                activation_fn=None,
                weights_initializer=self.initializer,
                scope=logits_scope)

        tf.nn.softmax(logits, name="softmax")

    def setup_global_step(self):
        """Sets up the global step Tensor"""

        global_step = tf.Variable(
            initial_value=0,
            name="global_step",
            trainable=False,
            collections=[tf.GraphKeys.GLOBAL_STEP, tf.GraphKeys.GLOBAL_VARIABLES])

        self.global_step = global_step

    def build(self):
        self.build_inputs()
        self.build_image_embeddings()
        self.build_seq_embeddings()
        self.build_model()
        self.setup_global_step()




© 2015 - 2024 Weber Informatics LLC | Privacy Policy