Neural networks: tricks of the trade

Pieter-Jan Hoedt @ OPTIMA 2022
@ml_hoedt ;


newcomers to the field waste much time wondering why their networks train so slowly and perform so poorly
(Orr & Müller, 1998)
cover page of the book 'Neural Networks: Tricks of the Trade'
cover of the 1998 edition



NIST form to collect handwritten character data
Data form for NIST Special Database 19 (Grother, 1995)


examples of MNIST digits
MNIST digits as 8-bit grayscale images (Bottou et al., 1994; Yadav & Bottou, 2019)

Looking at

examples of MNIST digits
MNIST digits as 8-bit grayscale images (Bottou et al., 1994; Yadav & Bottou, 2019)

								mean = np.mean(mnist, axis=0)

								std = np.std(mnist, axis=0)

Looking at

examples of MNIST digits
MNIST digits as 8-bit grayscale images (Bottou et al., 1994; Yadav & Bottou, 2019)

								mean = np.mean(mnist, axis=0)
								plt.imshow(mean, vmin=0, vmax=255, cmap='gray')

								std = np.std(mnist, axis=0)
								plt.imshow(std, vmin=0, cmap='viridis')

Looking at

examples of scaled MNIST digits
MNIST digits as 8-bit grayscale images
mean and standard deviation over scaled MNIST digits
Statistics over 8-bit MNIST digits

Looking at

comparison of different dimensionality reduction methods on MNIST
Different dimensionality reduction methods on MNIST digits (Wang et al., 2021)

Statistics of the

examples of scaled MNIST digits
MNIST digits as 8-bit grayscale images
mean and standard deviation over scaled MNIST digits
Statistics over 8-bit MNIST digits


examples of scaled MNIST digits
MNIST digits as floating point values $\in [0, 1]$
mean and standard deviation over scaled MNIST digits
Statistics over $[0, 1]$ MNIST digits


examples of centred MNIST digits
Mean-centred MNIST digits
mean and standard deviation over centred MNIST digits
Statistics over centred MNIST digits


examples of normalised MNIST digits
Mean- and variance-normalised MNIST digits
mean and standard deviation over normalised MNIST digits
Statistics over normalised MNIST digits


examples of clipped normalised MNIST digits
Mean- and variance-normalised MNIST digits (clipped)
clipped mean and standard deviation over normalised MNIST digits
Statistics over normalised MNIST digits (clipped)


examples of PCA-whitened MNIST digits
PCA-whitened MNIST digits (clipped)
mean and standard deviation over PCA-whitened MNIST digits
Statistics over PCA-whitened MNIST digits (clipped)


examples of ZCA-whitened MNIST digits
ZCA-whitened MNIST digits (clipped)
mean and standard deviation over ZCA-whitened MNIST digits
Statistics over ZCA-whitened MNIST digits (clipped)

Why normalise/whiten?

$$\begin{aligned} \mathrm{E}_{\vec{w}} &= \frac{1}{|\mathcal{D}|} \sum_{\vec{x} \in \mathcal{D}} \frac{1}{2} (\vec{w} \cdot \vec{x} - y)^2 \\ \operatorname{\nabla} \mathrm{E}_{\vec{w}} &= \frac{1}{|\mathcal{D}|} \sum_{\vec{x} \in \mathcal{D}} (\vec{w} \cdot \vec{x} - y) \vec{x} \\ \operatorname{\nabla}^2 \mathrm{E}_{\vec{w}} &= \frac{1}{|\mathcal{D}|} \sum_{\vec{x} \in \mathcal{D}} \vec{x} \otimes \vec{x} \end{aligned}$$

Taken from section 5.1 in (LeCun et al., 1998)

Why normalise/whiten?

