scripts.nn.layers.batch_norm2d.dml Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of systemml Show documentation
Show all versions of systemml Show documentation
Declarative Machine Learning
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------
/*
* 2D (Spatial) Batch Normalization layer.
*/
source("nn/util.dml") as util
forward = function(matrix[double] X, matrix[double] gamma, matrix[double] beta,
int C, int Hin, int Win, string mode,
matrix[double] ema_mean, matrix[double] ema_var,
double mu, double epsilon)
return (matrix[double] out, matrix[double] ema_mean_upd, matrix[double] ema_var_upd,
matrix[double] cache_mean, matrix[double] cache_var, matrix[double] cache_norm) {
/*
* Computes the forward pass for a 2D (spatial) batch normalization
* layer. The input data has N examples, each represented as a 3D
* volume unrolled into a single vector.
*
* A spatial batch normalization layer uses the per-channel sample
* mean and per-channel uncorrected sample variance during training
* to normalize each channel of the input data. Additionally, it
* introduces learnable parameters (gamma, beta) to control the
* amount of normalization.
*
* `y = ((x-mean) / sqrt(var+eps)) * gamma + beta`
*
* This implementation maintains exponential moving averages of the
* mean and variance during training for use during testing.
*
* Reference:
* - Batch Normalization: Accelerating Deep Network Training by
* Reducing Internal Covariate Shift, S. Ioffe & C. Szegedy, 2015
* - https://arxiv.org/abs/1502.03167
*
* Inputs:
* - X: Inputs, of shape (N, C*Hin*Win).
* - gamma: Scale parameters, of shape (C, 1).
* - beta: Shift parameters, of shape (C, 1).
* - C: Number of input channels (dimensionality of input depth).
* - Hin: Input height.
* - Win: Input width.
* - mode: 'train' or 'test' to indicate if the model is currently
* being trained or tested. During training, the current batch
* mean and variance will be used to normalize the inputs, while
* during testing, the exponential average of the mean and
* variance over all previous batches will be used.
* - ema_mean: Exponential moving average of the mean, of
* shape (C, 1).
* - ema_var: Exponential moving average of the variance, of
* shape (C, 1).
* - mu: Momentum value for moving averages.
* Typical values are in the range of [0.9, 0.999].
* - epsilon: Smoothing term to avoid divide by zero errors.
* Typical values are in the range of [1e-5, 1e-3].
*
* Outputs:
* - out: Outputs, of shape (N, C*Hin*Win).
* - ema_mean_upd: Updated exponential moving average of the mean,
* of shape (C, 1).
* - ema_var_upd: Updated exponential moving average of the variance,
* of shape (C, 1).
* - cache_mean: Cache of the batch mean, of shape (C, 1).
* Note: This is used for performance during training.
* - cache_var: Cache of the batch variance, of shape (C, 1).
* Note: This is used for performance during training.
* - cache_norm: Cache of the normalized inputs, of
* shape (C, N*Hin*Win). Note: This is used for performance
* during training.
*/
N = nrow(X)
if (mode == 'train') {
# Compute channel-wise mean and variance
# Since we don't have tensors, we will compute the means and variances in a piece-wise fashion.
# - mean of total group is mean of subgroup means
# - variance is the mean of the subgroup variances + the variance of the subgroup means
subgrp_means = matrix(colMeans(X), rows=C, cols=Hin*Win)
subgrp_vars = matrix(colVars(X) * ((N-1)/N), rows=C, cols=Hin*Win) # uncorrected variances
mean = rowMeans(subgrp_means) # shape (C, 1)
var = rowMeans(subgrp_vars) + rowVars(subgrp_means)*(((Hin*Win)-1)/(Hin*Win)) # shape (C, 1)
# Update moving averages
ema_mean_upd = mu*ema_mean + (1-mu)*mean
ema_var_upd = mu*ema_var + (1-mu)*var
}
else {
# Use moving averages of mean and variance during testing
mean = ema_mean
var = ema_var
ema_mean_upd = ema_mean
ema_var_upd = ema_var
}
# Normalize, shift, and scale
# norm = (X-mean)*(var+epsilon)^(-1/2)
# = (X-mean) / sqrt(var+epsilon)
centered = bias_add(X, -mean) # shape (N, C*Hin*Win)
norm = bias_multiply(centered, 1/sqrt(var+epsilon)) # shape (N, C*Hin*Win)
# out = norm*gamma + beta
scaled = bias_multiply(norm, gamma) # shape (N, C*Hin*Win)
out = bias_add(scaled, beta) # shape (N, C*Hin*Win)
# Save variable for backward pass
cache_mean = mean
cache_var = var
cache_norm = norm
}
backward = function(matrix[double] dout, matrix[double] out,
matrix[double] ema_mean_upd, matrix[double] ema_var_upd,
matrix[double] cache_mean, matrix[double] cache_var, matrix[double] cache_norm,
matrix[double] X, matrix[double] gamma, matrix[double] beta,
int C, int Hin, int Win, string mode,
matrix[double] ema_mean, matrix[double] ema_var,
double mu, double epsilon)
return (matrix[double] dX, matrix[double] dgamma, matrix[double] dbeta) {
/*
* Computes the backward pass for a 2D (spatial) batch normalization
* layer.
*
* Inputs:
* - dout: Gradient wrt `out` from upstream, of shape (N, C*Hin*Win).
* - out: Outputs from the forward pass, of shape (N, C*Hin*Win).
* - ema_mean_upd: Updated exponential moving average of the mean
* from the forward pass, of shape (C, 1).
* - ema_var_upd: Updated exponential moving average of the variance
* from the forward pass, of shape (C, 1).
* - cache_mean: Cache of the batch mean from the forward pass, of
* shape (C, 1). Note: This is used for performance during
* training.
* - cache_var: Cache of the batch variance from the forward pass,
* of shape (C, 1). Note: This is used for performance during
* training.
* - cache_norm: Cache of the normalized inputs from the forward
* pass, of shape (C, N*Hin*Win). Note: This is used for
* performance during training.
* - X: Input data matrix to the forward pass, of
* shape (N, C*Hin*Win).
* - gamma: Scale parameters, of shape (C, 1).
* - beta: Shift parameters, of shape (C, 1).
* - C: Number of input channels (dimensionality of input depth).
* - Hin: Input height.
* - Win: Input width.
* - mode: 'train' or 'test' to indicate if the model is currently
* being trained or tested. During training, the current batch
* mean and variance will be used to normalize the inputs, while
* during testing, the exponential average of the mean and
* variance over all previous batches will be used.
* - ema_mean: Exponential moving average of the mean, of
* shape (C, 1).
* - ema_var: Exponential moving average of the variance, of
* shape (C, 1).
* - mu: Momentum value for moving averages.
* Typical values are in the range of [0.9, 0.999].
* - epsilon: Smoothing term to avoid divide by zero errors.
* Typical values are in the range of [1e-5, 1e-3].
*
* Outputs:
* - dX: Gradient wrt `X`, of shape (N, C*Hin*Win).
* - dgamma: Gradient wrt `W`, of shape (C, 1).
* - dbeta: Gradient wrt `b`, of shape (C, 1).
*
*/
N = nrow(X)
mean = cache_mean
var = cache_var
norm = cache_norm
centered = bias_add(X, -mean) # shape (N, C*Hin*Win)
if (mode == 'train') {
# Compute gradients during training
dgamma = util::channel_sums(dout*norm, C, Hin, Win) # shape (C, 1)
dbeta = util::channel_sums(dout, C, Hin, Win) # shape (C, 1)
dnorm = bias_multiply(dout, gamma) # shape (N, C*Hin*Win)
dvar = util::channel_sums((-1/2) * bias_multiply(centered, (var+epsilon)^(-3/2)) * dnorm,
C, Hin, Win) # shape (C, 1)
dmean_norm_branch = util::channel_sums(bias_multiply(dnorm, -1/sqrt(var+epsilon)), C, Hin, Win)
dmean_var_branch = util::channel_sums((-2/(N*Hin*Win)) * centered, C, Hin, Win)
dmean_var_branch = dmean_var_branch * dvar # we can't use a function within an expression yet
dmean = dmean_norm_branch + dmean_var_branch # shape (C, 1)
dX_norm_branch = bias_multiply(dnorm, 1/sqrt(var+epsilon))
dX_mean_branch = (1/(N*Hin*Win)) * bias_add(matrix(0, rows=1, cols=C*Hin*Win), dmean)
dX_var_branch = (2/(N*Hin*Win)) * bias_multiply(centered, dvar)
dX = dX_norm_branch + dX_mean_branch + dX_var_branch # shape (N, C*Hin*Win)
}
else {
# Compute gradients during testing
dgamma = util::channel_sums(dout*norm, C, Hin, Win) # shape (C, 1)
dbeta = util::channel_sums(dout, C, Hin, Win) # shape (C, 1)
dnorm = bias_multiply(dout, gamma) # shape (N, C*Hin*Win)
dX = bias_multiply(dnorm, 1/sqrt(var+epsilon)) # shape (N, C*Hin*Win)
}
}
init = function(int C)
return (matrix[double] gamma, matrix[double] beta,
matrix[double] ema_mean, matrix[double] ema_var) {
/*
* Initialize the parameters of this layer.
*
* Note: This is just a convenience function, and parameters
* may be initialized manually if needed.
*
* Inputs:
* - C: Number of input channels (dimensionality of input depth).
*
* Outputs:
* - gamma: Scale parameters, of shape (C, 1).
* - beta: Shift parameters, of shape (C, 1).
* - ema_mean: Exponential moving average of the mean, of
* shape (C, 1).
* - ema_var: Exponential moving average of the variance, of
* shape (C, 1).
*/
gamma = matrix(1, rows=C, cols=1)
beta = matrix(0, rows=C, cols=1)
ema_mean = matrix(0, rows=C, cols=1)
ema_var = matrix(1, rows=C, cols=1)
}