All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.actelion.research.calc.regression.svm.SVMRegression Maven / Gradle / Ivy

There is a newer version: 2024.11.2
Show newest version
package com.actelion.research.calc.regression.svm;

import com.actelion.research.calc.CorrelationCalculator;
import com.actelion.research.calc.Matrix;
import com.actelion.research.calc.ProgressController;
import com.actelion.research.util.Formatter;
import com.actelion.research.util.datamodel.DoubleArray;
import com.actelion.research.util.datamodel.ModelXYIndex;
import com.actelion.research.calc.regression.ARegressionMethod;
import org.machinelearning.svm.libsvm.svm_model;
import org.machinelearning.svm.libsvm.svm_problem;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

/**
 * SVMRegression
 *
 * Calculates the SVM regression for one Y column.
 *
 * 

Modest v. Korff

*

* Created by korffmo1 on 27.11.18. * * http://www.svms.org/parameters/ */ public class SVMRegression extends ARegressionMethod implements Comparable { public static final int LIMIT_ROWS_ANALYTICAL = 1000; public static boolean VERBOSE = false; public static final double NU = 0.1; private svm_model modelSVM; public SVMRegression() { setParameterRegressionMethod(new ParameterSVM(SVMParameterHelper.regressionEpsilonSVR())); } public SVMRegression(ParameterSVM parameterSVM) { setParameterRegressionMethod(parameterSVM); } public Matrix createModel(ModelXYIndex modelXYIndex) { if(parameterRegressionMethod.getSvmParameter().degree==SVMParameterHelper.DEGREE_ANALYTICALLY_PARAMETER_CALC) { ModelXYIndex modelXYIndexAnalytical = null; int rows = modelXYIndex.X.rows(); if(rows>LIMIT_ROWS_ANALYTICAL) { List liIndex = new ArrayList<>(rows); for (int i = 0; i < rows; i++) { liIndex.add(i); } Collections.shuffle(liIndex); List liIndexSub = new ArrayList<>(liIndex.subList(0, LIMIT_ROWS_ANALYTICAL)); modelXYIndexAnalytical = modelXYIndex.sub(liIndexSub); } else{ modelXYIndexAnalytical = modelXYIndex; } ParameterSVM parameterSVM = AnalyticalParameterCalculatorSVM.calculate(modelXYIndexAnalytical); setParameterRegressionMethod(parameterSVM); } int rows = modelXYIndex.X.rows(); int colsY = modelXYIndex.Y.cols(); if(colsY!=1){ throw new RuntimeException("Only one y column allowed!"); } Matrix YHat = new Matrix(rows, colsY); svm_problem svm_problem = new svm_problem(); svm_problem.l = rows; svm_problem.x = Matrix2SVMNodeConverter.convert(modelXYIndex.X); if(getParameter().getGamma()<=0) { getParameter().setGamma(1.0 / modelXYIndex.X.cols()); } boolean failed = false; try { double [] arrY = modelXYIndex.Y.getColAsDouble(0); DoubleArray daY = new DoubleArray(arrY); svm_problem.y = arrY; String error_msg = svm.svm_check_parameter(svm_problem, getParameter().getSvmParameter()); if(error_msg != null) { System.err.print("SVMRegressionMultiY svm_check_parameter error: "+error_msg+"\n"); failed = true; } ProgressController pg = getProgressController(); modelSVM = svm.svm_train(svm_problem, getParameter().getSvmParameter(), pg); // System.out.println("SVM model created"); if(pg!=null) { pg.startProgress("Calculate train data fit", 0, rows); } DoubleArray daYHat = new DoubleArray(rows); for (int j = 0; j < rows; j++) { double yHat = svm.svm_predict(modelSVM, svm_problem.x[j]); daYHat.add(yHat); YHat.set(j,0,yHat); if(pg!=null) { pg.updateProgress(j); if (pg.threadMustDie()) { failed = true; break; } } } if(VERBOSE) { double r = new CorrelationCalculator().calculateCorrelation(daY, daYHat, CorrelationCalculator.TYPE_BRAVAIS_PEARSON); double r2 = r * r; System.out.println("SVMRegressionMultiY model r2 " + Formatter.format4(r2) + "."); } } catch (Exception e) { e.printStackTrace(); if(VERBOSE) System.err.println("SVMRegressionMultiY break."); failed = true; } if(failed) { YHat=null; } return YHat; } @Override public Matrix calculateYHat(Matrix X) { int rows = X.rows(); Matrix YHat = new Matrix(rows, 1); svm_problem svm_problem = new svm_problem(); svm_problem.l = rows; svm_problem.x = Matrix2SVMNodeConverter.convert(X); for (int j = 0; j < rows; j++) { double yHat = svm.svm_predict(modelSVM, svm_problem.x[j]); YHat.set(j,0,yHat); } return YHat; } @Override public double calculateYHat(double[] arrRow) { double yHat; synchronized (this) { svm_problem svm_problem = new svm_problem(); svm_problem.l = 1; svm_problem.x = Matrix2SVMNodeConverter.convertSingleRow(arrRow); yHat = svm.svm_predict(modelSVM, svm_problem.x[0]); } return yHat; } public void setNu(double nu){ getParameter().setNu(nu); } public void setGamma(double gamma){ getParameter().setGamma(gamma); } @Override public int compareTo(SVMRegression o) { return getParameter().compareTo(o.getParameter()); } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy