All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.mayabot.nlp.classification.FasttextClassification Maven / Gradle / Ivy

There is a newer version: 3.1.7
Show newest version
package com.mayabot.nlp.classification;

import com.google.common.base.CharMatcher;
import com.google.common.base.Splitter;
import com.google.common.collect.Lists;
import com.mayabot.mynlp.fasttext.*;
import com.mayabot.nlp.segment.LexerReader;
import com.mayabot.nlp.segment.Lexers;
import org.jetbrains.annotations.NotNull;

import java.io.File;
import java.util.List;

/**
 * 使用Fasttext算法的文本分类工具
 * 

* 文本文件格式 * * @author jimichan */ public class FasttextClassification { /** * 训练Fasttext模型 */ public static FastText train(File file, int dim, double learnRate, int epoch) throws Exception { TrainArgs trainArgs = new TrainArgs(); trainArgs.setDim(dim); trainArgs.setLr(learnRate); trainArgs.setEpoch(epoch); return train(file, trainArgs); } /** * 训练Fasttext模型 */ public static FastText train(File file, TrainArgs trainArgs) throws Exception { Splitter splitter = Splitter.on(CharMatcher.whitespace()).omitEmptyStrings().trimResults(); LexerReader analyzer = Lexers.coreBuilder().build().filterReader(true, true); FileTrainExampleSource source = new FileTrainExampleSource( new WordSplitter() { @NotNull @Override public List split(String text) { List result = Lists.newArrayList(); splitter.split(text).forEach(part -> { if (part.startsWith("__label__")) { result.add(part); } else { for (String word : analyzer.scan(part).toWordSequence()) { result.add(word); } } }); return result; } } , file ); return FastText.train(source, ModelName.sup, trainArgs); } /** * 使用乘积量化压缩Fasttext模型,稍微损失一些精度. * * @param fastText * @return FastText * @throws Exception */ public static FastText compress(FastText fastText) throws Exception { return FastText.quantize(fastText); } /** * 便捷方法,返回预测的第一个结果 * * @param model * @param tokenizer 分词器 * @param text 需要分析的原始文本 * @return 分类目标. null表示没有结果 */ public static List predict(FastText model, LexerReader tokenizer, String text, int top) { List inputList = Lists.newArrayList(tokenizer.scan(text).toWordSequence()); if (inputList.isEmpty()) { return null; } return model.predict(inputList, top); } /** * 预测一个唯一的目标,要求概率大于0.5f * * @param model * @param tokenizer * @param text * @return FloatStringPair */ public static FloatStringPair predictOne(FastText model, LexerReader tokenizer, String text) { List inputList = Lists.newArrayList(tokenizer.scan(text).toWordSequence()); if (inputList.isEmpty()) { return null; } List list = model.predict(inputList, 5); if (list.isEmpty()) { return null; } FloatStringPair first = list.get(0); if (first.first < 0.5f) { return null; } else { return first; } } private static LexerReader lexerReader; public static List predict(FastText model, String text, int top) { if (lexerReader == null) { lexerReader = Lexers.coreBuilder().build().filterReader(true, true); } return predict(model, lexerReader, text, top); } public static FloatStringPair predictOne(FastText model, String text) { if (lexerReader == null) { lexerReader = Lexers.coreBuilder().build().filterReader(true, true); } return predictOne(model, lexerReader, text); } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy