All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.chungkwong.classifier.SvmClassifierFactory Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (C) 2018 Chan Chung Kwong
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see .
 */
package com.github.chungkwong.classifier;
import com.github.chungkwong.classifier.util.*;
import de.bwaldvogel.liblinear.*;
import java.util.*;
import java.util.stream.*;
/**
 *
 * Factory for SVM classifier
 * @author Chan Chung Kwong
 * @param  the type of the objects to be classified
 */
public class SvmClassifierFactory extends StreamClassifierFactory>,DocumentVectorsModel,T>{
	private TfIdfFormula tfIdfFormula=TfIdfFormula.STANDARD;
	private Parameter parameter=new Parameter(SolverType.L2R_L2LOSS_SVC_DUAL,1,0.1);
	/**
	 * Create a SVM classifier factory
	 */
	public SvmClassifierFactory(){
	}
	/**
	 * @return parameters of liblinear
	 */
	public Parameter getParameter(){
		return parameter;
	}
	/**
	 * Set parameters of liblinear
	 * @param parameter
	 * @return this
	 */
	public SvmClassifierFactory setParameter(Parameter parameter){
		this.parameter=parameter;
		return this;
	}
	/**
	 * Set TF-IDF formula
	 * @param tfIdfFormula TF-IDF formula
	 * @return
	 */
	public SvmClassifierFactory setTfIdfFormula(TfIdfFormula tfIdfFormula){
		this.tfIdfFormula=tfIdfFormula;
		return this;
	}
	/**
	 * @return TF-IDF formula
	 */
	public TfIdfFormula getTfIdfFormula(){
		return tfIdfFormula;
	}
	
	@Override
	public Classifier> createClassifier(DocumentVectorsModel model){
		ImmutableFrequencies totalDocumentFrequencies=model.getTotalDocumentFrequencies();
		Problem problem=new Problem();
		problem.l=(int)model.getSampleCount();
		problem.n=totalDocumentFrequencies.getTokenCount();
		
		Map tokenIndex=new HashMap<>();
		int sampleCount=(int)model.getSampleCount();
		problem.y=new double[sampleCount];
		problem.x=new Feature[sampleCount][];
		int i=0,j=0;
		Iterator>> iterator=model.getProfiles().entrySet().iterator();
		while(iterator.hasNext()){
			Map.Entry> next=iterator.next();
			for(ImmutableFrequencies sample:next.getValue().getDocumentVectors()){
				problem.y[i]=j;
				problem.x[i]=toFeatureArray(sample,tokenIndex,totalDocumentFrequencies,sampleCount,tfIdfFormula);
				++i;
			}
			++j;
		}
		return new SvmClassifier<>(Linear.train(problem,parameter),
				tokenIndex,totalDocumentFrequencies,sampleCount,tfIdfFormula,
				model.getProfiles().keySet().toArray(new Category[0]));
	}
	@Override
	public DocumentVectorsModel createModel(){
		return new DocumentVectorsModel<>();
	}
	private static  Feature[] toFeatureArray(ImmutableFrequencies object,Map tokenIndex,
			ImmutableFrequencies documentFrequencies,long documentCount,TfIdfFormula formula){
		Feature[] features=new Feature[object.getTokenCount()];
		int i=0;
		for(Map.Entry e:object.toMap().entrySet()){
			T token=e.getKey();
			Integer index=tokenIndex.get(token);
			if(index==null){
				index=tokenIndex.size()+1;
				tokenIndex.put(token,index);
			}
			features[i++]=new FeatureNode(index,formula.calculate(e.getValue(),documentFrequencies.getFrequency(token),documentCount));
		}
		double factor=0;
		for(Feature feature:features)
			factor+=feature.getValue()*feature.getValue();
		factor=Math.sqrt(factor);
		for(Feature feature:features)
			feature.setValue(feature.getValue()/factor);
		Arrays.sort(features,(f,g)->Integer.compare(f.getIndex(),g.getIndex()));
		return features;	
	}
	private static class SvmClassifier implements Classifier>{
		private final Model model;
		private final ImmutableFrequencies documentFrequencies;
		private final long documentCount;
		private final TfIdfFormula tfIdfFormula;
		private final Map tokenIndex;
		private final Category[] categories;
		public SvmClassifier(Model model,Map tokenIndex,
				ImmutableFrequencies documentFrequencies,long documentCount,
				TfIdfFormula tfIdfFormula,Category[] categories){
			this.model=model;
			this.tokenIndex=tokenIndex;
			this.categories=categories;
			this.documentCount=documentCount;
			this.documentFrequencies=documentFrequencies;
			this.tfIdfFormula=tfIdfFormula;
		}
		@Override
		public List getCandidates(Stream object,int max){
			Feature[] features=toFeatureArray(object,tokenIndex,documentFrequencies,documentCount,tfIdfFormula);
			int categoryIndex=(int)(Linear.predict(model,features)+0.5);
			if(categoryIndex>=0&&categoryIndex Feature[] toFeatureArray(Stream tokens,Map tokenIndex,
				ImmutableFrequencies documentFrequencies,long documentCount,TfIdfFormula formula){
			ImmutableFrequencies object=new ImmutableFrequencies<>(tokens);
			Feature[] features=object.toMap().entrySet().stream().filter((e)->tokenIndex.containsKey(e.getKey())).
					map((e)->new FeatureNode(tokenIndex.get(e.getKey()),formula.calculate(e.getValue(),documentFrequencies.getFrequency(e.getKey()),documentCount))).toArray(Feature[]::new);
			double factor=0;
			for(Feature feature:features)
				factor+=feature.getValue()*feature.getValue();
			factor=Math.sqrt(factor);
			for(Feature feature:features)
				feature.setValue(feature.getValue()/factor);
			Arrays.sort(features,(f,g)->Integer.compare(f.getIndex(),g.getIndex()));
			return features;	
		}
	}
	@Override
	protected String getName(){
		return "SVM";
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy