All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.loss.DSSumFuncComposite Maven / Gradle / Ivy

package com.expleague.ml.loss;

import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.ml.BlockedTargetFunc;
import com.expleague.commons.math.FuncC1;
import com.expleague.commons.math.Trans;
import com.expleague.commons.math.TransC1;
import com.expleague.ml.data.set.DataSet;

import java.util.function.Function;

/**
 * User: solar
 * Date: 01.06.15
 * Time: 13:23
 */
public class DSSumFuncComposite extends DSSumFuncC1 implements FuncC1 {
  private final BlockedTargetFunc tgt;
  private final Function decisionFactory;
  private final int dim;

  public DSSumFuncComposite(DataSet ds, BlockedTargetFunc tgt, Function decisionFactory) {
    super(ds);
    this.tgt = tgt;
    this.decisionFactory = decisionFactory;
    dim = decisionFactory.apply(ds.at(0)).xdim();
  }

  @Override
  public CompositeFunc component(int index) {
    return new CompositeFunc((FuncC1)tgt.block(index), decisionFactory.apply(ds.at(index)));
  }

  public Decision decision(final Vec x) {
    return new Decision(x);
  }

  @Override
  public Vec gradient(Vec x) {
    final Vec result = new ArrayVec(dim());
    final int length = length();
    for (int i = 0; i < length; i++){
      VecTools.append(result, component(i).gradient(x));
    }
    return result;
  }

  public int dim() {
    return dim;
  }

  public class Decision implements Function {
    public final Vec x;

    public Decision(Vec x) {
      this.x = x;
    }

    @Override
    public Vec apply(Item argument) {
      final Trans compute = decisionFactory.apply(argument);
      return compute.trans(x);
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy