All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.atilika.kuromoji.viterbi.ViterbiSearcher Maven / Gradle / Ivy

There is a newer version: 1.0.0-beta7
Show newest version
/*-*
 * Copyright © 2010-2015 Atilika Inc. and contributors (see CONTRIBUTORS.md)
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may
 * not use this file except in compliance with the License.  A copy of the
 * License is distributed with this work in the LICENSE.md file.  You may
 * also obtain a copy of the License from
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.atilika.kuromoji.viterbi;

import com.atilika.kuromoji.TokenizerBase;
import com.atilika.kuromoji.dict.ConnectionCosts;
import com.atilika.kuromoji.dict.UnknownDictionary;

import java.util.LinkedList;
import java.util.List;

public class ViterbiSearcher {

    private static final int DEFAULT_COST = Integer.MAX_VALUE;

    private final ConnectionCosts costs;
    private final UnknownDictionary unknownDictionary;

    private int kanjiPenaltyLengthTreshold;
    private int otherPenaltyLengthThreshold;
    private int kanjiPenalty;
    private int otherPenalty;

    private final TokenizerBase.Mode mode;

    public ViterbiSearcher(TokenizerBase.Mode mode, ConnectionCosts costs, UnknownDictionary unknownDictionary,
                    List penalties) {
        if (!penalties.isEmpty()) {
            this.kanjiPenaltyLengthTreshold = penalties.get(0);
            this.kanjiPenalty = penalties.get(1);
            this.otherPenaltyLengthThreshold = penalties.get(2);
            this.otherPenalty = penalties.get(3);
        }

        this.mode = mode;
        this.costs = costs;
        this.unknownDictionary = unknownDictionary;
    }

    /**
     * Find best path from input lattice.
     *
     * @param lattice the result of build method
     * @return List of ViterbiNode which consist best path
     */
    public List search(ViterbiLattice lattice) {

        ViterbiNode[][] endIndexArr = calculatePathCosts(lattice);
        LinkedList result = backtrackBestPath(endIndexArr[0][0]);

        return result;
    }

    private ViterbiNode[][] calculatePathCosts(ViterbiLattice lattice) {
        ViterbiNode[][] startIndexArr = lattice.getStartIndexArr();
        ViterbiNode[][] endIndexArr = lattice.getEndIndexArr();

        for (int i = 1; i < startIndexArr.length; i++) {

            if (startIndexArr[i] == null || endIndexArr[i] == null) { // continue since no array which contains ViterbiNodes exists. Or no previous node exists.
                continue;
            }

            for (ViterbiNode node : startIndexArr[i]) {
                if (node == null) { // If array doesn't contain ViterbiNode any more, continue to next index
                    break;
                }

                updateNode(endIndexArr[i], node);
            }
        }
        return endIndexArr;
    }

    private void updateNode(ViterbiNode[] viterbiNodes, ViterbiNode node) {
        int backwardConnectionId = node.getLeftId();
        int wordCost = node.getWordCost();
        int leastPathCost = DEFAULT_COST;

        for (ViterbiNode leftNode : viterbiNodes) {
            // If array doesn't contain any more ViterbiNodes, continue to next index
            if (leftNode == null) {
                return;
            } else {
                // cost = [total cost from BOS to previous node] + [connection cost between previous node and current node] + [word cost]
                int pathCost = leftNode.getPathCost() + costs.get(leftNode.getRightId(), backwardConnectionId)
                                + wordCost;

                // Add extra cost for long nodes in "Search mode".
                if (mode == TokenizerBase.Mode.SEARCH || mode == TokenizerBase.Mode.EXTENDED) {
                    pathCost += getPenaltyCost(node);
                }

                // If total cost is lower than before, set current previous node as best left node (previous means left).
                if (pathCost < leastPathCost) {
                    leastPathCost = pathCost;
                    node.setPathCost(leastPathCost);
                    node.setLeftNode(leftNode);
                }
            }
        }
    }

    private int getPenaltyCost(ViterbiNode node) {
        int pathCost = 0;
        String surface = node.getSurface();
        int length = surface.length();

        if (length > kanjiPenaltyLengthTreshold) {
            if (isKanjiOnly(surface)) { // Process only Kanji keywords
                pathCost += (length - kanjiPenaltyLengthTreshold) * kanjiPenalty;
            } else if (length > otherPenaltyLengthThreshold) {
                pathCost += (length - otherPenaltyLengthThreshold) * otherPenalty;
            }
        }
        return pathCost;
    }

    private boolean isKanjiOnly(String surface) {
        for (int i = 0; i < surface.length(); i++) {
            char c = surface.charAt(i);

            if (Character.UnicodeBlock.of(c) != Character.UnicodeBlock.CJK_UNIFIED_IDEOGRAPHS) {
                return false;
            }
        }
        return true;
    }

    private LinkedList backtrackBestPath(ViterbiNode eos) {
        ViterbiNode node = eos;
        LinkedList result = new LinkedList<>();

        result.add(node);

        while (true) {
            ViterbiNode leftNode = node.getLeftNode();

            if (leftNode == null) {
                break;
            } else {
                // Extended mode converts unknown word into unigram nodes
                if (mode == TokenizerBase.Mode.EXTENDED && leftNode.getType() == ViterbiNode.Type.UNKNOWN) {
                    LinkedList uniGramNodes = convertUnknownWordToUnigramNode(leftNode);
                    result.addAll(uniGramNodes);
                } else {
                    result.addFirst(leftNode);
                }
                node = leftNode;
            }
        }
        return result;
    }

    private LinkedList convertUnknownWordToUnigramNode(ViterbiNode node) {
        LinkedList uniGramNodes = new LinkedList<>();
        int unigramWordId = 0;
        String surface = node.getSurface();

        for (int i = surface.length(); i > 0; i--) {
            String word = surface.substring(i - 1, i);
            int startIndex = node.getStartIndex() + i - 1;

            ViterbiNode uniGramNode = new ViterbiNode(unigramWordId, word, unknownDictionary, startIndex,
                            ViterbiNode.Type.UNKNOWN);
            uniGramNodes.addFirst(uniGramNode);
        }

        return uniGramNodes;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy