All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.func.ScaledVectorFunc Maven / Gradle / Ivy

package com.expleague.ml.func;

import com.expleague.commons.math.DiscontinuousTrans;
import com.expleague.commons.math.Func;
import com.expleague.commons.math.Trans;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecTools;
import org.jetbrains.annotations.NotNull;

/**
 * User: qdeee
 * Date: 18.05.17
 */
public class ScaledVectorFunc extends Trans.Stub {
  public final Vec weights;
  public final Func function;

  public ScaledVectorFunc(Func function, Vec weights) {
    this.weights = weights;
    this.function = function;
  }

  @Override
  public Vec transTo(Vec argument, Vec to) {
    VecTools.assign(to, weights);
    VecTools.scale(to, function.value(argument));
    return to;
  }

  @Override
  public DiscontinuousTrans subgradient() {
    final DiscontinuousTrans subgradient = function.subgradient();
    if (subgradient == null) {
      throw new UnsupportedOperationException();
    }

    return new DiscontinuousTrans.Stub() {
      @NotNull
      @Override
      public Vec leftTo(Vec x, Vec to) {
        subgradient.leftTo(x, to);
        return VecTools.scale(to, VecTools.sum(weights));
      }

      @NotNull
      @Override
      public Vec rightTo(Vec x, Vec to) {
        subgradient.rightTo(x, to);
        return VecTools.scale(to, VecTools.sum(weights));
      }

      @Override
      public int xdim() {
        return ScaledVectorFunc.this.xdim();
      }

      @Override
      public int ydim() {
        return xdim();
      }
    };
  }

  @Override
  public int xdim() {
    return function.xdim();
  }

  @Override
  public int ydim() {
    return weights.dim();
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy