All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.arosbio.ml.algorithms.svm.LinearSVC Maven / Gradle / Ivy

Go to download

Conformal AI package, including all data IO, transformations, machine learning models and predictor classes. Without inclusion of chemistry-dependent code.

There is a newer version: 2.0.0
Show newest version
/*
 * Copyright (C) Aros Bio AB.
 *
 * CPSign is an Open Source Software that is dual licensed to allow you to choose a license that best suits your requirements:
 *
 * 1) GPLv3 (GNU General Public License Version 3) with Additional Terms, including an attribution clause as well as a limitation to use the software for commercial purposes.
 *
 * 2) CPSign Proprietary License that allows you to use CPSign for commercial activities, such as in a revenue-generating operation or environment, or integrate CPSign in your proprietary software without worrying about disclosing the source code of your proprietary software, which is required if you choose to use the software under GPLv3 license. See arosbio.com/cpsign/commercial-license for details.
 */
package com.arosbio.ml.algorithms.svm;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.Arrays;
import java.util.EnumSet;
import java.util.List;
import java.util.Map;
import java.util.Random;

import com.arosbio.commons.GlobalConfig;
import com.arosbio.commons.config.EnumConfig;
import com.arosbio.data.DataRecord;
import com.arosbio.data.FeatureVector;
import com.arosbio.ml.algorithms.MultiLabelClassifier;
import com.arosbio.ml.algorithms.impl.DefaultMLParameterSettings;
import com.arosbio.ml.algorithms.impl.LibLinear;

import de.bwaldvogel.liblinear.Model;
import de.bwaldvogel.liblinear.Parameter;
import de.bwaldvogel.liblinear.SolverType;

public class LinearSVC implements SVC, MultiLabelClassifier {

	public static final String ALG_NAME = "LinearSVC";
	public static final int ALG_ID = 11;
	public final static EnumSet ALLOWED_SOLVERS = EnumSet.of(
			SolverType.L1R_L2LOSS_SVC,
			SolverType.L2R_L1LOSS_SVC_DUAL, 
			SolverType.L2R_L2LOSS_SVC,
			SolverType.L2R_L2LOSS_SVC_DUAL);
	public final static SolverType DEFAULT_SOLVER = SolverType.L2R_L2LOSS_SVC;

	/**
	 * Parameters that holds all info
	 */
	private Parameter parameters = LibLinear.getDefaultParams(DEFAULT_SOLVER);
	private Model svm;

	public LinearSVC(){
		this.parameters.setRandom(new Random(GlobalConfig.getInstance().getRNGSeed()));
	}

	public static EnumSet getAllowedSolvers(){
		return ALLOWED_SOLVERS;
	}
	
	@Override
	public String getName() {
		return ALG_NAME;
	}

	@Override
	public int getID() {
		return ALG_ID;
	}
	
	@Override
	public String toString() {
		return ALG_NAME;
	}
	
	@Override
	public String getDescription() {
		return "Support Vector Classification (SVC) implemented in LIBLINEAR. Restricted to a linear kernel and optimized for fast training and predictions for linear kernel SVM. Should be prefered over LIBSVM implementation with a linear kernel.";
	}

	// Cost
	public double getC() {
		return parameters.getC();
	}

	public void setC(double cost) {
		parameters.setC(cost);
	}

	public LinearSVC withC(double cost){
		parameters.setC(cost);
		return this;
	}

	// Epsilon
	public double getEpsilon() {
		return parameters.getEps();
	}

	public void setEpsilon(double eps) {
		parameters.setEps(eps);
	}

	public LinearSVC withEpsilon(double eps){
		parameters.setEps(eps);
		return this;
	}

	// Solver type
	public SolverType getSolverType() {
		return parameters.getSolverType();
	}

	public void setSolverType(SolverType type) {
		if (!ALLOWED_SOLVERS.contains(type)) {
			throw new IllegalArgumentException("SolverType not allowed for algorithm " + ALG_NAME);
		} 
		parameters.setSolverType(type);
	}

	public LinearSVC withSolverType(SolverType type){
		setSolverType(type);
		return this;
	}

	// Max iterations
	public int getMaxNumIterations() {
		return parameters.getMaxIters();
	}
	
	public void setMaxNumIterations(int maxIterations) {
		if (maxIterations < 1)
			parameters.setMaxIters(LibLinear.DEFAULT_MAX_ITERATIONS);
		else
			parameters.setMaxIters(maxIterations);
	}

	public LinearSVC withMaxNumIterations(int maxIterations){
		setMaxNumIterations(maxIterations);
		return this;
	}

	/**
	 * Cannot access the seed from LibLinear, this method returns {@code null}
	 * @return {@code null}
	 */
	@Override
	public Long getSeed() {
		return null;
	}
	
	@Override
	public void setSeed(long seed) {
		this.parameters.setRandom(new Random(seed));
	}

	public LinearSVC withSeed(long seed){
		setSeed(seed);
		return this;
	}

	@Override
	public boolean isFitted() {
		return svm!=null;
	}

	@Override
	public Map getProperties() {
		Map prop = LibLinear.toProperties(parameters);
		prop.put(ML_NAME_PARAM_KEY, ALG_NAME);
		prop.put(ML_ID_PARAM_KEY, ALG_ID);
		return prop;
	}
	
	@Override
	public List getLabels() {
		return LibLinear.getLabels(svm);
	}

	@Override
	public LinearSVC clone() {
		LinearSVC clone = new LinearSVC();
		// Only copy the actual parameters 
		clone.parameters = parameters.clone();
		return clone;
	}

	@Override
	public void setConfigParameters(Map params) throws IllegalStateException, IllegalArgumentException {
		Parameter clone = parameters.clone();
		LibLinear.setConfigParameters(clone, ALLOWED_SOLVERS, params);
		parameters = clone; 
	}

	@Override
	public List getConfigParameters(){
		return Arrays.asList(
			DefaultMLParameterSettings.COST_CONFIG,
			DefaultMLParameterSettings.EPSILON_CONFIG,
			new EnumConfig.Builder<>(
				LibLinear.SOLVER_TYPE_PARAM_NAMES,
				ALLOWED_SOLVERS,DEFAULT_SOLVER).build());
	}
	
	/* 
	 * =================================================
	 * 			TRAIN
	 * =================================================
	 */
	
	@Override
	public void train(List trainingset) throws IllegalArgumentException {
		svm = LibLinear.train(parameters, trainingset);
	}

	@Override
	public void fit(List trainingset) throws IllegalArgumentException {
		svm = LibLinear.train(parameters, trainingset);
	}
	
	/* 
	 * =================================================
	 * 			PREDICTIONS
	 * =================================================
	 */

	@Override
	public int predictClass(FeatureVector feature) throws IllegalStateException {
		return LibLinear.predictClass(svm,feature);
	}

	@Override
	public Map predictScores(FeatureVector example) throws IllegalStateException {
		return LibLinear.predictDistanceToHyperplane(svm,example);
	}

	@Override
	public Map predictDistanceToHyperplane(FeatureVector example) throws IllegalStateException {
		return LibLinear.predictDistanceToHyperplane(svm,example);
	}
	
	/* 
	 * =================================================
	 * 			I/O
	 * =================================================
	 */

	@Override
	public void saveToStream(OutputStream ostream) throws IOException, IllegalStateException {
		LibLinear.saveToStream(svm, ostream);
	}

	@Override
	public void loadFromStream(InputStream istream) throws IOException {
		svm = LibLinear.loadFromStream(istream);
	}

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy