All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.methods.MultiClass Maven / Gradle / Ivy

package com.expleague.ml.methods;

import com.expleague.commons.math.vectors.Mx;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.commons.math.vectors.impl.mx.VecBasedMx;
import com.expleague.commons.math.Func;
import com.expleague.ml.data.set.VecDataSet;
import com.expleague.ml.data.tools.DataTools;
import com.expleague.ml.func.FuncJoin;
import com.expleague.ml.loss.L2;

/**
 * User: solar
 * Date: 27.11.13
 * Time: 18:55
 */
public class MultiClass extends VecOptimization.Stub {
  private final VecOptimization inner;
  private final Class local;
  private final boolean printErrors;

  public MultiClass(final VecOptimization inner, final Class local) {
    this(inner, local, false);
  }

  public MultiClass(final VecOptimization inner, final Class local, final boolean printErrors) {
    this.inner = inner;
    this.local = local;
    this.printErrors = printErrors;
  }

  @Override
  public FuncJoin fit(final VecDataSet learn, final L2 mllLogitGradient) {
    final Mx gradient;
    final Vec gradVec = mllLogitGradient.target;
    if (gradVec instanceof Mx) {
      gradient = (Mx) gradVec;
    } else {
      final int columns = gradVec.dim() / learn.data().rows();
      gradient = new VecBasedMx(columns, gradVec);
    }
    final Func[] models = new Func[gradient.columns()];
    for (int c = 0; c < models.length; c++) {
      final L2 loss = DataTools.newTarget(local, gradient.col(c), learn);
      models[c] = (Func)inner.fit(learn, loss);
    }
    final FuncJoin resultModel = new FuncJoin(models);

    if (printErrors) {
      final Mx mxAfterFit = resultModel.transAll(learn.data());
      final double error = VecTools.distance(gradient, mxAfterFit);
      final double gradNorm = VecTools.norm(gradient);
      System.out.println("grad_norm = " + gradNorm + ", err = " + error);
    }

    return resultModel; //not MultiClassModel, for boosting compatibility
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy