All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.data.cherry.CherrySubset Maven / Gradle / Ivy

package com.expleague.ml.data.cherry;

import com.expleague.commons.func.AdditiveStatistics;
import com.expleague.commons.func.Factory;
import com.expleague.ml.BFGrid;
import com.expleague.ml.data.Aggregate;
import com.expleague.ml.data.impl.BinarizedDataSet;

public class CherrySubset implements CherryPointsHolder {
  private final BinarizedDataSet bds;
  private final Factory factory;
  private int[] points;
  private int[] cache;
  private boolean insideSign = false;
  private int outsideStart = 0;
  private Aggregate outsideAggregate;
  private AdditiveStatistics inside;
  private int length;

  public CherrySubset(BinarizedDataSet bds, Factory factory, int[] points) {
    this.bds = bds;
    this.factory = factory;
    this.points = points;
    for (int i = 0; i < points.length; ++i)
      points[i]++;
    cache = new int[points.length];
    this.length = points.length;
  }


  public void visitAll(final Aggregate.IntervalVisitor visitor) {
    outsideAggregate.visit(visitor);
  }

  @Override
  public BFGrid grid() {
    return bds.grid();
  }

  @Override
  public void endClause() {
    insideSign = !insideSign;
    length = outsideStart;
    outsideStart = 0;
  }

  @Override
  public void startClause() {
    inside = factory.create();
    this.outsideAggregate = new Aggregate(bds, factory, getPoints(0, length));
  }

  @Override
  public AdditiveStatistics addCondition(final BFGrid.BFRow row,
                                         final int startBin, final int endBin) {
    final byte[] bins = bds.bins(row.origFIndex);
    AdditiveStatistics added = factory.create();
    int count = 0;
    for (int i = outsideStart; i < length; ++i) {
      final int point = getPoint(i);
      final int bin = bins[point];
      if (startBin <= bin && bin <= endBin) {
        added.append(point, 1);
        points[i] *= -1;
        ++count;
      }
    }
    gather(outsideStart + count);
    if (count > length - outsideStart) {
      int[] out = getPoints(outsideStart + count, length);
      outsideAggregate = new Aggregate(bds, factory, out);
    } else {
      int[] addedPoints = getPoints(outsideStart, count);
      final Aggregate toRemove = new Aggregate(bds, factory, addedPoints);
      outsideAggregate.remove(toRemove);
    }
    outsideStart += count;
    inside.append(added);
    return inside;
  }

  private int[] getPoints(final int from, final int size) {
    final int[] result = new int[size];
    int len = (result.length / 4) * 4;
    for (int i = 0; i < len; i += 4) {
      result[i] = getPoint(from + i);
      result[i + 1] = getPoint(from + i + 1);
      result[i + 2] = getPoint(from + i + 2);
      result[i + 3] = getPoint(from + i + 3);
    }
    for (int i = len; i < result.length; ++i) {
      result[i] = getPoint(from + i);
    }
    return result;
  }


  private int getPoint(int i) {
    final int idx = points[i] > 0 ? points[i] : -points[i];
    return idx - 1;
  }

  private void gather(final int inCount) {
    int inPtr = 0;
    int outPtr = inCount;
    for (int i = 0; i < length; ++i) {
      if ((points[i] > 0) == insideSign) {
        cache[inPtr++] = points[i];
      } else {
        cache[outPtr++] = points[i];
      }
    }
    final int[] tmp = points;
    points = cache;
    cache = tmp;
  }

  public AdditiveStatistics total() {
    return outsideAggregate.total().append(inside);
  }

  public AdditiveStatistics inside() {
    final AdditiveStatistics stat = factory.create().append(inside);
    return stat;
  }

  public AdditiveStatistics outside() {
    final AdditiveStatistics stat = outsideAggregate.total();
    return stat;
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy