com.mayabot.nlp.common.matrix.TransformMatrix Maven / Gradle / Ivy
Show all versions of mynlp Show documentation
/*
* Copyright 2018 mayabot.com authors. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mayabot.nlp.common.matrix;
import com.google.common.base.Splitter;
import com.google.common.collect.ArrayTable;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Maps;
import com.google.common.io.ByteSource;
import com.mayabot.nlp.common.QuickStringDoubleTable;
import com.mayabot.nlp.common.QuickStringIntTable;
import com.mayabot.nlp.resources.NlpResource;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* 概率转移矩阵
* 这是一个通用的数据结构
*
*
* 放弃了HanLp中使用Enum的做法,直接使用了string。
*
*
* @author jimichan
*/
public class TransformMatrix {
/**
* 储存转移矩阵
*/
private QuickStringIntTable matrix;
/**
* 储存每个标签出现的次数
*/
private ImmutableMap total;
/**
* 所有标签出现的总次数
*/
private long totalFrequency;
// HMM的五元组
/**
* 隐状态
*/
public ImmutableList states;
/**
* 初始概率
*/
public ImmutableMap start_probability;
/**
* 转移概率
*/
public QuickStringDoubleTable transititon_probability;
public double getTP(String a, String b) {
double d = transititon_probability.get(a, b);
if (d == Double.MIN_VALUE) {
return 0;
}
return d;
}
/**
* 获取转移频次
*
* @param from
* @param to
* @return int
*/
public int getFrequency(String from, String to) {
int v = matrix.get(from, to);
if (v == Integer.MIN_VALUE) {
return 0;
}
return v;
}
public boolean load(ByteSource source) throws IOException {
try (InputStream inputStream = source.openBufferedStream()) {
return load(inputStream);
}
}
public boolean load(NlpResource resource) throws IOException {
try (InputStream inputStream = resource.inputStream()) {
return load(inputStream);
}
}
public boolean load(InputStream in) throws IOException {
Splitter splitter = Splitter.on(',').trimResults();
try (BufferedReader br = new BufferedReader(new InputStreamReader(in,
StandardCharsets.UTF_8))) {
String firstLine = br.readLine();
List lablist = splitter.splitToList(firstLine);
// 为了制表方便,第一个label是空白,所以要抹掉它
// //之后就描述了矩阵
ArrayTable matrix = ArrayTable.create(lablist.subList(1, lablist.size()),
lablist.subList(1, lablist.size()));
String line;
while ((line = br.readLine()) != null) {
List paramArray = splitter.splitToList(line);
String row_lable = paramArray.get(0);
Map row = matrix.row(row_lable);
for (int i = 1; i < paramArray.size(); i++) {
row.put(lablist.get(i), Integer.parseInt(paramArray.get(i)));
}
}
this.matrix = new QuickStringIntTable(matrix);
tongji(matrix);
}
return true;
}
private void tongji(ArrayTable matrix) {
// 需要统计一下每个标签出现的次数
HashMap _total = Maps.newHashMap();
for (String label : matrix.rowKeyList()) {
long v1 = 0, v2 = 0;
for (int x : matrix.row(label).values()) {
v1 += x;
}
for (int x : matrix.column(label).values()) {
v2 += x;
}
_total.put(label, v1 + v2 - matrix.get(label, label));
}
this.total = ImmutableMap.copyOf(_total);
// 总计频率
long _tf = 0;
for (long x : this.total.values()) {
_tf += x;
}
totalFrequency = _tf;
// 下面计算HMM四元组
// 状态标签数组
states = matrix.rowKeyList();
// 初始概率
HashMap _start_probability = Maps.newHashMap();
for (String label : states) {
double frequency = total.get(label) + 1e-8;
_start_probability
.put(label, -Math.log(frequency / totalFrequency));
}
this.start_probability = ImmutableMap.copyOf(_start_probability);
ArrayTable transititon_probability = ArrayTable.create(matrix.rowKeyList(),
matrix.columnKeyList());
for (String from : states) {
for (String to : states) {
double frequency = matrix.get(from, to) + 1e-8;
transititon_probability.put(from, to,
-Math.log(frequency / total.get(from)));
}
}
this.transititon_probability = new QuickStringDoubleTable(transititon_probability);
}
/**
* 获取e的总频次
*
* @param from
* @return long
*/
public long getTotalFrequency(String from) {
Long v = total.get(from);
if (v == null) {
// FIXME 这里会不会有问题
return 0;
}
return v.longValue();
}
/**
* 获取所有标签的总频次
*
* @return long
*/
public long getTotalFrequency() {
return totalFrequency;
}
}