All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.loss.blockwise.BlockwiseL2 Maven / Gradle / Ivy

There is a newer version: 1.4.9
Show newest version
package com.expleague.ml.loss.blockwise;

import com.expleague.commons.math.MathTools;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.func.AdditiveStatistics;
import com.expleague.commons.func.Factory;
import com.expleague.ml.BlockwiseFuncC1;
import com.expleague.ml.TargetFunc;
import com.expleague.ml.data.set.DataSet;

/**
 * User: solar
 * Date: 21.12.2010
 * Time: 22:37:55
 */
public class BlockwiseL2 extends BlockwiseFuncC1.Stub implements BlockwiseStatBasedLoss, TargetFunc {
  public final Vec target;
  private final DataSet owner;

  public BlockwiseL2(final Vec target, final DataSet owner) {
    this.target = target;
    this.owner = owner;
  }

  @Override
  public int dim() {
    return target.dim();
  }

  @Override
  public void gradient(final Vec pointBlock, final Vec result, final int index) {
    final int blockSize = blockSize();
    for (int i = 0; i < blockSize; i++) {
      result.set(i, 2 * (pointBlock.get(i) - target.get(index * blockSize + i)));
    }
  }

  @Override
  public double value(final Vec pointBlock, final int index) {
    double result = 0.0;
    final int blockSize = blockSize();
    for (int i = 0; i < blockSize; i++) {
      final double val = pointBlock.get(i) - target.get(index * blockSize + i);
      result += val * val;
    }
    return result;
  }

  @Override
  public double transformResultValue(final double value) {
    return Math.sqrt(value / dim());
  }

  @Override
  public int blockSize() {
    return 1;
  }

  @Override
  public Factory statsFactory() {
    return new Factory() {
      @Override
      public MSEStats create() {
        return new MSEStats(target);
      }
    };
  }

  @Override
  public double value(final MSEStats stats) {
    return stats.sum2;
  }

  @Override
  public double score(final MSEStats stats) {
    return stats.weight > MathTools.EPSILON ? (stats.sum2 - stats.sum * stats.sum / stats.weight) : stats.sum2;
  }

  @Override
  public double bestIncrement(final MSEStats stats) {
    return stats.weight > MathTools.EPSILON ? stats.sum / stats.weight : 0;
  }

  public double get(final int i) {
    return target.get(i);
  }

  @Override
  public DataSet owner() {
    return owner;
  }

  public static class MSEStats implements AdditiveStatistics {
    public volatile double sum;
    public volatile double sum2;
    public volatile int weight;

    private final Vec targets;

    public MSEStats(final Vec target) {
      this.targets = target;
    }

    @Override
    public MSEStats remove(final int index, final int times) {
      final double v = targets.get(index);
      sum -= times * v;
      sum2 -= times * v * v;
      weight -= times;
      return this;
    }

    @Override
    public MSEStats remove(final AdditiveStatistics otheras) {
      final MSEStats other = (MSEStats) otheras;
      sum -= other.sum;
      sum2 -= other.sum2;
      weight -= other.weight;
      return this;
    }

    @Override
    public AdditiveStatistics append(int index, double weight) {
      return append(index, (int)weight);
    }

    @Override
    public AdditiveStatistics remove(int index, double weight) {
      return remove(index,(int)weight);
    }


    @Override
    public MSEStats append(final int index, final int times) {
      final double v = targets.get(index);
      final double v1 = times * v;
      sum += v1;
      sum2 += v1 * v;
      weight += times;
      return this;
    }


    @Override
    public MSEStats append(final AdditiveStatistics otheras) {
      final MSEStats other = (MSEStats) otheras;
      sum += other.sum;
      sum2 += other.sum2;
      weight += other.weight;
      return this;
    }

  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy