All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.dynamicGrid.AggregateDynamic Maven / Gradle / Ivy

package com.expleague.ml.dynamicGrid;

import com.expleague.commons.func.AdditiveStatistics;
import com.expleague.commons.func.Factory;
import com.expleague.ml.dynamicGrid.impl.BinarizedDynamicDataSet;
import com.expleague.ml.dynamicGrid.interfaces.BinaryFeature;
import com.expleague.ml.dynamicGrid.interfaces.DynamicGrid;
import com.expleague.ml.dynamicGrid.interfaces.DynamicRow;
import com.expleague.commons.util.ArrayTools;
import com.expleague.commons.util.ThreadTools;

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ThreadPoolExecutor;

@SuppressWarnings("unchecked")
public class AggregateDynamic {
  private final BinarizedDynamicDataSet bds;
  private final DynamicGrid grid;
  public final AdditiveStatistics[][] bins;
  private final Factory factory;
  private int[] points;

  public void updatePoints(final int[] points) {
    this.points = points;
  }

  public AggregateDynamic(final BinarizedDynamicDataSet bds, final Factory factory, final int[] points) {
    this.points = points;
    this.bds = bds;
    this.grid = bds.grid();
    this.bins = new AdditiveStatistics[grid.rows()][];
    for (int feat = 0; feat < bins.length; feat++) {
      bins[feat] = new AdditiveStatistics[0];
    }

    this.factory = factory;
    rebuild(points, ArrayTools.sequence(0, grid.rows()));
  }

  public AdditiveStatistics combinatorForFeature(final BinaryFeature bf) {
    final AdditiveStatistics result = factory.create();
    final DynamicRow row = bf.row();
    final int binNo = bf.binNo();
    final int origFIndex = row.origFIndex();
    for (int b = 0; b <= binNo; b++) {
      result.append(bins[origFIndex][b]);
    }
    return result;
  }

  public AdditiveStatistics total() {
    final AdditiveStatistics myTotal = factory.create();
    final DynamicRow row = grid.nonEmptyRow();
    final AdditiveStatistics[] myBins = bins[row.origFIndex()];
    for (int bin = 0; bin < myBins.length; ++bin) {
      myTotal.append(myBins[bin]);
    }
    return myTotal;
  }

  private static final ThreadPoolExecutor exec = ThreadTools.createBGExecutor("Aggregator thread", -1);

  public void remove(final AggregateDynamic aggregate) {
    final AdditiveStatistics[][] my = bins;
    final AdditiveStatistics[][] other = aggregate.bins;
    for (int i = 0; i < bins.length; i++) {
      for (int j = 0; j < bins[i].length; ++j) {
        my[i][j].remove(other[i][j]);
      }
    }
  }

  public interface SplitVisitor {
    void accept(BinaryFeature bf, T left, T right);
  }


//  private static int sequentialLimit = 4;
//
//  public  void visit(final SplitVisitor visitor) {
//    final CountDownLatch latch = new CountDownLatch(grid.rows());
//    final T total = (T) total();
//    for (int f = 0; f < grid.rows(); f++) {
//      final DynamicRow row = grid.row(f);
//      if (row.size() < sequentialLimit) {
//        final T left = (T) factory.create();
//        final T right = (T) factory.create().append(total);
//        final AdditiveStatistics[] rowBins = bins[row.origFIndex()];
//        for (int b = 0; b < row.size(); b++) {
//          left.append(rowBins[b]);
//          right.remove(rowBins[b]);
//          visitor.accept(row.bf(b), left, right);
//        }
//        latch.countDown();
//      } else {
//        exec.execute(new Runnable() {
//          @Override
//          public void run() {
//            final T left = (T) factory.create();
//            final T right = (T) factory.create().append(total);
//            final AdditiveStatistics[] rowBins = bins[row.origFIndex()];
//            for (int b = 0; b < row.size(); b++) {
//              left.append(rowBins[b]);
//              right.remove(rowBins[b]);
//              visitor.accept(row.bf(b), left, right);
//            }
//            latch.countDown();
//          }
//        });
//      }
//    }
//    try {
//      latch.await();
//    } catch (InterruptedException e) {
//      // skip
//    }
//  }

  public  void visit(final SplitVisitor visitor) {
    final T total = (T) total();
    for (int f = 0; f < grid.rows(); f++) {
      final T left = (T) factory.create();
      final T right = (T) factory.create().append(total);
      final DynamicRow row = grid.row(f);
      final AdditiveStatistics[] rowBins = bins[row.origFIndex()];
      for (int b = 0; b < row.size(); b++) {
        left.append(rowBins[b]);
        right.remove(rowBins[b]);
        visitor.accept(row.bf(b), left, right);
      }
    }
  }

  public void rebuild(final int... features) {
    rebuild(this.points, features);
  }

  private void rebuild(final int[] indices, final int... features) {
    final CountDownLatch latch = new CountDownLatch(features.length);
    for (final int findex : features) {
      final int finalFIndex = findex;
      exec.execute(new Runnable() {
        @Override
        public void run() {
          final short[] bin = bds.bins(finalFIndex);
          if (!grid.row(finalFIndex).empty()) {

            final int length = 4 * (indices.length / 4);
            final AdditiveStatistics[] binsLocal = new AdditiveStatistics[grid.row(finalFIndex).size() + 1];

            for (int i = 0; i < binsLocal.length; ++i)
              binsLocal[i] = factory.create();
//              for (int i : indices) {
//                binsLocal[bin[i]].append(i, 1);
//              }
            final int[] indicesLocal = indices;
            for (int i = 0; i < length; i += 4) {
              final int idx1 = indicesLocal[i];
              final int idx2 = indicesLocal[i + 1];
              final int idx3 = indicesLocal[i + 2];
              final int idx4 = indicesLocal[i + 3];
              final AdditiveStatistics bin1 = binsLocal[bin[idx1]];
              final AdditiveStatistics bin2 = binsLocal[bin[idx2]];
              final AdditiveStatistics bin3 = binsLocal[bin[idx3]];
              final AdditiveStatistics bin4 = binsLocal[bin[idx4]];
              bin1.append(idx1, 1);
              bin2.append(idx2, 1);
              bin3.append(idx3, 1);
              bin4.append(idx4, 1);
            }
            for (int i = 4 * (indicesLocal.length / 4); i < indicesLocal.length; i++) {
              binsLocal[bin[indicesLocal[i]]].append(indicesLocal[i], 1);
            }
            bins[finalFIndex] = binsLocal;
          }
          latch.countDown();
        }
      });
    }
    try {
      latch.await();
    } catch (InterruptedException e) {
      // skip
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy