com.expleague.ml.methods.multiclass.spoc.impl.CodingMatrixLearning Maven / Gradle / Ivy
package com.expleague.ml.methods.multiclass.spoc.impl;
import com.expleague.commons.math.vectors.*;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.commons.random.FastRandom;
import com.expleague.ml.methods.multiclass.spoc.CMLHelper;
import com.expleague.commons.math.vectors.impl.mx.VecBasedMx;
import com.expleague.ml.methods.multiclass.spoc.AbstractCodingMatrixLearning;
import java.util.Random;
/**
* User: qdeee
* Date: 07.05.14
*/
public class CodingMatrixLearning extends AbstractCodingMatrixLearning {
private static final double MX_IGNORE_THRESHOLD = 0.1;
private static final double MX_LEARN_EPS = 1e-3;
private final Mx initB;
private final double mxLearnStep;
public CodingMatrixLearning(final Mx initB, final double mxLearnStep, final double lambdaC, final double lambdaR, final double lambda1) {
super(initB.rows(), initB.columns(), lambdaC, lambdaR, lambda1);
this.initB = initB;
this.mxLearnStep = mxLearnStep;
}
public CodingMatrixLearning(final Mx initB, final double mxLearnStep) {
this(initB, mxLearnStep, initB.rows(), 1.0, initB.rows());
}
public CodingMatrixLearning(final int k, final int l, final double lambdaC, final double lambdaR, final double lambda1, final double mxLearnStep) {
this(new VecBasedMx(k, l), mxLearnStep, lambdaC, lambdaR, lambda1);
final Random rand = new FastRandom(100500);
do {
for (int i = 0; i < k; i++) {
for (int j = 0; j < l; j++) {
initB.set(i, j, rand.nextInt(3) - 1);
}
}
} while (!CMLHelper.checkConstraints(initB));
}
public CodingMatrixLearning(final int k, final int l, final double mxLearnStep) {
this(k, l, k, 1.0, k, mxLearnStep);
}
@Override
public Mx findMatrixB(final Mx S) {
Mx mxB = initB;
final Vec b = new ArrayVec(2*k*l + 2*l + k);
{
for (int i = 0; i < 2*k*l; i++)
b.set(i, 1.);
for (int i = 2* k * l; i < 2*k*l + 2*l; i++)
b.set(i, -2.);
for (int i = 2* k * l + 2* l; i < 2*k*l + 2*l + k; i++)
b.set(i, -1.);
}
final Mx Inv = new VecBasedMx(k, k);
{
final double mult = 1 / (k * lambdaR * lambdaC + lambdaC * lambdaC);
VecTools.fill(Inv, -lambdaR * mult);
for (int i = 0; i < Inv.columns(); i++)
Inv.adjust(i, i, (k * lambdaR + lambdaC) * mult);
VecTools.scale(Inv, 0.5); //see algorithm's iteration process
}
final Vec gamma = new ArrayVec(2*k*l + 2*l + k);
{
// init gamma
for (int i = 0; i < gamma.dim(); i++) {
gamma.set(i, 0.5);
}
}
final Vec mu = new ArrayVec(k*l);
{
// init mu
for (int i = 0; i < mu.dim(); i++) {
mu.set(i, lambda1 / 2);
}
}
int iter = 0;
double error = 100500;
while (error > MX_LEARN_EPS) {
/**
* B^{i+1} = Inv * (2S * B^{i} - (transpose(A) * gamma - mu))
* def: m1 = 2S * B^{i}
* m2 = transpose(A) * gamma
* sub1 = m2 - mu
* sub2 = m1 - Mx(sub1)
*/
final Mx A = createConstraintsMatrix(mxB);
{
final Mx m1 = MxTools.multiply(S, mxB);
VecTools.scale(m1, 2.);
final Vec m2 = MxTools.multiply(MxTools.transpose(A), gamma);
final Vec sub1 = VecTools.subtract(m2, mu);
final Mx sub1Mx = vec2mx(sub1, m1.columns());
final Mx sub2 = VecTools.subtract(m1, sub1Mx);
final Mx newMxB = MxTools.multiply(Inv, sub2);
error = VecTools.infNorm(VecTools.subtract(mxB, newMxB));
mxB = newMxB;
}
/**
* Projections:
* gamma = Pr_{gamma >= 0} (gamma - t * (b - A * vec(mxB)))
* def: m1 = A * vec(mxB)
* sub = b - m1
*
* mu = Pr_{infnorm(mu) <= lambda1} (mu - t * vec(mxB))
*/
{
final Vec vecB = mx2vec(mxB);
final Vec m1 = MxTools.multiply(A, vecB);
final Vec sub = VecTools.subtract(b, m1);
VecTools.incscale(gamma, sub, -1 * mxLearnStep);
for (final VecIterator iterator = gamma.nonZeroes(); iterator.advance(); ) {
if (iterator.value() < 0)
iterator.setValue(0);
}
VecTools.incscale(mu, vecB, -1 * mxLearnStep);
for (final VecIterator iterator = mu.nonZeroes(); iterator.advance(); ) {
if (Math.abs(iterator.value()) > lambda1) {
iterator.setValue(lambda1);
}
}
}
if (iter++ > 1000)
break;
// if (!checkConstraints(mxB))
// throw new IllegalStateException("out of contraints!");
}
normalizeMx(mxB);
return mxB;
}
private static void normalizeMx(final Mx codingMatrix) {
for (final MxIterator iter = codingMatrix.nonZeroes(); iter.advance(); ) {
final double value = iter.value();
if (Math.abs(value) > MX_IGNORE_THRESHOLD)
iter.setValue(Math.signum(value));
else
iter.setValue(0.0);
}
}
protected static Mx vec2mx(final Vec vec, final int columns) {
final Mx result = new VecBasedMx(columns, new ArrayVec(vec.dim()));
final int rows = result.rows();
for (int i = 0; i < vec.dim(); i++) {
result.set(i % rows, i / rows, vec.get(i));
}
return result;
}
protected static Vec mx2vec(final Mx mx) {
final Vec result = new ArrayVec(mx.dim());
final int rows = mx.rows();
for (int i = 0; i < result.dim(); i++) {
result.set(i, mx.get(i % rows, i / rows));
}
return result;
}
//
// public static boolean checkConstraints(final Mx B) {
// final int k = B.rows();
// final int l = B.columns();
// final Mx A = createConstraintsMatrix(B);
// final Vec vecB = mx2vec(B);
// final Vec checkVec = VecTools.multiply(A, vecB);
// for (int i = 0; i < 2*k*l; i++)
// if (checkVec.at(i) > 1.)
// return false;
// for (int i = 2* k * l; i < 2*k*l + 2*l; i++)
// if (checkVec.at(i) > -2.)
// return false;
// for (int i = 2* k * l + 2* l; i < 2*k*l + 2*l + k; i++)
// if (checkVec.at(i) > -1)
// return false;
// return true;
// }
/**
*
* @param B Coding matrix that was obtained at the last iteration, size = [k,l]
* @return Matrix of constraints
*/
public static Mx createConstraintsMatrix(final Mx B) {
final int k = B.rows();
final int l = B.columns();
// final Mx A = new SparseMx(new MxBasisImpl(2*k*l + 2*l +k, k*l));
final Mx A = new VecBasedMx(2* k * l + 2* l + k, k * l);
for (int j = 0; j < k * l; j++) {
A.set(j, j, -1.0);
A.set(k * l + j, j, 1.0);
final double signum = Math.signum(B.get(j % k, j / k));
A.set(2*k*l + j/ k, j, -1 - signum);
A.set(2*k*l + l + j/ k, j, 1 -signum);
A.set(2*k*l + 2*l + (j % k), j, -signum);
}
return A;
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy