All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.methods.multiclass.spoc.impl.CodingMatrixLearningGreedy Maven / Gradle / Ivy

There is a newer version: 1.4.9
Show newest version
package com.expleague.ml.methods.multiclass.spoc.impl;

import com.expleague.commons.math.vectors.MxTools;
import com.expleague.commons.math.vectors.Mx;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.commons.math.vectors.impl.mx.VecBasedMx;
import com.expleague.commons.util.Combinatorics;
import com.expleague.ml.methods.multiclass.spoc.AbstractCodingMatrixLearning;
import com.expleague.ml.methods.multiclass.spoc.CMLHelper;

/**
 * User: qdeee
 * Date: 07.05.14
 */
public class CodingMatrixLearningGreedy extends AbstractCodingMatrixLearning {
  public CodingMatrixLearningGreedy(final int k, final int l, final double lambdaC, final double lambdaR, final double lambda1) {
    super(k, l, lambdaC, lambdaR, lambda1);
  }

  protected double calcLoss(final Mx B, final Mx S) {
    double result = 0;
    final Mx mult = MxTools.multiply(B, MxTools.transpose(B));
    result -= MxTools.trace(MxTools.multiply(mult, S));
    result += lambdaR * VecTools.sum(mult);
    result += lambdaC * VecTools.sum2(B);
    result += lambda1 * VecTools.l1(B);
    return result;
  }

  @Override
  public Mx findMatrixB(final Mx S) {
    final Mx mxB = new VecBasedMx(k, l);
    for (int j = 0; j < l; j++) {
      final Combinatorics.PartialPermutations permutationsGenerator = new Combinatorics.PartialPermutations(2, k);
      int[] bestPerm = null;
      double bestLoss = Double.MAX_VALUE;
      while (permutationsGenerator.hasNext()) {
        final int[] perm = permutationsGenerator.next();
        for (int i = 0; i < k; i++) {
          mxB.set(i, j, 2 * perm[i] - 1);
        }
        final Mx sub = mxB.sub(0, 0, k, j + 1);
        if (CMLHelper.checkConstraints(sub) && CMLHelper.checkColumnsIndependence(sub)) {
          final double loss = calcLoss(sub, S);
          if (loss < bestLoss) {
            bestLoss = loss;
            bestPerm = perm;
          }
        }
      }
      if (bestPerm != null) {
        for (int i = 0; i < k; i++) {
          mxB.set(i, j, 2 * bestPerm[i] - 1);
        }
      }
      else
        throw new IllegalStateException("Not found appreciate column #" + j);
//      System.out.println("Column " + j + " is over!");
    }
    return mxB;
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy