All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.models.ObliviousTree Maven / Gradle / Ivy

There is a newer version: 1.4.9
Show newest version
package com.expleague.ml.models;

import com.expleague.commons.math.DiscontinuousTrans;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.ml.BFGrid;
import com.expleague.ml.BinModelWithGrid;
import com.expleague.ml.BinOptimizedModel;
import com.expleague.ml.data.impl.BinarizedDataSet;
import org.jetbrains.annotations.NotNull;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
 * User: solar
 * Date: 29.11.12
 * Time: 5:35
 */
public class ObliviousTree extends BinOptimizedModel.Stub implements BinModelWithGrid {
  private final BFGrid.Feature[] features;
  private final double[] values;
  private final double[] basedOn;

  public ObliviousTree(final List features, final double[] values) {
    this(features, values, new double[values.length]);
  }

  public ObliviousTree(final List features, final double[] values, final double[] basedOn) {
    if (features.size() == 0)
      throw new RuntimeException("Creating oblivious tree of zero depth");
    this.basedOn = basedOn;
    this.features = features.toArray(new BFGrid.Feature[features.size()]);
    this.values = values;
  }

  @Override
  public int dim() {
    return features[0].row().grid().size();
  }

  @Override
  public double value(final Vec x) {
    final int index = bin(x);
    return values[index];
  }

  @Override
  public DiscontinuousTrans subgradient() {
    return new ObliviousGrad();
  }

  @Override
  public String toString() {
    final StringBuilder builder = new StringBuilder();
    builder.append(values.length);
    builder.append("->(");
    for (int i = 0; i < features.length; i++) {
      builder.append(i > 0 ? ", " : "")
          .append(features[i]).append("@").append(basedOn[i]);
    }
    builder.append(")");
    builder.append("+[");
    for (final double feature : values) {
      builder.append(feature).append(", ");
    }
    builder.delete(builder.length() - 2, builder.length());
    builder.append("]");
    return builder.toString();
  }

  public int bin(final Vec x) {
    int index = 0;
    for (int i = 0; i < features.length; i++) {
      index <<= 1;
      if (features[i].value(x))
        index++;
    }
    return index;
  }

  public List features() {
    return new ArrayList<>(Arrays.asList(features));
  }

  public double[] values() {
    return values;
  }

  public double[] based() {
    return basedOn;
  }

  @Override
  public boolean equals(final Object o) {
    if (this == o) return true;
    if (!(o instanceof ObliviousTree)) return false;

    final ObliviousTree that = (ObliviousTree) o;

    if (!Arrays.equals(features, that.features)) return false;
    return Arrays.equals(values, that.values);
  }

  @Override
  public int hashCode() {
    int result = Arrays.hashCode(features);
    result = 31 * result + Arrays.hashCode(values);
    result = 31 * result + Arrays.hashCode(basedOn);
    return result;
  }

  public BFGrid grid() {
    return features[0].row().grid();
  }

  @Override
  public double value(final BinarizedDataSet bds, final int pindex) {
    int index = 0;
    for (int i = 0; i < features.length; i++) {
      index <<= 1;
      if (features[i].value(bds.bins(features[i].findex())[pindex]))
        index++;
    }
    return values[index];
  }

  private class ObliviousGrad extends DiscontinuousTrans.Stub {
    @NotNull
    @Override
    public Vec leftTo(Vec x, Vec gradient) {
      final int numFeatures = x.dim();
      final byte[] folds = new byte[numFeatures];
      final BFGrid grid = grid();
      grid.binarizeTo(x, folds);

      final Vec x0 = VecTools.assign(new ArrayVec(numFeatures), x);
      final double value = ObliviousTree.this.value(x);

      for (int fIdx = 0; fIdx < numFeatures; fIdx++) {
        final BFGrid.Row row = grid.row(fIdx);
        final int leftIdx = folds[fIdx] - 1;

        if (leftIdx == -1) {
          gradient.adjust(fIdx, -1000);
        } else {
          final double leftBorder = row.condition(leftIdx);
          x0.set(fIdx, leftBorder);
          final double leftValue = ObliviousTree.this.value(x0);

          if (Math.abs(x.get(fIdx) - leftBorder) > 1e-7)
            gradient.adjust(fIdx, (value - leftValue) / (x.get(fIdx) - leftBorder));
          else
            gradient.adjust(fIdx, (value - leftValue));

          assert(!Double.isNaN(gradient.get(fIdx)));

          x0.set(fIdx, x.get(fIdx));
        }
      }

      return gradient;
    }

    @NotNull
    @Override
    public Vec rightTo(Vec x, Vec gradient) {
      final int numFeatures = x.dim();
      final byte[] folds = new byte[numFeatures];
      final BFGrid grid = grid();
      grid.binarizeTo(x, folds);

      final Vec x0 = VecTools.assign(new ArrayVec(numFeatures), x);
      final double value = ObliviousTree.this.value(x);

      for (int fIdx = 0; fIdx < numFeatures; fIdx++) {
        final BFGrid.Row row = grid.row(fIdx);
        final int rightIdx = folds[fIdx] + 1;
        if (rightIdx >= row.size()) {
          gradient.adjust(fIdx, 1000);
        } else {
          final double rightBorder = (row.condition(rightIdx) + row.condition(rightIdx - 1)) / 2.;
          x0.set(fIdx, rightBorder);

          final double rightValue = ObliviousTree.this.value(x0);
          if (Math.abs(x.get(fIdx) - rightBorder) > 1e-7)
            gradient.adjust(fIdx, (rightValue - value) / (rightBorder - x.get(fIdx)));
          else {
            gradient.adjust(fIdx, (rightValue - value));
          }
          assert(!Double.isNaN(gradient.get(fIdx)));

          x0.set(fIdx, x.get(fIdx));
        }
      }

      return gradient;
    }

    @Override
    public int xdim() {
      return ObliviousTree.this.dim();
    }

    @Override
    public int ydim() {
      return xdim();
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy