com.arosbio.ml.algorithms.impl.LibLinear Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of confai Show documentation
Show all versions of confai Show documentation
Conformal AI package, including all data IO, transformations, machine learning models and predictor classes. Without inclusion of chemistry-dependent code.
/*
* Copyright (C) Aros Bio AB.
*
* CPSign is an Open Source Software that is dual licensed to allow you to choose a license that best suits your requirements:
*
* 1) GPLv3 (GNU General Public License Version 3) with Additional Terms, including an attribution clause as well as a limitation to use the software for commercial purposes.
*
* 2) CPSign Proprietary License that allows you to use CPSign for commercial activities, such as in a revenue-generating operation or environment, or integrate CPSign in your proprietary software without worrying about disclosing the source code of your proprietary software, which is required if you choose to use the software under GPLv3 license. See arosbio.com/cpsign/commercial-license for details.
*/
package com.arosbio.ml.algorithms.impl;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.io.Reader;
import java.io.Writer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.arosbio.commons.CollectionUtils;
import com.arosbio.commons.StringUtils;
import com.arosbio.commons.TypeUtils;
import com.arosbio.data.DataRecord;
import com.arosbio.data.DataUtils;
import com.arosbio.data.FeatureVector;
import com.arosbio.data.MissingDataException;
import com.arosbio.data.MissingValueFeature;
import de.bwaldvogel.liblinear.Feature;
import de.bwaldvogel.liblinear.FeatureNode;
import de.bwaldvogel.liblinear.Linear;
import de.bwaldvogel.liblinear.Model;
import de.bwaldvogel.liblinear.Parameter;
import de.bwaldvogel.liblinear.Problem;
import de.bwaldvogel.liblinear.SolverType;
/**
* Wrapper class for LIBLINEAR (https://www.csie.ntu.edu.tw/~cjlin/liblinear/)
*
* @author staffan
*
*/
public class LibLinear {
private static final Logger LOGGER = LoggerFactory.getLogger(LibLinear.class);
public final static int DEFAULT_MAX_ITERATIONS = 1000;
private static final double bias = 1; // The bias term
// // TUNABLE PARAMETERS
public static final List SOLVER_TYPE_PARAM_NAMES = Arrays.asList("solverType", "solver");
public static final List MAX_ITERATIONS_PARAM_NAMES = Arrays.asList("maxIterations");
// Remove logging
static {
Linear.setDebugOutput(null);
}
// Should never instantiate this class
private LibLinear() {}
public static Parameter getDefaultParams(SolverType type) {
return new Parameter(type,
DefaultMLParameterSettings.DEFAULT_C,
DefaultMLParameterSettings.DEFAULT_EPSILON,
DEFAULT_MAX_ITERATIONS,
DefaultMLParameterSettings.DEFAULT_SVR_EPSILON);
}
public static Map toProperties(Parameter p){
Map props = new HashMap<>();
props.put(DefaultMLParameterSettings.COST_PARAM_NAMES.get(0), p.getC());
props.put(DefaultMLParameterSettings.EPSILON_PARAM_NAMES.get(0), p.getEps());
props.put(DefaultMLParameterSettings.SVR_EPSILON_PARAM_NAMES.get(0), p.getP());
props.put(SOLVER_TYPE_PARAM_NAMES.get(0), p.getSolverType().getId());
props.put(MAX_ITERATIONS_PARAM_NAMES.get(0), p.getMaxIters());
return props;
}
public static List getLabels(Model model){
if (model==null)
return new ArrayList<>();
try{
List labels = new ArrayList<>();
for(int l: model.getLabels()) {
labels.add(l);
}
return labels;
} catch (NullPointerException npe){
return new ArrayList<>();
}
}
public static void setConfigParameters(Parameter original, EnumSet allowedSolvers, Map params)
throws IllegalArgumentException {
for (Map.Entry p : params.entrySet()) {
try {
if (CollectionUtils.containsIgnoreCase(DefaultMLParameterSettings.COST_PARAM_NAMES, p.getKey())) {
original.setC(TypeUtils.asDouble(p.getValue()));
} else if (CollectionUtils.containsIgnoreCase(DefaultMLParameterSettings.EPSILON_PARAM_NAMES, p.getKey())) {
original.setEps(TypeUtils.asDouble(p.getValue()));
} else if (CollectionUtils.containsIgnoreCase(DefaultMLParameterSettings.SVR_EPSILON_PARAM_NAMES, p.getKey())) {
original.setP(TypeUtils.asDouble(p.getValue()));
} else if (CollectionUtils.containsIgnoreCase(SOLVER_TYPE_PARAM_NAMES, p.getKey())) {
try {
SolverType newType = null;
if (TypeUtils.isInt(p.getValue())) {
newType = SolverType.getById(TypeUtils.asInt(p.getValue()));
} else if (p.getValue() instanceof String) {
newType = SolverType.valueOf(p.getValue().toString());
} else if (p.getValue() instanceof SolverType) {
newType = (SolverType) p.getValue();
} else {
throw new IllegalArgumentException("Parameter '"+p.getKey()+"' cannot be interpreted as a proper SolverType, was: " + p.getValue());
}
// Verify that it's an allowed solvertype
if (! allowedSolvers.contains(newType)) {
throw new IllegalArgumentException("Parameter '"+p.getKey()+"' not allowed to take value: " + p.getValue());
}
original.setSolverType(newType);
} catch (Exception e) {
throw new IllegalArgumentException(e.getMessage());
}
}
// Fall through on parameters that are not used
} catch (Exception e) {
LOGGER.debug("Failed setting parameter {} with value: {}",p.getKey(), p.getValue());
throw new IllegalArgumentException("Invalid argument for parameter '" + p.getKey() + "': " + e.getMessage());
}
}
}
/*
* =================================================
* TRAINING
* =================================================
*/
public static Model train(Parameter params, List trainingSet) throws IllegalArgumentException{
return train(params, createLibLinearTrainProblem(trainingSet));
}
public static Model train(Parameter params, Problem problem) throws IllegalArgumentException {
if (problem.l == 0)
throw new IllegalArgumentException("Training set cannot be empty");
LOGGER.trace("Training liblinear model with #records={}, #attributes={}, using parameters={}",
problem.l,problem.n,params.toString());
// Do the training!
Model model = Linear.train(problem, params);
LOGGER.debug("Finished training the linear model");
return model;
}
/*
* =================================================
* UTILS
* =================================================
*/
public static Problem createLibLinearTrainProblem(
List trainingset) {
LOGGER.debug("trainingset.size={}", trainingset.size());
int maxFeatIndex = DataUtils.getMaxFeatureIndex(trainingset) + 1; // +1 for indices starting at 1 instead of 0
int biasColumn = maxFeatIndex+1;
Problem trainProblem = new Problem();
trainProblem.l = trainingset.size();
trainProblem.n = biasColumn; // add 1 for bias term
trainProblem.x = new Feature[trainProblem.l][];
trainProblem.y = new double[trainProblem.l];
trainProblem.bias = bias;
try {
for (int ex=0; ex < trainProblem.l; ex++) {
// Copy the target value
trainProblem.y[ex] = trainingset.get(ex).getLabel();
// Convert the feature vector
trainProblem.x[ex] = createFeatureArray(trainingset.get(ex).getFeatures(),biasColumn);
}
} catch (MissingDataException e) {
LOGGER.debug("Failed setting up LibLinear problem due to missing data: ",e);
throw new MissingDataException("Failed training LibLinear model due to missing data - please revise your pre-processing");
}
LOGGER.trace("prob.l={}, prob.n={}, prob.x.len={}, prob.y.len={}",
trainProblem.l, trainProblem.n, trainProblem.x.length, trainProblem.y.length);
return trainProblem;
}
public static Problem clone(Problem problem){
Problem clone = new Problem();
clone.l = problem.l;
clone.n = problem.n;
clone.x = problem.x.clone();
clone.y = problem.y.clone();
clone.bias = problem.bias;
return clone;
}
public static Feature[] createFeatureArray(FeatureVector feats, Model m){
return createFeatureArray(feats, m.getNrFeature()+1);
}
public static Feature[] createFeatureArray(FeatureVector feats, int biasCol){
Feature[] nodes = new Feature[feats.getNumExplicitFeatures()+1]; // Add one for the bias column
int index = 0;
List missingDataIndices = new ArrayList<>();
for (FeatureVector.Feature f : feats) {
if (f instanceof MissingValueFeature || !Double.isFinite(f.getValue())) {
missingDataIndices.add(f.getIndex());
}
nodes[index] = new FeatureNode(
f.getIndex() +1 , // Need to add one as features starts at 0, liblinear requires start at 1!
f.getValue());
index++;
}
nodes[index] = new FeatureNode(biasCol, bias);
if (!missingDataIndices.isEmpty()) {
throw new MissingDataException("Encountered feature(s) with missing data (index): " + StringUtils.toStringNoBrackets(missingDataIndices));
}
return nodes;
}
/*
* =================================================
* PREDICTIONS
* =================================================
*/
private static void assertFittedModel(Model model) throws IllegalStateException {
if (model == null)
throw new IllegalStateException("Model not fitted");
}
public static double predictValue(Model model, FeatureVector example) throws IllegalStateException {
return predictValue(model,createFeatureArray(example,model));
}
public static double predictValue(Model model, Feature[] instance) throws IllegalStateException {
assertFittedModel(model);
return Linear.predict(model, instance);
}
public static int predictClass(Model model,FeatureVector example)
throws IllegalStateException {
return predictClass(model,createFeatureArray(example,model));
}
public static int predictClass(Model model,Feature[] instance)
throws IllegalStateException {
assertFittedModel(model);
return (int) Linear.predict(model, instance);
}
public static Map predictDistanceToHyperplane(Model model,FeatureVector example) throws IllegalStateException {
return predictDistanceToHyperplane(model,createFeatureArray(example, model));
}
public static Map predictDistanceToHyperplane(Model model,Feature[] example) throws IllegalStateException {
assertFittedModel(model);
int[] labels = model.getLabels();
double decValues[] = new double[labels.length];
Linear.predictValues(model, example, decValues);
// Convert to the labels used
Map prediction = new HashMap<>();
if (model.getNrClass() ==2) {
// Special treat binary classification - only gives a single value
prediction.put(labels[0], decValues[0]);
prediction.put(labels[1], -1*decValues[0]);
} else {
for (int i=0; i predictProbabilities(Model model, FeatureVector example){
assertFittedModel(model);
return predictProbabilities(model,createFeatureArray(example, model.getNrFeature()+1));
}
public static Map predictProbabilities(Model model, Feature[] example){
assertFittedModel(model);
if (!model.isProbabilityModel()) {
throw new IllegalStateException("The model was not trained for predicting probabilities");
}
int[] labels = model.getLabels();
double[] probs = new double[labels.length];
Linear.predictProbability(model, example, probs);
Map prediction = new HashMap<>();
for (int i=0; i