$\begin{aligned} \mathrm{ E}_{\vec{w}^*} &\approx \mathrm{E}_{\vec{w}} + \operatorname{\nabla} \mathrm{E}_{\vec{w}} \cdot (\vec{w}^* - \vec{w}) \\ %+ \frac{1}{2} \|\vec{w}^* - \vec{w}\|_{\operatorname{\nabla}^2 \mathrm{E}_{\vec{w}}}^2 \\ \operatorname{\nabla} \mathrm{E}_{\vec{w}^*} &\approx \operatorname{\nabla} \mathrm{E}_{\vec{w}} + \operatorname{\nabla}^2 \mathrm{E}_{\vec{w}} \cdot (\vec{w}^* - \vec{w}) \end{aligned}$

$$\vec{w}^* \approx \vec{w} - \bigg(\operatorname{\nabla}^2 \mathrm{E}_{\vec{w}}\bigg)^{-1}\operatorname{\nabla} \mathrm{E}_{\vec{w}}$$

Taken from section 5.1 in (LeCun et al., 1998)

Why not to whiten?

Whitening and Second Order Optimization Both make Information in the 
								     Dataset Unusable During Training, and Can Reduce or Prevent Generalization

BUT:  $\mat{W}_\mathrm{ZCA} = \mat{C}^{-1/2}$

(Wadia et al., 2021)

Normalisation of

examples of clipped normalised MNIST digits
Mean- and variance-normalised MNIST digits (clipped)
clipped mean and standard deviation over normalised MNIST digits
Statistics over normalised MNIST digits (clipped)

Practical Normalisation of

examples of channel-normalised MNIST digits
Mean- and variance-normalised MNIST digits (channel)
mean and standard deviation over channel-normalised MNIST digits
Statistics over normalised MNIST digits (channel)

Invariances in

example augmentations of an MNIST digit
Example augmentations of an MNIST digit


  • be aware of origins
  • gain insights (e.g. using visualisation)
  • use statistics to your advantage
  • consider inductive biases of the model


  1. Data
  2. Model
  3. Learning


Debugging your

test error as function of complexity in traditional vs modern view of overfitting
Complexity curves for two models of overfitting (Belkin et al., 2019)

Debugging your

								import numpy as np

								def my_custom_dl_implementation():
no! god please! no! no!

Debugging your

								for epoch in range(max_epochs):
									for x, y in train_loader:
										pred = model(x)
										err = loss_func(pred, y)

									with torch.no_grad():
										for x, y in valid_loader:
											pred = model(x)
											val_err = loss_func(pred, y)
no. no!

Debugging your

								def update(model, loader, loss_func, opt):

									for x, y in loader:
										pred = model(x)
										err = loss_func(pred, y)

									return err.item()
								def evaluate(model, loader, loss_func):

									errs = []
									for x, y in loader:
										pred = model(x)
										val_err = loss_func(pred, y)
									return errs
								val_errs = evaluate(model, valid_loader, loss_func)
								for epoch in range(1, max_epochs + 1):
									err = update(model, train_loader, loss_func, opt)
									val_errs = evaluate(model, valid_loader, loss_func)

Debugging your

Debugging your

								from import Subset

								if debug:
									train_data = Subset(data, [0, 1])
									valid_data = Subset(data, range(2, len(data)))
loss curves for overfitting model

choosing your

examples of inductive biases in different architectures
Image taken from a blog post by Sam Finlayson

Choosing your

attention system in the transformer architecture
Transformer architecture (Vaswani et al., 2017)
visualisation of the BERT model
BERT language model (Devlin et al., 2019; image from a blog post by Jay Alammar).
visualisation of the vision transformer model
Vision transformer architecture (Dosovitskiy et al., 2021)
visualisation of the MLP mixer model
MLP mixer architecture (Tolstikhin et al., 2021)

Hopfield networks is all you need (Ramsauer et al., 2021)

choosing your

(van der Smagt & Hirzinger, 1998; Srivasta et al., 2015; He et al., 2016)

Initialising your

Initial biases in your

