All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.models.nn.layers.ConvLayerBuilder Maven / Gradle / Ivy

package com.expleague.ml.models.nn.layers;

import com.expleague.commons.math.AnalyticFunc;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.seq.ArraySeq;
import com.expleague.commons.seq.Seq;
import com.expleague.ml.models.nn.Identity;
import com.expleague.ml.models.nn.NeuralSpider.BackwardNode;
import com.expleague.ml.models.nn.nodes.ConvNode;

import static com.expleague.ml.models.nn.NeuralSpider.ForwardNode;

public class ConvLayerBuilder implements LayerBuilder {
  private int kSizeX = 3;
  private int kSizeY = 3;
  private int strideX = 1;
  private int strideY = 1;
  private int paddX = 0;
  private int paddY = 0;
  private int outChannels = 0;
  private FillerType fillerType = FillerType.NORMAL;
  private LayerBuilder prevBuilder;
  private ConvLayer layer;
  private int yStart;
  private int wStart;
  private AnalyticFunc activation = new Identity();

  protected ConvLayerBuilder() {}

  public static ConvLayerBuilder create() {
    return new ConvLayerBuilder();
  }

  public ConvLayerBuilder ksize(int kSizeX, int kSizeY) {
    assert(kSizeX > 0);
    assert(kSizeY > 0);

    this.kSizeX = kSizeX;
    this.kSizeY = kSizeY;

    if (paddX != 0 || paddY != 0) {
      paddX = (kSizeX - 1) / 2;
      paddY = (kSizeY - 1) / 2;
    }

    return this;
  }

  public ConvLayerBuilder stride(int strideX, int strideY) {
    assert(strideX > 0);
    assert(strideY > 0);
    this.strideX = strideX;
    this.strideY = strideY;
    return this;
  }

  public ConvLayerBuilder channels(int channels) {
    this.outChannels = channels;
    return this;
  }

  public ConvLayerBuilder samePadd() {
    paddX = (kSizeX - 1) / 2;
    paddY = (kSizeY - 1) / 2;
    return this;
  }

  public ConvLayerBuilder padd(int paddX, int paddY) {
    this.paddX = paddX;
    this.paddY = paddY;
    return this;
  }

  @Override
  public ConvLayer getLayer() {
    return layer;
  }

  @Override
  public LayerBuilder setPrevBuilder(LayerBuilder prevBuilder) {
    if (this.prevBuilder != null) {
      throw new IllegalStateException("Conv layer can have only one previous layer");
    }

    this.prevBuilder = prevBuilder;
    return this;
  }

  @Override
  public LayerBuilder yStart(int yStart) {
    this.yStart = yStart;
    return this;
  }

  @Override
  public LayerBuilder wStart(int wStart) {
    this.wStart = wStart;
    return this;
  }

  public ConvLayerBuilder weightFill(FillerType fillerType) {
    this.fillerType = fillerType;
    return this;
  }

  @Override
  public Layer3D build() {
    if (prevBuilder.getLayer() == null) {
      throw new IllegalStateException("Graph is not acyclic");
    }

    if (layer != null) {
      return layer;
    }

    if (outChannels == 0) {
      throw new IllegalStateException("The number of output channels is not provided");
    }

    layer = new ConvLayer((Layer3D) prevBuilder.getLayer());

    return layer;
  }

  public ConvLayerBuilder activation(Class actClass) {
    try {
      activation = actClass.newInstance();
    }
    catch (InstantiationException | IllegalAccessException e) {
      throw new RuntimeException(e);
    }
    return this;
  }

  public class ConvLayer implements Layer3D {
    protected final Layer3D input;
    private final Filler filler;
    private final ConvNode node;

    protected ConvLayer(Layer3D input) {
      this.input = input;
      filler = FillerType.getInstance(fillerType, this);
      node = new ConvNode(
          yStart, wStart, input.yStart(),
          input.width(), input.height(), width(), height(),
          kSizeX, kSizeY, strideX, strideY, paddX, paddY,
          input.channels(), outChannels, activation);
    }

    public void initWeights(Vec weights) {
      filler.apply(weights.sub(wStart, wdim()));
    }

    @Override
    public int height() {
      return (input.height() + 2 * paddX - kSizeX) / strideX + 1;
    }

    @Override
    public int width() {
      return (input.width() + 2 * paddY - kSizeY) / strideY + 1;
    }

    @Override
    public int channels() {
      return outChannels;
    }

    @Override
    public int xdim() {
      return input.channels() * input.width() * input.height();
    }

    @Override
    public int ydim() {
      return width() * height() * outChannels;
    }

    @Override
    public int wdim() {
      return (input.channels() * kSizeX * kSizeY + 1) * outChannels;
    }

    @Override
    public int yStart() {
      return yStart;
    }

    public int kSizeX() {
      return kSizeX;
    }

    public int kSizeY() {
      return kSizeY;
    }

    public int strideX() {
      return strideX;
    }

    public int strideY() {
      return strideY;
    }

    @Override
    public Seq forwardFlow() {
      return ArraySeq.iterate(ForwardNode.class, node.forward(), ydim());
    }

    @Override
    public Seq backwardFlow() {
      return ArraySeq.iterate(BackwardNode.class, node.backward(), xdim());
    }

    @Override
    public Seq gradientFlow() {
      return  ArraySeq.iterate(BackwardNode.class, node.gradient(), wdim());
    }

    @Override
    public String toString() {
      return "Conv outSize[" + height() + ", " + width() + ", " + channels() + "] " +
          "kernel[" + kSizeX + ", " + kSizeY + "] " +
          "stride[" + strideX + ", " + strideY + "]";
    }
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy