com.expleague.ml.func.TransJoin Maven / Gradle / Ivy
package com.expleague.ml.func;
import com.expleague.commons.func.Evaluator;
import com.expleague.commons.math.Trans;
import com.expleague.commons.math.vectors.Mx;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.commons.math.vectors.impl.mx.VecBasedMx;
import com.expleague.commons.util.ArrayTools;
import org.jetbrains.annotations.Nullable;
import sun.reflect.generics.reflectiveObjects.NotImplementedException;
import java.util.Arrays;
import java.util.List;
/**
* User: solar
* Date: 21.12.2010
* Time: 22:07:07
*/
public class TransJoin extends Trans.Stub {
public final Trans[] dirs;
private final int xdim;
private final int ydim;
public TransJoin(final Trans[] dirs) {
this.dirs = dirs;
xdim = dirs[ArrayTools.max(dirs, new Evaluator() {
@Override
public double value(final Trans trans) {
return trans.xdim();
}
})].xdim();
ydim = dirs.length * dirs[ArrayTools.max(dirs, new Evaluator() {
@Override
public double value(final Trans trans) {
return trans.ydim();
}
})].ydim();
}
public TransJoin(final List models) {
this(models.toArray(new Trans[models.size()]));
}
@Override
public int ydim() {
return ydim;
}
@Override
public int xdim() {
return xdim;
}
@Nullable
@Override
public Trans gradient() {
final Trans[] gradients = new Trans[ydim()];
for (int i = 0; i < dirs.length; i++) {
gradients[i] = dirs[i].gradient();
if (gradients[i] == null)
return null;
}
return new Stub() {
@Override
public int xdim() {
return TransJoin.this.xdim();
}
@Override
public int ydim() {
return xdim() * TransJoin.this.ydim();
}
@Nullable
@Override
public Trans gradient() {
throw new NotImplementedException();
}
@Override
public Vec trans(final Vec x) {
final Mx result = new VecBasedMx(xdim(), new ArrayVec(ydim()));
for (int i = 0; i < dirs.length; i++) {
VecTools.assign(result.row(i), gradients[i].trans(x));
}
return result;
}
};
}
@Override
public Vec trans(final Vec x) {
final Mx result = new VecBasedMx(ydim / dirs.length, new ArrayVec(ydim));
for (int c = 0; c < dirs.length; c++) {
VecTools.assign(result.row(c), dirs[c].trans(x));
}
return result;
}
@Override
public boolean equals(final Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
final TransJoin transJoin = (TransJoin) o;
if (xdim != transJoin.xdim) return false;
if (ydim != transJoin.ydim) return false;
if (!Arrays.equals(dirs, transJoin.dirs)) return false;
return true;
}
@Override
public int hashCode() {
int result = Arrays.hashCode(dirs);
result = 31 * result + xdim;
result = 31 * result + ydim;
return result;
}
}