com.mayabot.nlp.algorithm.collection.bintrie.BinTrieTree Maven / Gradle / Ivy
/*
/*
* Copyright 2018 mayabot.com authors. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mayabot.nlp.algorithm.collection.bintrie;
import com.mayabot.nlp.algorithm.collection.Trie;
import com.mayabot.nlp.common.hppc.CharObjectHashMap;
import com.mayabot.nlp.common.hppc.CharObjectMap;
import kotlin.collections.AbstractIterator;
import java.util.*;
import java.util.Map.Entry;
/**
* TireTree 实现。 如果为smart模式,那么首字也是二分法查找。子节点的数字如果大于阀值,那么使用65536的数组
*
* 第一次改造为了简单,使用专用的charMap来实现smart tree的行为。smart的树的节点,如果超过一定数量,该用65536的数组
*
* @param
* @author jimichan
*/
public class BinTrieTree implements Trie, BinTrieNode {
/**
* 区别是.这个在首字也是用二分查找,意味着,更节省内存.但是在构造和查找的时候都慢一点,一般应用在.词少.或者临时词典中.
*/
boolean rootChildUseMap = false;
static final int max_width = 65536;
private AbstractTrieNode[] children;
private CharObjectMap> childrenMap;
//boolean frezz = false;// 是否冻结
TrieNodeFactory nodeFactory = ArrayTrieNode::new;
BinTrieTree(boolean rootChildUseMap, TrieNodeFactory nodeFactory) {
this.rootChildUseMap = rootChildUseMap;
this.nodeFactory = nodeFactory;
reset();
}
/**
* 清空树释放内存
*/
@SuppressWarnings("unchecked")
public void reset() {
if (rootChildUseMap) {
childrenMap = new CharObjectHashMap<>(500);
} else {
children = new AbstractTrieNode[max_width];
}
}
/**
* 计算根节点的数量
*
* @return child count
*/
public int rootChildCount() {
if (childrenMap != null) {
return childrenMap.size();
} else {
int c = 0;
for (Object o : children) {
if (o != null) {
c++;
}
}
return c;
}
}
/**
* 创建一个该树的Matcher对象
*
* @param text
* @return TrieTreeMatcher
*/
public TrieTreeMatcher newForwardMatcher(String text) {
return new TrieTreeForwardMaxMatcher<>(this, text);
}
/**
* 创建一个该树的Matcher对象
*
* @param text
* @return TrieTreeMatcher
*/
public TrieTreeMatcher newAllMatcher(String text) {
return new TrieTreeAllMatcher<>(this, text);
}
@Override
public boolean containsKey(String key) {
BinTrieNode branch = this;
int len = key.length();
for (int i = 0; i < len; i++) {
char _char = key.charAt(i);
if (branch == null) {
return false;
}
branch = branch.findChild(_char);
}
if (branch == null) {
return false;
}
// 下面这句可以保证只有成词的节点被返回
return branch.getStatus() == AbstractTrieNode.Status_End || branch.getStatus() == AbstractTrieNode.Status_Continue;
}
@Override
public T get(char[] key) {
BinTrieNode branch = this;
int len = key.length;
for (int i = 0; i < len; i++) {
char _char = key[i];
if (branch == null) {
return null;
}
branch = branch.findChild(_char);
}
if (branch == null) {
return null;
}
// 下面这句可以保证只有成词的节点被返回
if (!(branch.getStatus() == AbstractTrieNode.Status_End || branch.getStatus() == AbstractTrieNode.Status_Continue)) {
return null;
}
return branch.getValue();
}
@Override
public T get(char[] key, int offset, int len) {
BinTrieNode branch = this;
for (int i = offset; i < len; i++) {
char _char = key[i];
if (branch == null) {
return null;
}
branch = branch.findChild(_char);
}
if (branch == null) {
return null;
}
// 下面这句可以保证只有成词的节点被返回
if (!(branch.getStatus() == AbstractTrieNode.Status_End || branch.getStatus() == AbstractTrieNode.Status_Continue)) {
return null;
}
return branch.getValue();
}
@Override
public T get(CharSequence key) {
BinTrieNode branch = findNode(key);
if (branch == null) {
return null;
}
// 下面这句可以保证只有成词的节点被返回
if (!(branch.getStatus() == AbstractTrieNode.Status_End || branch.getStatus() == AbstractTrieNode.Status_Continue)) {
return null;
}
return branch.getValue();
}
/**
* 插入一个词项
*
* @param word
* @param value 参数对象
*/
public void put(String word, T value) {
word = word.toLowerCase();
BinTrieNode point = this;
int len = word.length(); // 不用toCharArray的原因是不用再复制一份
int lenIndex = len - 1;
for (int i = 0; i < len; i++) {
char theChar = word.charAt(i);
if (lenIndex == i) {
point.addChildNode(nodeFactory.create(theChar, AbstractTrieNode.Status_End, value));
} else {
point.addChildNode(nodeFactory.create(theChar, AbstractTrieNode.Status_Begin, null));
}
point = point.findChild(theChar);
}
}
/**
* 前缀查询
*
* @param key 查询串
* @return 键值对
*/
public Set> prefixSearch(String key) {
BinTrieNode node = findNode(key);
if (node == null) {
return Collections.emptySet();
}
NodeHolder holder = new NodeHolder<>();
IteratorKeys ite = new IteratorKeys(holder, (AbstractTrieNode) node, key);
Set> set = new HashSet<>();
while (ite.hasNext()) {
String k = ite.next();
AbstractTrieNode v = holder.node;
set.add(new AbstractMap.SimpleEntry(k, v.value));
}
return set;
}
/**
* 删除一个词
*
* @param key
*/
public void remove(String key) {
BinTrieNode branch = this;
char[] chars = key.toCharArray();
for (int i = 0; i < chars.length; i++) {
if (branch == null) {
return;
}
if (chars.length == i + 1) {
branch.addChildNode(nodeFactory.create(chars[i], AbstractTrieNode.Status_Null, null));
}
branch = branch.findChild(chars[i]);
}
}
public interface TireNodeAccess {
void access(AbstractTrieNode node);
}
public void accessFullTireNode(TireNodeAccess nodeAccess) {
LinkedList> stack = new LinkedList<>();
// 初始化堆栈
if (childrenMap != null) {
for (AbstractTrieNode node : childrenMap.values()) {
stack.push(node);
}
} else {
for (AbstractTrieNode x : children) {
if (x != null) {
stack.push(x);
}
}
}
// 堆栈循环访问
while (!stack.isEmpty()) {
AbstractTrieNode node = stack.pop();
nodeAccess.access(node);
List> chl = node.getChildren();
if (chl != null) {
node.getChildren().forEach(x -> stack.push(x));// 改成放到栈顶
}
}
}
/**
* 访问所有的keys 词
*
* @return Iterator
*/
public Iterator keys(NodeHolder holder) {
return new IteratorKeys(holder);
}
public Iterator keys() {
return new IteratorKeys(null);
}
public static class NodeHolder {
AbstractTrieNode node;
public AbstractTrieNode getNode() {
return node;
}
}
public Iterable> entry() {
return () -> new AbstractIterator>() {
NodeHolder holder;
Iterator ite;
{
holder = new NodeHolder<>();
ite = keys(holder);
}
@Override
protected void computeNext() {
if (!ite.hasNext()) {
done();
return;
}
String key = ite.next();
if (key != null) {
setNext(new AbstractMap.SimpleEntry<>(key, holder.node.value));
return;
}
done();
return;
}
};
}
class IteratorKeys extends AbstractIterator {
LinkedList> stack = new LinkedList<>();
char[] buffer = new char[Short.MAX_VALUE];
NodeHolder holder;
IteratorKeys(NodeHolder holder) {
this.holder = holder;
// 初始化堆栈
if (childrenMap != null) {
for (AbstractTrieNode node : childrenMap.values()) {
stack.push(node);
}
} else {
for (AbstractTrieNode x : children) {
if (x != null) {
stack.push(x);
}
}
}
stack.forEach(x -> x.level = 1);
}
IteratorKeys(NodeHolder holder, AbstractTrieNode initNode, String prefix) { //指定了初始化节点
this.holder = holder;
stack.push(initNode);
stack.forEach(x -> x.level = (short) prefix.length()); //FIXME 此处需要多测试
if (prefix.length() >= 2) {
for (int i = 0; i < prefix.length() - 1; i++) {
this.buffer[1 + i] = prefix.charAt(i);
}
}
}
@Override
protected void computeNext() {
if (stack.isEmpty()) {
done();
return;
}
String n = _next();
while (n == null) {
n = _next();
if (n != null) {
setNext(n);
return;
}
if (stack.isEmpty()) {
done();
return;
}
}
setNext(n);
}
private String _next() {
AbstractTrieNode node = stack.pop();
buffer[node.level] = node._char;
List> chl = node.getChildren();
if (chl != null) {
short level = (short) (node.level + 1);
chl.forEach(x -> {
x.level = level;
stack.push(x);
});
}
if (node.status == AbstractTrieNode.Status_Continue || node.status == AbstractTrieNode.Status_End) {
if (holder != null) {
holder.node = node;
}
return new String(buffer, 1, node.level);
}
return null;
}
}
// //////////////////////////////以下是作为NODE的行为////////////////////////////////////////////
@Override
public boolean contains(char c) {
if (rootChildUseMap) {
return childrenMap.containsKey(c);
} else {
return this.children[c] != null;
}
}
@Override
public BinTrieNode addChildNode(BinTrieNode n) {
AbstractTrieNode node = ((AbstractTrieNode) n);
AbstractTrieNode oldNode = null;
if (rootChildUseMap) {
oldNode = this.childrenMap.get(node._char);
if (oldNode == null) {
this.childrenMap.put(node._char, node);
oldNode = node;
}
} else {
oldNode = this.children[node._char];
if (oldNode == null) {
this.children[node._char] = node;
oldNode = node;
}
}
switch (node.status) {
case AbstractTrieNode.Status_Begin:
if (oldNode.status == AbstractTrieNode.Status_End) {
oldNode.status = AbstractTrieNode.Status_Continue;
}
break;
case AbstractTrieNode.Status_End:
if (oldNode.status == AbstractTrieNode.Status_Begin) {
oldNode.status = AbstractTrieNode.Status_Continue;
}
oldNode.value = node.value;
}
return oldNode;
}
@Override
public AbstractTrieNode findChild(char c) {
if (rootChildUseMap) {
return childrenMap.get(c);
} else {
if (c > max_width) {
return null;
}
return this.children[c];
}
}
@Override
public byte getStatus() {
return 0;
}
@Override
public T getValue() {
return null;
}
@Override
public int compareTo(char c) {
return 0;
}
public CharObjectMap> getChildrenMap() {
return childrenMap;
}
public interface TrieNodeFactory {
BinTrieNode create(char _char, byte status, T param);
}
}