All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.models.TransObliviousTree Maven / Gradle / Ivy

package com.expleague.ml.models;

import com.expleague.commons.math.Trans;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.ml.BFGrid;
import com.expleague.ml.BinModelWithGrid;
import com.expleague.ml.BinOptimizedModel;
import com.expleague.ml.data.impl.BinarizedDataSet;
import com.expleague.ml.data.set.VecDataSet;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

/**
 * User: noxoomo
 */

public class TransObliviousTree extends BinOptimizedModel.Stub implements BinModelWithGrid {

  private final BFGrid.Feature[] features;
  private final Trans[] values;
  private final BFGrid grid;

  public TransObliviousTree(final List features, final Trans[] values) {
    grid = features.get(0).row().grid();
    this.features = features.toArray(new BFGrid.Feature[features.size()]);
    this.values = values;
  }

  @Override
  public int dim() {
    return grid.rows();
  }

  @Override
  public double value(final Vec x) {
    final int index = bin(x);
    return values[index].trans(x).get(0);
  }

  @Override
  public String toString() {
    final StringBuilder builder = new StringBuilder();
    builder.append(values.length);
    builder.append("->(");
    for (int i = 0; i < features.length; i++) {
      builder.append(i > 0 ? ", " : "")
          .append(features[i]).append("@");
    }
    builder.append(")");
    builder.append("+[");
    for (final Trans feature : values) {
      builder.append(feature.toString()).append(", ");
    }
    builder.delete(builder.length() - 2, builder.length());
    builder.append("]");
    return builder.toString();
  }

  public int bin(final Vec x) {
    int index = 0;
    for (int i = 0; i < features.length; i++) {
      index <<= 1;
      if (features[i].value(x))
        index++;
    }
    return index;
  }

  public List features() {
    final List ret = new ArrayList();
    Collections.addAll(ret, features);
    return ret;
  }

  public Trans[] values() {
    return values;
  }

  @Override
  public boolean equals(final Object o) {
    if (this == o) return true;
    if (!(o instanceof TransObliviousTree)) return false;
    final TransObliviousTree that = (TransObliviousTree) o;
    return Arrays.equals(features, that.features) && Arrays.equals(values, that.values);
  }

  @Override
  public int hashCode() {
    int result = Arrays.hashCode(features);
    result = 31 * result + Arrays.hashCode(values);
    return result;
  }

  public BFGrid grid() {
    return features[0].row().grid();
  }

  @Override
  public double value(final BinarizedDataSet bds, final int pindex) {
    int index = 0;
    for (int i = 0; i < features.length; i++) {
      index <<= 1;
      if (bds.bins(features[i].findex())[pindex] > features[i].bin())
        index++;
    }
    //dirty hack with cast
    return values[index].trans(((VecDataSet)bds.original()).data().row(pindex)).get(0);
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy