All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.loss.pairwise.PairwiseLLLogit Maven / Gradle / Ivy

package com.expleague.ml.loss.pairwise;

import com.expleague.commons.math.FuncC1;
import com.expleague.commons.math.MathTools;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.impl.vectors.VecBuilder;
import com.expleague.commons.seq.IntSeq;
import com.expleague.commons.seq.IntSeqBuilder;
import com.expleague.ml.TargetFunc;
import com.expleague.ml.data.set.DataSet;
import com.expleague.ml.meta.GroupedDSItem;
import gnu.trove.map.TObjectIntMap;
import gnu.trove.map.hash.TObjectIntHashMap;

import java.util.ArrayList;
import java.util.List;
import java.util.stream.DoubleStream;

import static java.lang.Math.*;

@SuppressWarnings("unused")
public class PairwiseLLLogit extends FuncC1.Stub implements TargetFunc {
  protected final Vec target;
  private final DataSet owner;
  private final int[][] groups;

  public PairwiseLLLogit(final Vec target, final DataSet owner) {
    this(target.stream(), owner);
  }

  public PairwiseLLLogit(final IntSeq target, final DataSet owner) {
    this(target.stream().mapToDouble(i -> i), owner);
  }

  private PairwiseLLLogit(final DoubleStream target, final DataSet owner) {
    final TObjectIntMap groups = new TObjectIntHashMap<>();
    final List builders = new ArrayList<>();
    int groupCount = 0;
    for (int i = 0; i < owner.length(); i++) {
      final String group = owner.at(i).groupId();
      int index = groups.get(group);
      if (index == groups.getNoEntryValue()) {
        groups.put(group, index = ++groupCount);
        builders.add(new IntSeqBuilder());
      }
      builders.get(index - 1).append(i);
    }

    this.groups = new int[groupCount][];
    for (int i = 0; i < builders.size(); i++) {
      final IntSeqBuilder builder = builders.get(i);
      this.groups[i] = builder.build().toArray();
    }
    this.target = target.collect(VecBuilder::new, VecBuilder::append, VecBuilder::addAll).build();
    this.owner = owner;
  }

  @Override
  public int dim() {
    return target.dim();
  }

  @Override
  public double value(final Vec point) {
    double result = 0;
    int count = 0;
    for (final int[] group : groups) {
      for (int i = 0; i < group.length; i++) {
        final int a = group[i];
        for (int j = i + 1; j < group.length; j++) {
          final int b = group[j];
          final double y = target.get(a) > target.get(b) ? 1 : -1;
          final double w = abs(target.get(a) - target.get(b));
          if (w > MathTools.EPSILON) {
            count++;
            result += -w * log(1. + exp(-y * (point.get(i) - point.get(j))));
          }
        }
      }
    }
    return exp(result/count);
  }

  @Override
  public Vec gradientTo(final Vec point, Vec to) {
    for (final int[] group : groups) {
      for (int i = 0; i < group.length; i++) {
        final int a = group[i];
        for (int j = i + 1; j < group.length; j++) {
          final int b = group[j];
          final double y = target.get(a) > target.get(b) ? 1 : -1;
          final double w = abs(target.get(a) - target.get(b));
          final double oneMinusP = 1. / (1. + exp(y * (point.get(i) - point.get(j))));
          to.adjust(i, -w * y * oneMinusP);
          to.adjust(j, w * y * oneMinusP);
        }
      }
    }
    return to;
  }


  public int label(final int idx) {
    return (int)target.get(idx);
  }

  public Vec labels() {
    return target;
  }

  @Override
  public DataSet owner() {
    return owner;
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy