Gradient Clipping to remove batchnorm layers.

Adaptive Gradient Clipping introduced in the paper “High-Performance Large-Scale Image Recognition Without Normalization” from DeepMind by Brock et al. (2021)

Pascanu, Mikolov, and Bengio (2013) first introduced the gradient clipping technique: for model parameters \(\theta\) and loss function \(L\), the gradient for a particular layer/group of weights(\(l\)) \(G^l\) is \(G^l = \frac{\partial L}{\partial \theta^l}\). Now, Gradient clipping scales down the gradient based on it’s norm.

\[ G^l \rightarrow \begin{cases} \lambda \frac{G^l}{\vert\vert G^l\vert\vert},& \text{if } \vert\vert G^l\vert\vert > \lambda\\ G^l, & \text{otherwise} \end{cases} \]

Here, the gradient clipping is performed independent of the weights it affects, i.e it only dependent on \(G\). Brock et al. (2021) suggests **Adaptive Gradient Clipping**: if by modifying the gradient clipping condition by introducing the Frobenius norm of the `weights`

(\(W^l\)) the gradient is updating and the gradient \(G^l\) for each block \(i\) in \(\theta\) parameters:

\[ G_i^l \rightarrow \begin{cases} \lambda\frac{\vert\vert W_i^l\vert\vert_F^*}{\vert\vert G_i^l\vert\vert_F},& \text{if } \frac{\vert\vert G_i^l\vert\vert_F}{\vert\vert W_i^l\vert\vert_F^*} > \lambda \\ G_i^l,& \text{otherwise } \end{cases} \\ where \hspace{1mm} {\vert\vert W_i^l\vert\vert_F^*} = max({\vert\vert W_i^l\vert\vert_F^*}, \epsilon) \]

Notice that the condition that regulates the gradient norm depends on the norm with respect to the block of weights its being used to update. Hence, if the gradient is too big for the weights or the weights are too small for the gradient, this clipping strategy suggests to scale the gradient down. This property makes the clipping *adaptive*.

You might find yourself wondering now, *“This is cool, but this doesn’t solve the normalization of features that BatchNorm provides, so where’s that?”*. Here, the second trick adapted from a previous paper by the first author Brock, De, and Smith (2021) that introduces - **Weight Standardization** and **Activation scaling**

\[ \text{Weight standardization: } \hat W_{ij}= \frac{W_{ij} - \mu_i}{\sqrt N \sigma_i} \]

where \((mean)\mu = (1/N) \sum_j W_{ij}\), \((variance) \sigma_i^2 = (1/N)\sum_j (W_{ij} - \mu)^2\) and \(N\) is fan-in i.e number of input units. In *activation scaling*, \(\gamma\) is used to scale the activation outputs, where \(\gamma = \sqrt{2/(1 - (1/\pi))}\) for ReLUs.

You can find this code in `class WSConv2d`

here. Notice that the weight that is being used to convoluted is being called in every forward pass. The weight is now reshaped to a 2D tensor - `output_channels * (input_channels * (kernel ** 2))`

- where `i`

is output_channels and rest is j. Hence, the number of inputs to this is `input_channels * (kernel ** 2)`

i.e. the `fan-in`

parameter. PyTorch’s ONNX export didn’t support `mean_var`

as of that release, hence the mean and variance are calculated separately.

Residual connections have traditionally been \(h_{i+1} = h_i + f_i(h_i)\), where \(h_i\) is input to the residual block \(f_i\). In the NF family of networks, it is modified to \(h_{i+1} = h_i + \alpha f_i(h_i/\beta_i)\). Intuitively, this translates to \(\alpha\) `scaling the residual block activations`

to increase variance and \(\beta\) to `scale down`

the input of the function inside the residual block, as opposed to setting it as identity. \(\alpha\) is set to `0.2`

and \(\beta\) is predicted as \(\beta=\sqrt{Variance(h_i)}\).

As Yannic Kilcher explains, there is an *implicit dependence on the batch size* in AGC, while `BatchNorm`

has an *explicit dependence on the batch size*. However, the paper doesn’t clearly mention how disentangling the above components effect the accuracy, etc.

To summarize the contributions, `Weight standardization`

and `Activation Scaling`

in combination control the mean-shift at initialization that `BatchNorm`

provides. The `Adaptive Gradient Clipping`

helps prevent the shift by making sure the parameters don’t significantly grow.

These techniques are used in the NAS pipeline to discover the family of architectures the authors term as `NFNets`

. Hence, all of the above techniques combined **eliminates the mean-shift** - the central role of BatchNorm. This technique scales well with large training batch sizes. The PyTorch code is available on GitHub

`https://github.com/vballoli/nfnets-pytorch`

There are interesting future avenues using these tricks. Specifically, in Meta Learning for classification where BatchNorm plays a significant role and how the pre-training on these gradients effect and translate to task-specific adaptation.

Brock, Andrew, Soham De, and Samuel L Smith. 2021. “Characterizing Signal Propagation to Close the Performance Gap in Unnormalized ResNets.” *arXiv Preprint arXiv:2101.08692*.

Brock, Andrew, Soham De, Samuel L Smith, and Karen Simonyan. 2021. “High-Performance Large-Scale Image Recognition Without Normalization.” *arXiv Preprint arXiv:2102.06171*.

Pascanu, Razvan, Tomas Mikolov, and Yoshua Bengio. 2013. “On the Difficulty of Training Recurrent Neural Networks.” In *International Conference on Machine Learning*, 1310–18. PMLR.

For attribution, please cite this work as

Balloli (2021, March 31). Tour de ML: Adaptive Gradient Clipping. Retrieved from https://tourdeml.github.io/blog/posts/2021-03-31-adaptive-gradient-clipping/

BibTeX citation

@misc{balloli2021adaptive, author = {Balloli, Vaibhav}, title = {Tour de ML: Adaptive Gradient Clipping}, url = {https://tourdeml.github.io/blog/posts/2021-03-31-adaptive-gradient-clipping/}, year = {2021} }