All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.BlockwiseFuncC1 Maven / Gradle / Ivy

package com.expleague.ml;

import com.expleague.commons.math.FuncC1;
import com.expleague.commons.math.vectors.Mx;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.commons.math.vectors.impl.mx.VecBasedMx;
import com.expleague.commons.util.ThreadTools;

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ThreadPoolExecutor;

/**
 * User: solar
 * Date: 21.12.2010
 * Time: 22:07:07
 */
public interface BlockwiseFuncC1 extends FuncC1 {
  void gradient(Vec pointBlock, Vec result, int index);

  double value(Vec pointBlock, int index);
  double transformResultValue(double value);

  int blockSize();

  abstract class Stub extends FuncC1.Stub implements BlockwiseFuncC1 {
    protected static ThreadPoolExecutor pool = ThreadTools.createBGExecutor("Gradient calculator tg", ThreadTools.COMPUTE_UNITS);
    public final Mx gradient(final Mx x) {
      final Mx result = VecTools.copy(x);
      final CountDownLatch latch = new CountDownLatch(ThreadTools.COMPUTE_UNITS);
      for (int t = 0; t < ThreadTools.COMPUTE_UNITS; t++) {
        final int finalT = t;
        pool.execute(() -> {
          for (int i = finalT; i < x.rows(); i+= ThreadTools.COMPUTE_UNITS) {
            gradient(x.row(i), result.row(i), i);
          }
          latch.countDown();
        });
      }
      try {
        latch.await();
      } catch (InterruptedException e) {
        throw new RuntimeException(e);
      }
      return result;
    }

    @Override
    public final Mx gradient(final Vec x) {
      return gradient(x instanceof Mx ? (Mx)x : new VecBasedMx(blockSize(), x));
    }

    protected double value(final Mx blocks) {
      double result = 0.0;
      for (int i = 0; i < blocks.rows(); i ++) {
        result += value(blocks.row(i), i);
      }
      return result;
    }

    @Override
    public final double value(final Vec x) {
      final double value = value(x instanceof Mx ? (Mx) x : new VecBasedMx(blockSize(), x));
      return transformResultValue(value);
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy