All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.dynamicGrid.models.ObliviousTreeDynamicBin Maven / Gradle / Ivy

package com.expleague.ml.dynamicGrid.models;

import com.expleague.commons.math.Func;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.ml.dynamicGrid.interfaces.BinaryFeature;
import com.expleague.ml.dynamicGrid.impl.BinarizedDynamicDataSet;
import com.expleague.ml.dynamicGrid.interfaces.DynamicGrid;

import java.util.Arrays;
import java.util.List;

public class ObliviousTreeDynamicBin extends Func.Stub implements BinDynamicOptimizedModel {
  private final BinaryFeature[] features;
  private final double[] values;
  private final DynamicGrid grid;

  public int depth() {
    return features.length;
  }


  public ObliviousTreeDynamicBin(final List features, final double[] values) {
    this.grid = features.get(0).row().grid();
    this.features = features.toArray(new BinaryFeature[features.size()]);
    this.values = values;
  }

  @Override
  public int dim() {
    return grid.rows();
  }

  @Override
  public double value(final Vec x) {
    final int index = bin(x);
    return values[index];
  }

  @Override
  public String toString() {
    final StringBuilder builder = new StringBuilder();
    builder.append(values.length);
    builder.append("->(");
    for (int i = 0; i < features.length; i++) {
      builder.append(i > 0 ? ", " : "")
              .append(features[i]).append("@");
    }
    builder.append(")");
    builder.append("+[");
    for (final double feature : values) {
      builder.append(feature).append(", ");
    }
    builder.delete(builder.length() - 2, builder.length());
    builder.append("]");
    return builder.toString();
  }

  public int bin(final Vec x) {
    int index = 0;
    for (int i = 0; i < features.length; i++) {
      index <<= 1;
      if (features[i].value(x))
        index++;
    }
    return index;
  }

  public BinaryFeature[] features() {
    return features;
  }

  public double[] values() {
    return values;
  }


  @Override
  public boolean equals(final Object o) {
    if (this == o) return true;
    if (!(o instanceof ObliviousTreeDynamicBin)) return false;

    final ObliviousTreeDynamicBin that = (ObliviousTreeDynamicBin) o;

    if (!Arrays.equals(features, that.features)) return false;
    if (!Arrays.equals(values, that.values)) return false;

    return true;
  }

  @Override
  public int hashCode() {
    int result = Arrays.hashCode(features);
    result = 31 * result + Arrays.hashCode(values);
    return result;
  }

  public DynamicGrid grid() {
    return grid;
  }

  @Override
  public double value(final BinarizedDynamicDataSet bds, final int pindex) {
    int index = 0;
    for (int i = 0; i < features.length; i++) {
      index <<= 1;
      if (bds.bins(features[i].fIndex())[pindex] > features[i].binNo())
        index++;
    }
    return values[index];
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy