All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.ansj.app.crf.model.CRFModel Maven / Gradle / Ivy

There is a newer version: 5.1.6
Show newest version
package org.ansj.app.crf.model;

import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.util.zip.GZIPInputStream;
import java.util.zip.ZipException;

import org.ansj.app.crf.Config;
import org.ansj.app.crf.Model;
import org.nlpcn.commons.lang.tire.domain.SmartForest;
import org.nlpcn.commons.lang.util.IOUtil;

/**
 * 加载ansj格式的crfmodel,目前此model格式是通过crf++ 或者wapiti生成的
 * 
 * @author Ansj
 *
 */
public class CRFModel extends Model {

	public static final String version = "ansj1";

	public CRFModel(String name) {
		super(name);
	}

	@Override
	public void loadModel(String modelPath) throws Exception {
		loadModel(IOUtil.getInputStream(modelPath));
	}

	public void loadModel(InputStream is) throws Exception {
		long start = System.currentTimeMillis();
		try (ObjectInputStream ois = new ObjectInputStream(new GZIPInputStream(is))) {
			ois.readUTF();
			this.status = (float[][]) ois.readObject();
			int[][] template = (int[][]) ois.readObject();
			this.config = new Config(template);
			int win = 0;
			int size = 0;
			String name = null;
			featureTree = new SmartForest();
			float[] value = null;
			do {
				win = ois.readInt();
				size = ois.readInt();
				for (int i = 0; i < size; i++) {
					name = ois.readUTF();
					value = new float[win];
					for (int j = 0; j < value.length; j++) {
						value[j] = ois.readFloat();
					}
					featureTree.add(name, value);
				}
			} while (win == 0 || size == 0);
			logger.info("load crf model ok ! use time :" + (System.currentTimeMillis() - start));
		}
	}

	@Override
	public boolean checkModel(String modelPath) {
		try (FileInputStream fis = new FileInputStream(modelPath)) {
			ObjectInputStream inputStream = new ObjectInputStream(new GZIPInputStream(fis));
			String version = inputStream.readUTF();
			if (version.equals("ansj1")) { // 加载ansj,model
				return true;
			}
		} catch (ZipException ze) {
			logger.warn("解压异常", ze);
		} catch (FileNotFoundException e) {
			logger.warn("文件没有找到", e);
		} catch (IOException e) {
			logger.warn("IO异常", e);
		}
		return false;
	}

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy