All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.models.multiclass.JoinedBinClassModelSeq Maven / Gradle / Ivy

package com.expleague.ml.models.multiclass;

import com.expleague.commons.math.MathTools;
import com.expleague.commons.math.vectors.SingleValueVec;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.commons.seq.Seq;
import com.expleague.commons.util.ArrayTools;

import java.util.function.Function;

public class JoinedBinClassModelSeq implements Function,Vec> {
  protected final Function, Vec>[] internalModel;

  public JoinedBinClassModelSeq(final Function, Vec>[] dirs) {
    internalModel = dirs;
  }

  public Vec probs(final Seq x) {
    final Vec sum = computeSum(x);
    final Vec probs = new ArrayVec(sum.dim());
    for (int i = 0; i < sum.dim(); i++) {
      probs.set(i, MathTools.sigmoid(sum.get(i)));
    }
    return probs;
  }

  public int bestClass(final Seq x) {
    final double[] trans = computeSum(x).toArray();
    return ArrayTools.max(trans);
  }

  @Override
  public Vec apply(Seq x) {
    return new SingleValueVec(bestClass(x));
  }

  private Vec computeSum(final Seq x) {
    Vec[] values = ArrayTools.map(internalModel, Vec.class, func -> func.apply(x));
    if (values[0].dim() != 1) {
      throw new IllegalArgumentException(); //todo is it right?
    }
    final Vec sum = new ArrayVec(internalModel.length);
    for (int i = 0; i < internalModel.length; i++) {
      sum.set(i, values[i].get(0));
    }
    return sum;
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy