All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.ansj.app.crf.Model Maven / Gradle / Ivy

There is a newer version: 5.1.6
Show newest version
package org.ansj.app.crf;

import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectOutputStream;
import java.util.Map;
import java.util.Map.Entry;
import java.util.zip.GZIPOutputStream;

import org.ansj.app.crf.model.CRFModel;
import org.ansj.app.crf.model.CRFppTxtModel;
import org.ansj.app.crf.model.WapitiCRFModel;
import org.ansj.util.MyStaticValue;
import org.nlpcn.commons.lang.tire.domain.SmartForest;
import org.nlpcn.commons.lang.util.MapCount;
import org.nlpcn.commons.lang.util.logging.Log;
import org.nlpcn.commons.lang.util.logging.LogFactory;

public abstract class Model {

	public static final Log logger = LogFactory.getLog(Model.class);

	protected Config config;

	protected SmartForest featureTree = null;

	protected float[][] status = new float[Config.TAG_NUM][Config.TAG_NUM];

	public int allFeatureCount = 0;

	/**
	 * 判断当前数据流是否是本实例
	 * 
	 * @param is
	 * @return
	 */
	public abstract boolean checkModel(String modelPath) throws IOException;

	/**
	 * 模型读取
	 * 
	 * @param path
	 * @return
	 * @return
	 * @throws Exception
	 */
	public static Model load(String modelPath) throws Exception {
		Model model = new CRFModel();
		if (model.checkModel(modelPath)) {
			return model.loadModel(modelPath);
		}
		model = new CRFppTxtModel();

		if (model.checkModel(modelPath)) {
			return model.loadModel(modelPath);
		}
		model = new WapitiCRFModel();
		if (model.checkModel(modelPath)) {
			return model.loadModel(modelPath);
		}
		throw new Exception("I did not know what type of model by file " + modelPath);
	}

	/**
	 * 模型读取
	 * 
	 * @param path
	 * @return
	 * @return
	 * @throws Exception
	 */
	public static Model load(Class c, InputStream is) throws Exception {
		Model model = c.newInstance();
		return model.loadModel(is);
	}

	/**
	 * 不同的模型实现自己的加载模型类
	 * 
	 * @throws Exception
	 */
	public abstract Model loadModel(String modelPath) throws Exception;

	public abstract Model loadModel(InputStream is) throws Exception;

	/**
	 * 获得特征所在权重数组
	 * 
	 * @param featureStr
	 * @return
	 */
	public float[] getFeature(char... chars) {
		if (chars == null) {
			return null;
		}
		SmartForest sf = featureTree;
		sf = sf.getBranch(chars);
		if (sf == null || sf.getParam() == null) {
			return null;
		}
		return sf.getParam();
	}

	public Config getConfig() {
		return this.config;
	}

	/**
	 * tag转移率
	 * 
	 * @param s1
	 * @param s2
	 * @return
	 */
	public float tagRate(int s1, int s2) {
		return status[s1][s2];
	}

	/**
	 * 增加特征到特征数中
	 * 
	 * @param cs
	 * @param tempW
	 */
	protected static void printFeatureTree(String cs, float[] tempW) {
		String name = "*";
		if (tempW.length == 4) {
			name = "U";
		}
		name += "*" + ((int) cs.charAt(cs.length() - 1) - Config.FEATURE_BEGIN + 1) + ":" + cs.substring(0, cs.length() - 1);
		for (int i = 0; i < tempW.length; i++) {
			if (tempW[i] != 0) {
				System.out.println(name + "\t" + Config.getTagName(i / 4 - 1) + "\t" + Config.getTagName(i % 4) + "\t" + tempW[i]);
			}

		}
	}

	/**
	 * 将model序列化到硬盘
	 * 
	 * @param path
	 * @throws IOException
	 * @throws FileNotFoundException
	 */
	public void writeModel(String path) {
		try (FileOutputStream fso = new FileOutputStream(path)) {
			ObjectOutputStream oos = new ObjectOutputStream(new GZIPOutputStream(fso));
			oos.writeUTF(CRFModel.version);
			oos.writeObject(status);
			oos.writeObject(config.getTemplate());
			Map map = featureTree.toMap();
			MapCount mc = new MapCount();
			for (float[] v : map.values()) {
				mc.add(v.length);
			}
			for (Entry entry : mc.get().entrySet()) {
				int win = entry.getKey();
				oos.writeInt(win);// 宽度
				oos.writeInt(entry.getValue().intValue());// 个数
				for (Entry e : map.entrySet()) {
					if (e.getValue().length == win) {
						oos.writeUTF(e.getKey());
						float[] value = e.getValue();
						for (int i = 0; i < win; i++) {
							oos.writeFloat(value[i]);
						}
					}
				}
			}
			oos.writeInt(0);
			oos.writeInt(0);
			oos.flush();
		} catch (FileNotFoundException e) {
			logger.warn("文件没有找到", e);
		} catch (IOException e) {
			logger.warn("IO异常", e);
		}
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy