All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.loss.CrossEntropy Maven / Gradle / Ivy

package com.expleague.ml.loss;

import com.expleague.commons.math.FuncC1;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.commons.seq.IntSeq;
import com.expleague.ml.BlockedTargetFunc;
import com.expleague.ml.data.set.DataSet;
import com.expleague.ml.func.generic.LogSoftMax;

public class CrossEntropy extends FuncC1.Stub implements BlockedTargetFunc {
  protected final Vec target;
  private final DataSet owner;
  private final int nClasses;

  public CrossEntropy(final Vec target, final DataSet owner, int nClasses) {
    this.target = target;
    this.owner = owner;
    this.nClasses = nClasses;
  }

  public CrossEntropy(final IntSeq target, final DataSet owner, int nClasses) {
    this.target = new ArrayVec(target.length());
    for (int i = 0; i < target.length(); i++) {
      this.target.set(i, target.at(i));
    }
    this.owner = owner;
    this.nClasses = nClasses;
  }

  @Override
  public Vec gradientTo(final Vec x, Vec to) {
    for (int i = 0; i < x.dim() / nClasses; i++) {
      final int trueClass = label(i);
      Vec subX = x.sub(i * nClasses, nClasses);
      Vec subTo = to.sub(i * nClasses, nClasses);
      LogSoftMax.staticGrad(subX, trueClass, subTo);
    }

    return to;
  }

  @Override
  public int dim() {
    return target.dim();
  }

  @Override
  public double value(final Vec logits) {
    double result = 0.;

    for (int i = 0; i < logits.dim() / nClasses; i++) {
      final int trueClass = label(i);
      result += LogSoftMax.staticValue(logits.sub(i * nClasses, nClasses), trueClass);
    }

    return result;
  }

  public int label(final int idx) {
    return (int)target.get(idx);
  }

  public Vec labels() {
    return target;
  }

  @Override
  public DataSet owner() {
    return owner;
  }

  @Override
  public FuncC1 block(int index) {
    return new LogSoftMax(nClasses, label(index));
  }

  @Override
  public int blocksCount() {
    return target.dim();
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy