All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.loss.CompositeFunc Maven / Gradle / Ivy

package com.expleague.ml.loss;

import com.expleague.commons.math.MathTools;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.commons.math.vectors.impl.ThreadLocalArrayVec;
import com.expleague.commons.math.FuncC1;
import com.expleague.commons.math.TransC1;

import static com.expleague.commons.math.vectors.VecTools.assign;

/**
 * User: solar
 * Date: 01.06.15
 * Time: 13:52
 */
public class CompositeFunc extends FuncC1.Stub {
  private final FuncC1 first;
  private final TransC1[] after;

  public CompositeFunc(FuncC1 first, TransC1... after) {
    this.first = first;
    this.after = after;
    int dim = first.dim();
    for(int i = 0; i < after.length; i++) {
      if (dim != after[i].ydim())
        throw new IllegalArgumentException("Composite dimensions does not match: " + after[i] + " must have y dimension of " + dim);
      dim = after[i].xdim();
    }
  }

  final ThreadLocalArrayVec nextTemp = new ThreadLocalArrayVec();
  final ThreadLocalArrayVec gradTemp = new ThreadLocalArrayVec();
  final ThreadLocalArrayVec resultTemp = new ThreadLocalArrayVec();
  @Override
  public Vec gradient(Vec x) {
    final Vec[] values = new Vec[after.length + 1];
    values[after.length] = x;
    for (int i = after.length - 1; i >= 0; i--) {
      values[i] = after[i].trans(x);
      if (Double.isNaN(VecTools.norm(values[i]))) {
        throw new RuntimeException("" + after[i].trans(x));
      }
      x = values[i];
    }
    Vec result = first.gradientTo(x, resultTemp.get(first.xdim()));
    for (int i = 0; i < after.length; i++) {
      final Vec next = nextTemp.get(after[i].xdim());
      for (int j = 0; j < result.length(); j++) {
        if (Math.abs(result.get(j)) < MathTools.EPSILON)
          continue;
        final Vec grad = gradTemp.get(next.dim());
        final Vec gradientRowJ = after[i].gradientRowTo(values[i + 1], grad, j);
        VecTools.incscale(next, gradientRowJ, result.get(j));
      }
      result = resultTemp.get(next.dim());
      assign(result, next);
    }
    return result;
  }

  @Override
  public double value(Vec x) {
    for (int i = after.length - 1; i >= 0; i--) {
      x = after[i].trans(x);
    }
    return first.value(x);
  }

  @Override
  public int dim() {
    return after[after.length - 1].xdim();
  }

  public FuncC1 first() {
    return first;
  }

  @Override
  public String toString() {
    final StringBuilder builder = new StringBuilder();
    builder.append(first.toString());
    for(int i = 0; i < after.length; i++) {
      builder.append(" o ").append(after[i]);
    }
    return builder.toString();
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy