All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.tika.parser.captioning.tf.vocabulary.py Maven / Gradle / Ivy

There is a newer version: 2024.11.18751.20241128T090041Z-241100
Show newest version
#!/usr/bin/env python
#  Licensed to the Apache Software Foundation (ASF) under one
#  or more contributor license agreements.  See the NOTICE file
#  distributed with this work for additional information
#  regarding copyright ownership.  The ASF licenses this file
#  to you under the Apache License, Version 2.0 (the
#  "License"); you may not use this file except in compliance
#  with the License.  You may obtain a copy of the License at
#    http://www.apache.org/licenses/LICENSE-2.0
#  Unless required by applicable law or agreed to in writing,
#  software distributed under the License is distributed on an
#  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
#  KIND, either express or implied.  See the License for the
#  specific language governing permissions and limitations
#  under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf


class Vocabulary(object):
    """
        Vocabulary class for an image-to-text model
    """

    def __init__(self,
                 vocab_file,
                 start_word="",
                 end_word="",
                 unk_word=""):
        """Initializes the vocabulary"""

        if not tf.gfile.Exists(vocab_file):
            tf.logging.fatal("Vocab file %s not found.", vocab_file)
        tf.logging.info("Initializing vocabulary from file: %s", vocab_file)

        with tf.gfile.GFile(vocab_file, mode="r") as f:
            reverse_vocab = list(f.readlines())
        reverse_vocab = [line.split()[0] for line in reverse_vocab]
        assert start_word in reverse_vocab
        assert end_word in reverse_vocab
        if unk_word not in reverse_vocab:
            reverse_vocab.append(unk_word)
        vocab = dict([(x, y) for (y, x) in enumerate(reverse_vocab)])

        tf.logging.info("Created vocabulary with %d words" % len(vocab))

        # vocab[word] = id
        self.vocab = vocab
        # reverse_vocab[id] = word
        self.reverse_vocab = reverse_vocab

        # save special word ids
        self.start_id = vocab[start_word]
        self.end_id = vocab[end_word]
        self.unk_id = vocab[unk_word]

    def id_to_word(self, word_id):
        """Returns the word string of an integer word id"""

        if word_id >= len(self.reverse_vocab):
            return self.reverse_vocab[self.unk_id]
        else:
            return self.reverse_vocab[word_id]




© 2015 - 2024 Weber Informatics LLC | Privacy Policy