All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.methods.multiclass.gradfac.GradFacMulticlass Maven / Gradle / Ivy

There is a newer version: 1.4.9
Show newest version
package com.expleague.ml.methods.multiclass.gradfac;

import com.expleague.commons.math.Func;
import com.expleague.commons.math.vectors.Mx;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.ml.data.set.VecDataSet;
import com.expleague.ml.factorization.Factorization;
import com.expleague.ml.methods.VecOptimization;
import com.expleague.ml.methods.multiclass.MultiClassOneVsRest;
import com.expleague.commons.math.vectors.impl.mx.VecBasedMx;
import com.expleague.commons.util.Pair;
import com.expleague.ml.data.tools.DataTools;
import com.expleague.ml.func.ScaledVectorFunc;
import com.expleague.ml.loss.L2;

/**
 * User: qdeee
 * Date: 25.12.14
 */
public class GradFacMulticlass implements VecOptimization {
  private final VecOptimization inner;
  private final Factorization matrixDecomposition;
  private final Class local;
  private final boolean printErrors;

  public GradFacMulticlass(final VecOptimization inner, final Factorization matrixDecomposition, final Class local) {
    this(inner, matrixDecomposition, local, false);
  }

  public GradFacMulticlass(final VecOptimization inner, final Factorization matrixDecomposition, final Class local, final boolean printErrors) {
    this.inner = inner;
    this.matrixDecomposition = matrixDecomposition;
    this.local = local;
    this.printErrors = printErrors;
  }

  @Override
  public ScaledVectorFunc fit(VecDataSet learn, L2 mllLogitGradient) {
    final Mx gradient = mllLogitGradient.target instanceof Mx
        ? (Mx)mllLogitGradient.target
        : new VecBasedMx(mllLogitGradient.target.dim() / learn.length(), mllLogitGradient.target);
    final Pair pair = matrixDecomposition.factorize(gradient);

    final Vec h = pair.getFirst();
    final Vec b = pair.getSecond();

    final double normB = VecTools.norm(b);
    VecTools.scale(b, 1 / normB);
    VecTools.scale(h, normB);

    final L2 loss = DataTools.newTarget(local, h, learn);
    final Func model = MultiClassOneVsRest.extractFunc(inner.fit(learn, loss));
    final ScaledVectorFunc resultModel = new ScaledVectorFunc(model, b);

    if (printErrors) {
      final Mx mxAfterFactor = VecTools.outer(h, b);
      final Mx mxAfterFit = resultModel.transAll(learn.data());
      final double gradNorm = VecTools.norm(gradient);
      final double error1 = VecTools.distance(gradient, mxAfterFactor);
      final double error2 = VecTools.distance(mxAfterFactor, mxAfterFit);
      final double totalError = VecTools.distance(gradient, mxAfterFit);

      System.out.println(String.format("grad_norm = %f, err1 = %f, err2 = %f, absErr = %f", gradNorm, error1, error2, totalError));
    }

    return resultModel; //not MultiClassModel, for boosting compatibility
  }
}


// cn \t gradnorm \t rel_fact_err
/*
if (printErrors) {
      final RealMatrix realMatrix = new Array2DRowRealMatrix(gradient.rows(), gradient.columns());
      final int rows = gradient.rows();
      final int columns = gradient.columns();
      for (int i = 0; i < rows; i++) {
        for (int j = 0; j < columns; j++) {
          realMatrix.setEntry(i, j, gradient.get(i, j));
        }
      }
      final SingularValueDecomposition singularValueDecomposition = new SingularValueDecomposition(realMatrix);
      System.out.print(singularValueDecomposition.getConditionNumber() + "\t");


      final Mx mxAfterFactor = VecTools.outer(h, b);
//      final Mx mxAfterFit = resultModel.transAll(learn.data());
      final double gradNorm = VecTools.norm(gradient);
      final double error1 = VecTools.distance(gradient, mxAfterFactor);
//      final double error2 = VecTools.distance(mxAfterFactor, mxAfterFit);
//      final double totalError = VecTools.distance(gradient, mxAfterFit);

//      System.out.println(String.format("%f\t%f\t%f\t%f", gradNorm, error1, error2, totalError));
      System.out.print(gradNorm + "\t");
      System.out.print(error1 / gradNorm + "\n");
    }

*/




© 2015 - 2024 Weber Informatics LLC | Privacy Policy