com.expleague.ml.models.ObliviousMultiClassTree Maven / Gradle / Ivy
package com.expleague.ml.models;
import com.expleague.commons.math.Trans;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.ml.BFGrid;
import java.util.Arrays;
import java.util.List;
/**
* User: solar
* Date: 29.11.12
* Time: 5:35
*/
public class ObliviousMultiClassTree extends Trans.Stub {
private final ObliviousTree binaryClassifier;
private final boolean[][] masks;
public ObliviousMultiClassTree(final List features, final double[] values, final double[] basedOn, final boolean[][] masks) {
binaryClassifier = new ObliviousTree(features, values, basedOn);
this.masks = masks;
}
@Override
public int ydim() {
return masks.length;
}
@Override
public int xdim() {
return binaryClassifier.xdim();
}
public double value(final Vec x, final int classNo) {
final int bin = binaryClassifier.bin(x);
final double v = binaryClassifier.values()[bin];
return masks[bin][classNo] ? v : -v;
}
@Override
public Vec trans(final Vec x) {
final Vec result = new ArrayVec(ydim());
for (int c = 0; c < ydim(); c++) {
result.set(c, value(x, c));
}
return result;
}
@Override
public String toString() {
final StringBuilder builder = new StringBuilder();
builder.append(binaryClassifier.toString());
builder.append('<');
for (int i = 0; i < masks.length; i++) {
final boolean[] mask = masks[i];
if (i > 0)
builder.append(", ");
for (int j = 0; j < mask.length; j++) {
builder.append(mask[j] ? 1 : 0);
}
}
builder.append('>');
return builder.toString();
}
@Override
public boolean equals(final Object o) {
if (this == o) return true;
if (!(o instanceof ObliviousMultiClassTree)) return false;
return binaryClassifier.equals(((ObliviousMultiClassTree) o).binaryClassifier) && Arrays.equals(masks, ((ObliviousMultiClassTree) o).masks);
}
@Override
public int hashCode() {
return 31 * binaryClassifier.hashCode() + Arrays.hashCode(masks);
}
public ObliviousTree binaryClassifier() {
return binaryClassifier;
}
public boolean[] mask(final int i) {
return masks[i];
}
}