$$\vec{x} = \vec{0} \Rightarrow \hat{\vec{y}} = \sigma(\vec{b})$$
  • Balanced data: $\vec{b} = \vec{0}$
  • Unbalanced data: $b_k \propto \ln\Big(\frac{p(y_k)}{1 - p(y_k)}\Big)$

"open" LSTM gate: $\vec{b} \gg 0$
"closed" LSTM gate: $\vec{b} \ll 0$

Normalising your

Batch-normalisation: $$\begin{aligned} \hat{\vec{x}} &= \frac{\vec{x} - \vec{\mu}_\mathcal{B}}{\vec{\sigma}_\mathcal{B}}\\ \vec{y} &= \vec{\gamma} \odot \vec{x} + \vec{\beta} \end{aligned}$$

(Schraudolph, 1998; Ioffe et al., 2015)

Normalising your

visualisation of multiple normalisation methods
different axes of normalisation (Hoedt et al., 2022; Wu & He, 2018)



  • start as simple as possible
  • use prior knowledge (or don't)
  • mind the signal flow!


  1. Data
  2. Model
  3. Learning

Optimisation algorithms

animation of various optimisation algorithms on the Beale function (contours)
A collection of optimisation algorithms on the Beale function.
animation of various optimisation algorithms in a saddle point (3D)
Same collection of optimisation algorithms in a saddle point.

Images from the OG blog post by Sebastian Ruder

Saddle points during

LSTM training curve with plateau
typical LSTM training curve (image from the pytorch forums)

Rate of

learning curves for different learning rates
Learning curves for different learning rates
If you have time to tune only one hyperparameter, tune the learning rate.

Deep Learning Book (Goodfellow et al., 2016 )

Rate of

sqrt learning rate schedule
$\eta(t) = \frac{1}{\sqrt{t + 1}} \eta$
exponential learning rate schedule
$\eta(t) = \alpha^t \eta$
stepping learning rate schedule
$\eta(t) = \alpha^{\lfloor t / N \rfloor} \eta$
cosine learning rate schedule
$\eta(t) = \eta_T + \frac{\eta_0 - \eta_T}{2} \big(1 + \cos(\pi \frac{t}{T})\big)$
cosine learning rate schedule

(Loshchilov & Hutter, 2017; © images from (Zhang et al., 2021))

Rate of

								ce = nn.CrossEntropyLoss(reduction="mean")
learning curves as functions of epoch vs update
Learning curves for different batch sizes (and number of epochs)

Rate of

learning curves as functions of epoch vs update
Learning curves for different batch sizes (and number of epochs)

								ce = nn.CrossEntropyLoss(reduction="sum")

Rate of

diagram depicting multi-task learning
Multi-task learning (image from a blog post from Sebastian Ruder)
example of semantic segmentation for self-driving cars
Segmentation (image from pydata 2017 session)

(Caruana, 1998)


test error as function of complexity in traditional vs modern view of overfitting
Complexity curves for two models of overfitting (Belkin et al., 2019)

(Prechelt, 1998)


standard neural network vs network with dropout
Dropout regularisation (Srivasta et al., 2014)

(Gal & Ghahramani, 2016; Li et al., 2019)


multi-class labels without and with label smoothing
Label smoothing (image from a blog post by Parthvi Shah)
2D projections of penultimate activations without and with label smoothing
Label smoothing (Müller et al., 2019)

(Szegedy et al., 2016)


unit balls in L1, L2 norms and how they affect parameters
unit balls in different norms and how they affect parameters
(image taken from wikipedia)

$$\min_\theta L(g(\vec{x} \mathbin{;} \theta), y) + \lambda \|\theta\|$$

(Hanson & Pratt, 1989; Rögnvaldsson, 1998; Loshchchilov & Hutter, 2019)


  • adaptive optimisers are easier (SGD is better)
  • maximise the rate of learning
  • use regularisation (after overfitting)


  1. Understand your data
  2. Keep your model simple
  3. Learning benefits from tuning


