All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.models.ObliviousMultiClassTree Maven / Gradle / Ivy

package com.expleague.ml.models;

import com.expleague.commons.math.Trans;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.ml.BFGrid;

import java.util.Arrays;
import java.util.List;

/**
 * User: solar
 * Date: 29.11.12
 * Time: 5:35
 */
public class ObliviousMultiClassTree extends Trans.Stub {
  private final ObliviousTree binaryClassifier;
  private final boolean[][] masks;

  public ObliviousMultiClassTree(final List features, final double[] values, final double[] basedOn, final boolean[][] masks) {
    binaryClassifier = new ObliviousTree(features, values, basedOn);
    this.masks = masks;
  }

  @Override
  public int ydim() {
    return masks.length;
  }

  @Override
  public int xdim() {
    return binaryClassifier.xdim();
  }

  public double value(final Vec x, final int classNo) {
    final int bin = binaryClassifier.bin(x);
    final double v = binaryClassifier.values()[bin];
    return masks[bin][classNo] ? v : -v;
  }

  @Override
  public Vec trans(final Vec x) {
    final Vec result = new ArrayVec(ydim());
    for (int c = 0; c < ydim(); c++) {
      result.set(c, value(x, c));
    }
    return result;
  }

  @Override
  public String toString() {
    final StringBuilder builder = new StringBuilder();
    builder.append(binaryClassifier.toString());
    builder.append('<');
    for (int i = 0; i < masks.length; i++) {
      final boolean[] mask = masks[i];
      if (i > 0)
        builder.append(", ");
      for (int j = 0; j < mask.length; j++) {
        builder.append(mask[j] ? 1 : 0);
      }
    }
    builder.append('>');
    return builder.toString();
  }

  @Override
  public boolean equals(final Object o) {
    if (this == o) return true;
    if (!(o instanceof ObliviousMultiClassTree)) return false;
    return binaryClassifier.equals(((ObliviousMultiClassTree) o).binaryClassifier) && Arrays.equals(masks, ((ObliviousMultiClassTree) o).masks);
  }

  @Override
  public int hashCode() {
    return 31 * binaryClassifier.hashCode() + Arrays.hashCode(masks);
  }

  public ObliviousTree binaryClassifier() {
    return binaryClassifier;
  }

  public boolean[] mask(final int i) {
    return masks[i];
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy