All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.models.nn.layers.PoolLayerBuilder Maven / Gradle / Ivy

package com.expleague.ml.models.nn.layers;

import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.seq.ArraySeq;
import com.expleague.commons.seq.Seq;
import com.expleague.ml.models.nn.nodes.PoolNode;

import static com.expleague.ml.models.nn.NeuralSpider.BackwardNode;
import static com.expleague.ml.models.nn.NeuralSpider.ForwardNode;

public class PoolLayerBuilder implements LayerBuilder {
  private int kSizeX = 3;
  private int kSizeY = 3;
  private int strideX = 3;
  private int strideY = 3;
  private PoolLayer layer;
  private int yStart;
  private LayerBuilder prevBuilder;

  private PoolLayerBuilder() {
    super();
  }

  public static PoolLayerBuilder create() {
    return new PoolLayerBuilder();
  }

  public PoolLayerBuilder ksize(int kSizeX, int kSizeY) {
    assert(kSizeX > 0);
    assert(kSizeY > 0);
    this.kSizeX = kSizeX;
    this.kSizeY = kSizeY;
    return this;
  }

  public PoolLayerBuilder stride(int strideX, int strideY) {
    assert(strideX > 0);
    assert(strideY > 0);
    this.strideX = strideX;
    this.strideY = strideY;
    return this;
  }

  @Override
  public Layer3D getLayer() {
    return layer;
  }

  @Override
  public LayerBuilder setPrevBuilder(LayerBuilder prevBuilder) {
    this.prevBuilder = prevBuilder;
    return this;
  }

  @Override
  public LayerBuilder yStart(int yStart) {
    this.yStart = yStart;
    return this;
  }

  @Override
  public LayerBuilder wStart(int wStart) {
    return this;
  }

  @Override
  public PoolLayer build() {
    if (prevBuilder.getLayer() == null) {
      throw new IllegalStateException("Graph is not acyclic");
    }

    if (layer != null) {
      return layer;
    }

    layer = new PoolLayer((Layer3D) prevBuilder.getLayer());
    return layer;
  }

  public class PoolLayer implements Layer3D {
    private final PoolNode node;
    private final Layer3D input;

    private PoolLayer(Layer3D in) {
      input = in;
      node = new PoolNode(yStart, input.yStart(), input.channels(),
          input.width(), width(), height(), kSizeX, kSizeY, strideX, strideY);
    }

    @Override
    public int wdim() {
      return 0;
    }

    @Override
    public int yStart() {
      return yStart;
    }

    @Override
    public int xdim() {
      return input.ydim();
    }

    @Override
    public int ydim() {
      return width() * height() * channels();
    }

    @Override
    public int height() {
      return (input.height() - kSizeX) / strideX + 1;
    }

    @Override
    public int width() {
      return (input.width() - kSizeY) / strideY + 1;
    }

    @Override
    public int channels() {
      return input.channels();
    }

    @Override
    public void initWeights(Vec weights) { }

    @Override
    public Seq forwardFlow() {
      return ArraySeq.iterate(ForwardNode.class, node.forward(), width() * height() * channels());
    }

    @Override
    public Seq backwardFlow() {
      return ArraySeq.iterate(BackwardNode.class, node.backward(), xdim());
    }

    @Override
    public Seq gradientFlow() {
      return  ArraySeq.iterate(BackwardNode.class, node.gradient(), wdim());
    }

    @Override
    public String toString() {
      return "Pool outSize[" + height() + ", " + width() + ", " + channels() + "] " +
          "kernel[" + kSizeX + ", " + kSizeY + "] stride[" + strideX + ", " + strideY + "]";
    }

    public int kSizeX() {
      return kSizeX;
    }

    public int kSizeY() {
      return kSizeY;
    }

    public int strideX() {
      return strideX;
    }

    public int strideY() {
      return strideY;
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy