scripts.nn.layers.max_pool2d.dml Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of systemml Show documentation
Show all versions of systemml Show documentation
Declarative Machine Learning
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------
/*
* Max Pooling layer.
*/
source("nn/util.dml") as util
forward = function(matrix[double] X, int C, int Hin, int Win, int Hf, int Wf,
int strideh, int stridew, int padh, int padw)
return (matrix[double] out, int Hout, int Wout) {
/*
* Computes the forward pass for a 2D spatial max pooling layer.
* The input data has N examples, each represented as a 3D volume
* unrolled into a single vector.
*
* This implementation uses `im2col` internally for each image to
* extract local image regions (patches) of each channel slice into
* columns, and then performs max pooling over the patches to compute
* the output maps.
*
* Inputs:
* - X: Inputs, of shape (N, C*Hin*Win).
* - C: Number of input channels (dimensionality of input depth).
* - Hin: Input height.
* - Win: Input width.
* - Hf: Filter height.
* - Wf: Filter width.
* - strideh: Stride over height.
* - stridew: Stride over width.
* - padh: Padding for top and bottom sides.
* A typical value is 0.
* - padw: Padding for left and right sides.
* A typical value is 0.
*
* Outputs:
* - out: Outputs, of shape (N, C*Hout*Wout).
* - Hout: Output height.
* - Wout: Output width.
*/
N = nrow(X)
Hout = as.integer(floor((Hin + 2*padh - Hf)/strideh + 1))
Wout = as.integer(floor((Win + 2*padw - Wf)/stridew + 1))
pad_value = -1/0 # in max pooling we pad with -infinity
# Create output volume
out = matrix(0, rows=N, cols=C*Hout*Wout)
# Max pooling - im2col implementation
parfor (n in 1:N) { # all examples
img = matrix(X[n,], rows=C, cols=Hin*Win) # reshape
if (padh > 0 | padw > 0) {
# Pad image to shape (C, (Hin+2*padh)*(Win+2*padw))
img = util::pad_image(img, Hin, Win, padh, padw, pad_value)
}
img_maxes = matrix(0, rows=C, cols=Hout*Wout) # zeros
parfor (c in 1:C) { # all channels
# Extract local image slice patches into columns with im2col, of shape (Hf*Wf, Hout*Wout)
img_slice_cols = util::im2col(img[c,], Hin+2*padh, Win+2*padw, Hf, Wf, strideh, stridew)
# Max pooling on patches
img_maxes[c,] = colMaxs(img_slice_cols)
}
out[n,] = matrix(img_maxes, rows=1, cols=C*Hout*Wout)
}
}
backward = function(matrix[double] dout, int Hout, int Wout, matrix[double] X,
int C, int Hin, int Win, int Hf, int Wf,
int strideh, int stridew, int padh, int padw)
return (matrix[double] dX) {
/*
* Computes the backward pass for a 2D spatial max pooling layer.
* The input data has N examples, each represented as a 3D volume
* unrolled into a single vector.
*
* Inputs:
* - dout: Gradient wrt `out` from upstream, of
* shape (N, C*Hout*Wout).
* - Hout: Output height.
* - Wout: Output width.
* - X: Input data matrix, of shape (N, C*Hin*Win).
* - C: Number of input channels (dimensionality of input depth).
* - Hin: Input height.
* - Win: Input width.
* - Hf: Filter height.
* - Wf: Filter width.
* - strideh: Stride over height.
* - stridew: Stride over width.
* - padh: Padding for top and bottom sides.
* A typical value is 0.
* - padw: Padding for left and right sides.
* A typical value is 0.
*
* Outputs:
* - dX: Gradient wrt `X`, of shape (N, C*Hin*Win).
*/
N = nrow(X)
pad_value = -1/0 # in max pooling we pad with -infinity
# Create gradient volume
dX = matrix(0, rows=N, cols=C*Hin*Win)
# Gradient of max pooling
parfor (n in 1:N, check=0) { # all examples
img = matrix(X[n,], rows=C, cols=Hin*Win)
if (padh > 0 | padw > 0) {
# Pad image to shape (C, (Hin+2*padh)*(Win+2*padw))
img = util::pad_image(img, Hin, Win, padh, padw, pad_value)
}
dimg = matrix(0, rows=C, cols=(Hin+2*padh)*(Win+2*padw))
parfor (c in 1:C, check=0) { # all channels
img_slice = matrix(img[c,], rows=Hin+2*padh, cols=Win+2*padw)
dimg_slice = matrix(0, rows=Hin+2*padh, cols=Win+2*padw)
for (hout in 1:Hout, check=0) { # all output rows
hin = (hout-1)*strideh + 1
for (wout in 1:Wout) { # all output columns
win = (wout-1)*stridew + 1
img_slice_patch = img_slice[hin:hin+Hf-1, win:win+Wf-1]
max_val_ind = img_slice_patch == max(img_slice_patch) # max value indicator matrix
# gradient passes through only for the max value(s) in this patch
dimg_slice_patch = max_val_ind * dout[n, (c-1)*Hout*Wout + (hout-1)*Wout + wout]
dimg_slice[hin:hin+Hf-1, win:win+Wf-1] = dimg_slice[hin:hin+Hf-1, win:win+Wf-1]
+ dimg_slice_patch
}
}
dimg[c,] = matrix(dimg_slice, rows=1, cols=(Hin+2*padh)*(Win+2*padw))
}
if (padh > 0 | padw > 0) {
# Unpad image gradient
dimg = util::unpad_image(dimg, Hin, Win, padh, padw) # shape (C, (Hin+2*padh)*(Win+2*padw))
}
dX[n,] = matrix(dimg, rows=1, cols=C*Hin*Win)
}
}