
com.github.chungkwong.classifier.SvmClassifierFactory Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of text-classifier-collection Show documentation
Show all versions of text-classifier-collection Show documentation
A full fledged text classification toolkit for Java
The newest version!
/*
* Copyright (C) 2018 Chan Chung Kwong
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
package com.github.chungkwong.classifier;
import com.github.chungkwong.classifier.util.*;
import de.bwaldvogel.liblinear.*;
import java.util.*;
import java.util.stream.*;
/**
*
* Factory for SVM classifier
* @author Chan Chung Kwong
* @param the type of the objects to be classified
*/
public class SvmClassifierFactory extends StreamClassifierFactory>,DocumentVectorsModel,T>{
private TfIdfFormula tfIdfFormula=TfIdfFormula.STANDARD;
private Parameter parameter=new Parameter(SolverType.L2R_L2LOSS_SVC_DUAL,1,0.1);
/**
* Create a SVM classifier factory
*/
public SvmClassifierFactory(){
}
/**
* @return parameters of liblinear
*/
public Parameter getParameter(){
return parameter;
}
/**
* Set parameters of liblinear
* @param parameter
* @return this
*/
public SvmClassifierFactory setParameter(Parameter parameter){
this.parameter=parameter;
return this;
}
/**
* Set TF-IDF formula
* @param tfIdfFormula TF-IDF formula
* @return
*/
public SvmClassifierFactory setTfIdfFormula(TfIdfFormula tfIdfFormula){
this.tfIdfFormula=tfIdfFormula;
return this;
}
/**
* @return TF-IDF formula
*/
public TfIdfFormula getTfIdfFormula(){
return tfIdfFormula;
}
@Override
public Classifier> createClassifier(DocumentVectorsModel model){
ImmutableFrequencies totalDocumentFrequencies=model.getTotalDocumentFrequencies();
Problem problem=new Problem();
problem.l=(int)model.getSampleCount();
problem.n=totalDocumentFrequencies.getTokenCount();
Map tokenIndex=new HashMap<>();
int sampleCount=(int)model.getSampleCount();
problem.y=new double[sampleCount];
problem.x=new Feature[sampleCount][];
int i=0,j=0;
Iterator>> iterator=model.getProfiles().entrySet().iterator();
while(iterator.hasNext()){
Map.Entry> next=iterator.next();
for(ImmutableFrequencies sample:next.getValue().getDocumentVectors()){
problem.y[i]=j;
problem.x[i]=toFeatureArray(sample,tokenIndex,totalDocumentFrequencies,sampleCount,tfIdfFormula);
++i;
}
++j;
}
return new SvmClassifier<>(Linear.train(problem,parameter),
tokenIndex,totalDocumentFrequencies,sampleCount,tfIdfFormula,
model.getProfiles().keySet().toArray(new Category[0]));
}
@Override
public DocumentVectorsModel createModel(){
return new DocumentVectorsModel<>();
}
private static Feature[] toFeatureArray(ImmutableFrequencies object,Map tokenIndex,
ImmutableFrequencies documentFrequencies,long documentCount,TfIdfFormula formula){
Feature[] features=new Feature[object.getTokenCount()];
int i=0;
for(Map.Entry e:object.toMap().entrySet()){
T token=e.getKey();
Integer index=tokenIndex.get(token);
if(index==null){
index=tokenIndex.size()+1;
tokenIndex.put(token,index);
}
features[i++]=new FeatureNode(index,formula.calculate(e.getValue(),documentFrequencies.getFrequency(token),documentCount));
}
double factor=0;
for(Feature feature:features)
factor+=feature.getValue()*feature.getValue();
factor=Math.sqrt(factor);
for(Feature feature:features)
feature.setValue(feature.getValue()/factor);
Arrays.sort(features,(f,g)->Integer.compare(f.getIndex(),g.getIndex()));
return features;
}
private static class SvmClassifier implements Classifier>{
private final Model model;
private final ImmutableFrequencies documentFrequencies;
private final long documentCount;
private final TfIdfFormula tfIdfFormula;
private final Map tokenIndex;
private final Category[] categories;
public SvmClassifier(Model model,Map tokenIndex,
ImmutableFrequencies documentFrequencies,long documentCount,
TfIdfFormula tfIdfFormula,Category[] categories){
this.model=model;
this.tokenIndex=tokenIndex;
this.categories=categories;
this.documentCount=documentCount;
this.documentFrequencies=documentFrequencies;
this.tfIdfFormula=tfIdfFormula;
}
@Override
public List getCandidates(Stream object,int max){
Feature[] features=toFeatureArray(object,tokenIndex,documentFrequencies,documentCount,tfIdfFormula);
int categoryIndex=(int)(Linear.predict(model,features)+0.5);
if(categoryIndex>=0&&categoryIndex Feature[] toFeatureArray(Stream tokens,Map tokenIndex,
ImmutableFrequencies documentFrequencies,long documentCount,TfIdfFormula formula){
ImmutableFrequencies object=new ImmutableFrequencies<>(tokens);
Feature[] features=object.toMap().entrySet().stream().filter((e)->tokenIndex.containsKey(e.getKey())).
map((e)->new FeatureNode(tokenIndex.get(e.getKey()),formula.calculate(e.getValue(),documentFrequencies.getFrequency(e.getKey()),documentCount))).toArray(Feature[]::new);
double factor=0;
for(Feature feature:features)
factor+=feature.getValue()*feature.getValue();
factor=Math.sqrt(factor);
for(Feature feature:features)
feature.setValue(feature.getValue()/factor);
Arrays.sort(features,(f,g)->Integer.compare(f.getIndex(),g.getIndex()));
return features;
}
}
@Override
protected String getName(){
return "SVM";
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy