All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.loss.WeightedLoss Maven / Gradle / Ivy

There is a newer version: 1.4.9
Show newest version
package com.expleague.ml.loss;

import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.func.AdditiveStatistics;
import com.expleague.commons.func.Factory;
import com.expleague.commons.math.Func;
import com.expleague.commons.math.Trans;
import com.expleague.ml.data.set.DataSet;
import gnu.trove.list.array.TIntArrayList;
import org.jetbrains.annotations.Nullable;

/**
 * User: solar
 * Date: 26.11.13
 * Time: 9:54
 */
public class WeightedLoss extends Func.Stub implements StatBasedLoss {
  private final BasedOn metric;
  private final int[] weights;

  public WeightedLoss(final BasedOn metric, final int[] weights) {
    this.metric = metric;
    this.weights = weights;
  }

  @Override
  public Factory statsFactory() {
    return () -> new Stat(weights, (AdditiveStatistics) metric.statsFactory().create());
  }

  @Override
  public Vec target() {
    return metric.target();
  }

  @Override
  public double bestIncrement(final Stat comb) {
    return metric.bestIncrement(comb.inside);
  }

  @Override
  public double score(final Stat comb) {
    return metric.score(comb.inside);
  }

  @Override
  public double value(final Stat comb) {
    return metric.value(comb.inside);
  }

  @Override
  public int dim() {
    return metric.xdim();
  }

  @Nullable
  @Override
  public Trans gradient() {
    return metric.gradient();
  }

  @Override
  public double value(final Vec x) {
    return metric.trans(x).get(0);
  }

  public double weight(final int index) {
    return weights[index];
  }

  public BasedOn base() {
    return metric;
  }

  @Override
  public DataSet owner() {
    return metric.owner();
  }

  public int[] points() {
    final TIntArrayList result = new TIntArrayList(weights.length + 1000); // Julian ????
    for(int i = 0; i < weights.length; i++) {
      if (weights[i] > 0)
        result.add(i);
    }
    return result.toArray();
  }

  public int[] zeroPoints() {
    final TIntArrayList result = new TIntArrayList(weights.length);
    for(int i = 0; i < weights.length; i++) {
      if (weights[i] == 0)
        result.add(i);
    }
    return result.toArray();
  }

  public static class Stat implements AdditiveStatistics {
    public AdditiveStatistics inside;
    private final int[] weights;

    public Stat(final int[] weights, final AdditiveStatistics inside) {
      this.weights = weights;
      this.inside = inside;
    }

    @Override
    public Stat append(final int index, final int times) {
      final int count = weights[index];
      inside.append(index, count * times);
      return this;
    }

    @Override
    public Stat append(final AdditiveStatistics other) {
      inside.append(((Stat) other).inside);
      return this;
    }

    @Override
    public Stat remove(final int index, final int times) {
      final int count = weights[index];
      inside.remove(index, count * times);
      return this;
    }

    @Override
    public Stat remove(final AdditiveStatistics other) {
      inside.remove(((Stat) other).inside);
      return this;
    }

    @Override
    public Stat append(int index, double weight) {
      final int count = weights[index];
      inside.append(index,weight*count);
      return this;
    }

    @Override
    public Stat remove(int index, double weight) {
      final int count = weights[index];
      inside.remove(index,weight*count);
      return this;
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy