![JAR search and dependency download from the Maven repository](/logo.png)
com.expleague.ml.methods.multiclass.gradfac.GradFacMulticlass Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of jmll Show documentation
Show all versions of jmll Show documentation
Various ML methods implemented by myself and my students
package com.expleague.ml.methods.multiclass.gradfac;
import com.expleague.commons.math.Func;
import com.expleague.commons.math.vectors.Mx;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.ml.data.set.VecDataSet;
import com.expleague.ml.factorization.Factorization;
import com.expleague.ml.methods.VecOptimization;
import com.expleague.ml.methods.multiclass.MultiClassOneVsRest;
import com.expleague.commons.math.vectors.impl.mx.VecBasedMx;
import com.expleague.commons.util.Pair;
import com.expleague.ml.data.tools.DataTools;
import com.expleague.ml.func.ScaledVectorFunc;
import com.expleague.ml.loss.L2;
/**
* User: qdeee
* Date: 25.12.14
*/
public class GradFacMulticlass implements VecOptimization {
private final VecOptimization inner;
private final Factorization matrixDecomposition;
private final Class extends L2> local;
private final boolean printErrors;
public GradFacMulticlass(final VecOptimization inner, final Factorization matrixDecomposition, final Class extends L2> local) {
this(inner, matrixDecomposition, local, false);
}
public GradFacMulticlass(final VecOptimization inner, final Factorization matrixDecomposition, final Class extends L2> local, final boolean printErrors) {
this.inner = inner;
this.matrixDecomposition = matrixDecomposition;
this.local = local;
this.printErrors = printErrors;
}
@Override
public ScaledVectorFunc fit(VecDataSet learn, L2 mllLogitGradient) {
final Mx gradient = mllLogitGradient.target instanceof Mx
? (Mx)mllLogitGradient.target
: new VecBasedMx(mllLogitGradient.target.dim() / learn.length(), mllLogitGradient.target);
final Pair pair = matrixDecomposition.factorize(gradient);
final Vec h = pair.getFirst();
final Vec b = pair.getSecond();
final double normB = VecTools.norm(b);
VecTools.scale(b, 1 / normB);
VecTools.scale(h, normB);
final L2 loss = DataTools.newTarget(local, h, learn);
final Func model = MultiClassOneVsRest.extractFunc(inner.fit(learn, loss));
final ScaledVectorFunc resultModel = new ScaledVectorFunc(model, b);
if (printErrors) {
final Mx mxAfterFactor = VecTools.outer(h, b);
final Mx mxAfterFit = resultModel.transAll(learn.data());
final double gradNorm = VecTools.norm(gradient);
final double error1 = VecTools.distance(gradient, mxAfterFactor);
final double error2 = VecTools.distance(mxAfterFactor, mxAfterFit);
final double totalError = VecTools.distance(gradient, mxAfterFit);
System.out.println(String.format("grad_norm = %f, err1 = %f, err2 = %f, absErr = %f", gradNorm, error1, error2, totalError));
}
return resultModel; //not MultiClassModel, for boosting compatibility
}
}
// cn \t gradnorm \t rel_fact_err
/*
if (printErrors) {
final RealMatrix realMatrix = new Array2DRowRealMatrix(gradient.rows(), gradient.columns());
final int rows = gradient.rows();
final int columns = gradient.columns();
for (int i = 0; i < rows; i++) {
for (int j = 0; j < columns; j++) {
realMatrix.setEntry(i, j, gradient.get(i, j));
}
}
final SingularValueDecomposition singularValueDecomposition = new SingularValueDecomposition(realMatrix);
System.out.print(singularValueDecomposition.getConditionNumber() + "\t");
final Mx mxAfterFactor = VecTools.outer(h, b);
// final Mx mxAfterFit = resultModel.transAll(learn.data());
final double gradNorm = VecTools.norm(gradient);
final double error1 = VecTools.distance(gradient, mxAfterFactor);
// final double error2 = VecTools.distance(mxAfterFactor, mxAfterFit);
// final double totalError = VecTools.distance(gradient, mxAfterFit);
// System.out.println(String.format("%f\t%f\t%f\t%f", gradNorm, error1, error2, totalError));
System.out.print(gradNorm + "\t");
System.out.print(error1 / gradNorm + "\n");
}
*/
© 2015 - 2025 Weber Informatics LLC | Privacy Policy