catastrophic forgetting

The art of learning is the art of knowing what to remember and what to forget.

Background

In this post, we are going to discuss a seminal method from the field of Continual Learning called Elastic Weight Consolidation, also known as EWC.

Continual Learning (CL) tackles learning from sequential data streams with limited retention and retraining capacity. It requires efficient use of past data and adaptation to new contexts under changing distributions without catastrophic forgetting.

Let’s consider a simple example. Suppose we have a neural network with two tasks, where the first task is to classify between cats and dogs, and the second task is to classify between cars and bikes. The network is trained on the first task, and then the second task is introduced. The network is then trained on the second task, and the weights are updated. However, the weights that were learned for the first task are now overwritten, and the network forgets how to classify between cats and dogs.

More precisely, in Continual Learning a model $f_\theta$ is trained on a sequence of $T$ tasks, where for each task $t \in \{ 1, \ldots, T \} $ the learner only gets access to a subset of samples of the given task: $ D_t = \{ (x_i, y_i) \}_{i=1}^{N_t} $. However, at the end the model is evaluanted on the joint performance, therefore we should aim to optimize:

$$ \theta^{\star} = arg\min_{\theta} \sum_{i=1}^{T} \mathbb{E}_{(x, y) \sim D} [ \mathcal{L} (f_{\theta}(x), y)] $$

The main challenge is that at the time of task $t$, the model has no access to data from previous tasks $\tilde{t} \in \{1, \ldots, t-1\}$, therefore violating the typical IID data assumption.

Overcoming catastrophic forgetting in neural networks

One of the seminal works addressing catastrophic forgetting in neural networks is the method called Elastic Weight Consolidation (EWC), first proposed in “Overcoming catastrophic forgetting in neural networks” by Kirkpatrick et al. (2017).

In order to get a better intuition for the final formulation of the method, let us first consider the Bayesian perspective on training neural networks.

From Bayes rule, we have:

$$ p(\theta | D) = \frac{p(D | \theta) p(\theta)}{p(D)} $$

where $p(\theta | D)$ is the posterior distribution of the weights given the data, $p(D | \theta)$ is the likelihood of the data given the weights, $p(\theta)$ is the prior distribution of the weights, and $p(D)$ is the marginal likelihood of the data (also known as evidence).

Taking the log of the posterior, we have:

$$ \log p(\theta | D) = \log p(D | \theta) + \log p(\theta) - \log p(D) $$

The goal is find the optimal configuration of parameters $ \theta^{\star} $ that maximizes the (log) posterior:

$$ \theta^{\star} = arg\max_{\theta} \log p(\theta | D) $$

In the case of 2 independent tasks s.t. $D = \{A, B\}$, we can re-write the log-posterior as:

$$ \begin{align*} \log(p(\theta | D)) &= \log \left(\frac{p(B | A, \theta) p(\theta | A) p(A)}{p(B|A) p(A)}\right)\\ &= \log(p(B|\theta)) + \log (p(\theta | A)) - \log(p(B)) &\text{(conditional independence of A and B)} \\ &\approx \log(p(B|\theta)) + \log (p(\theta | A)) &\text{($\log(p(B))$ is const.)} \end{align*} $$

The likelihood $p(B|\theta)$ corresponds to the loss on the task B. Notice however, that the posterior $p(\theta|A)$ is in general intractable for neural nets, and for this reason, we will resort to Laplace’s approximation.

To begin with, consider the second order Taylor expansion of the log-likelihood $l(\theta)$ around task $A$’s optimal parameters, $\theta_A^*$:

$$ \begin{align*} l(\theta) &\approx \underbrace{l(\theta_{A}^{*})}_{\text{const.}} + \underbrace{\left( \left.\frac{\partial l(\theta)}{\partial \theta} \right\vert_{\theta_{A}^{*}}\right)}_{0} + \frac{1}{2}\left(\theta - \theta_{A}^{*}\right)^t \left(\left. \frac{\partial^2 l(\theta)}{\partial^2 \theta} \right\vert_{\theta_{A}^{*}} \right) \left(\theta - \theta_{A}^*\right) \\ &\approx \frac{1}{2}\left(\theta - \theta_{A}^{*}\right)^t \left(\left. \frac{\partial^2 l(\theta)}{\partial^2 \theta} \right\vert_{\theta_{A}^{*}} \right) \left(\theta - \theta_{A}^*\right) \end{align*} $$

In this manner, we can analogously write the approximation for the log posterior $p(\theta | A)$ in a similar form:

$$ \begin{align*} \log(p(\theta|A)) \approx \frac{1}{2}\left(\theta - \theta_{A}^{*}\right)^t \left(\left. \frac{\partial^2 \log(p(\theta|A))}{\partial^2 \theta} \right\vert_{\theta_{A}^{*}} \right) \left(\theta - \theta_{A}^*\right) \end{align*} $$

Now, by exponentiating both sides and performing some algebraic manipulations, we can write the posterior as a multivariate Gaussian distribution:

$$ \begin{align*} p(\theta | A) &\approx \text{exp}\left( \frac{1}{2}\left(\theta - \theta_{A}^{*}\right)^t \left(\left. \frac{\partial^2 \log(p(\theta|A))}{\partial^2 \theta} \right\vert_{\theta_{A}^{*}} \right) \left(\theta - \theta_{A}^*\right) \right) \\ &= \text{exp}\left(-\frac{1}{2}\left(\theta - \theta_{A}^{*}\right)^t \left(\left(\left. -\frac{\partial^2 \log(p(\theta|A))}{\partial^2 \theta} \right\vert_{\theta_{A}^{*}} \right)^{-1}\right)^{-1} \left(\theta - \theta_{A}^*\right) \right) \end{align*} $$

Therefore, we obtain the following Laplace approximation:

$$ \begin{align*} p(\theta | A) \approx \mathcal{N}\left(\theta_{A}^{*}, \left(- \left. \frac{\partial^2 \log(p(\theta|A))}{\partial^2 \theta} \right\vert_{\theta_{A}^{*}} \right)^{-1} \right) \end{align*} $$

Notice that in the formula above, the term $-\left. \frac{\partial^2 \log(p(\theta|A))}{\partial^2 \theta} \right\vert_{\theta_{A}^{*}}$ is the Fisher Information matrix, which is a measure of the curvature of the log-likelihood around the optimal parameters $\theta_{A}^{*}$:

$$ \begin{align*} \mathbb{I}_{A} = \mathbb{E} \left[ -\left. \frac{\partial^2 \log(p(\theta|A))}{\partial^2 \theta} \right\vert_{\theta_{A}^{*}} \right] \end{align*} $$

First, let us note that the Fisher Information matrix can be efficiently computed using first-order derivatives of the log-likelihood:

$$ \begin{align*} \mathbb{I}_{A} &= \mathbb{E} \left[ -\left. \frac{\partial^2 \log(p(\theta|A))}{\partial^2 \theta} \right\vert_{\theta_{A}^{*}} \right] \\ &= \mathbb{E} \left[ \left( \frac{\partial \log(p(\theta|A))}{\partial \theta} \right) \left.\left( \frac{\partial \log(p(\theta|A))}{\partial \theta} \right)^t \right\vert_{\theta_{A}^{*}} \right] \end{align*} $$

However, given that neural networks today can have billions of parameters, materializing the entire Fisher Information matrix is computationally infeasible due to its quadratic complexity in the number of parameters. For this reason, EWC approximates the Fisher Information matrix by only considering the diagonal elements of the matrix.

This means that we only need to compute the gradients of the log-likelihood with respect to the parameters, and then square them!

Now, going back and plugging the approximation for $\log(p(\theta | A))$, we get:

$$ \begin{align*} \log(p(\theta | D)) &\approx \log(p(B|\theta)) + \log (p(\theta | A)) \\ &\approx \log(p(B|\theta)) + \frac{\lambda}{2} \left(\theta - \theta_{A}^{*}\right)^t \left(\left. \frac{\partial^2 \log(p(\theta|A))}{\partial^2 \theta} \right\vert_{\theta_{A}^{*}} \right) \left(\theta - \theta_{A}^*\right) \end{align*} $$

where $\lambda$ is a hyperparamter that trades off learning task B and not forgetting task A. Simplifying further:

$$ \begin{align*} \log(p(\theta | D)) &\approx \log(p(B|\theta)) - \frac{\lambda}{2}\left(\theta - \theta_{A}^{*}\right)^t \mathbb{I}_{A} \left(\theta - \theta_{A}^*\right) \\ \implies -\log(p(\theta | D)) &\approx -\log(p(B|\theta)) + \frac{\lambda}{2}\left(\theta - \theta_{A}^{*}\right)^t \mathbb{I}_{A} \left(\theta - \theta_{A}^*\right) \\ \underbrace{\mathcal{L}(\theta)}_{\text{total loss}} &\approx \underbrace{\mathcal{L}_B(\theta)}_{\text{loss on B}} + \underbrace{\frac{\lambda}{2}\left(\theta - \theta_{A}^{*}\right)^t \mathbb{I}_{A} \left(\theta - \theta_{A}^*\right)}_{\text{regularizer}} \end{align*} $$

Progress and Compress

In the previous section we looked at the case when our dataset consists only of two tasks $A$ and $B$. Now, let us consider the more general case when we have a sequence of $k$ different tasks:

$$ \begin{align*} p(\theta | T_{1:k}) &= \frac{p(T_{1:k} | \theta) p(\theta)}{p(T_{1:k})} \\ &\approx p(T_{1:k} | \theta) p(\theta) &\text{($P(T_{1:k})$ is const.)} \\ &\approx \left(\prod_{i=1}^k p(T_i | \theta) \right) p(\theta) &\text{(cond. independence)}\\ &= p(T_1 | \theta) p(T_2 | \theta) \ldots p(T_{k-1} | \theta) p(T_k | \theta) p(\theta) &\text{(expand prod)}\\ &= p(T_1 | \theta) p(T_2 | \theta) \ldots p(T_{k-1} | \theta) p(\theta) p(T_k | \theta) &\text{(reorder)} \\ &\approx p(\theta|T_{1:k-1})p(T_k|\theta) &\text{(same approx as above)} \\ \end{align*} $$

This means that the posterior of $\theta$ given all tasks up to $k$ can be computed sequentially, by first computing it for the first $k-1$ tasks, and then updating it with the likelihood (alternatively, the loss) for the $k$-th task.

Maximizing the posterior is equivalent to minimizing the negative log-posterior. Therefore, we can further obtain:

$$ \begin{align*} -\log(p(\theta | T_{1:k})) &\approx -\log (p(\theta | T_{1:k-1})) - \log(p(T_{k}| \theta)) \\ &= - \log(p(T_{k}| \theta)) - \log (p(\theta | T_{1:k-1})) &\text{(reorder)} \\ &\approx \underbrace{-\log(p(T_{k}|\theta))}_{\text{loss on task $T_k$}} + \underbrace{\frac{1}{2}\sum_{j=0}^{k-1} \left\Vert \theta - \theta_j^* \right\Vert_{F_j}^2}_{\text{regularizer}} &\text{(see previous section)} \end{align*} $$

Note that this formulation requires keeping a mean and Fisher for each task, thus making the computation cost linear in the number of tasks. Alternatively, one can apply Laplace’s approximation for the whole posterior $p(\theta|T_{1:k})$, rather than the individual likelihood terms, thus resulting in:

$$ \begin{align*} -\log(p(T_k | \theta)) + \frac{1}{2} \left\Vert \theta - \theta_{k-1}^* \right\Vert_{\sum_{j=0}^{k-1}F_j}^2 \end{align*} $$

This means that we only need to keep the latest Maximum-A-Posteriori (MAP) parameters, along with a running sum of Fishers.

What the paper “Progress & Compress: A scalable framework for continual learning” by Schwarz et al. (2018) instead suggests is to use a running average of the Fisher Information matrices, which is more computationally efficient.

More precisely, let $\theta_{i-1}^{*}$, $F_{i-1}^{*}$ be the MAP parameters and overall Fisher after presentation of $i-1$ tasks. Then, the loss for the $i$-th task is defined as:

$$ \begin{align*} -\log(p(T_i | \theta)) + \frac{1}{2}\left\Vert \theta - \theta_{i-1}^{*}\right\Vert_{\gamma F_{i-1}^{*}}^2 \end{align*} $$

where $\gamma < 1$ is a hyperparameter associated with removing the approximation term associated with the previous presentation of task $i$.

If $\theta_{i}^{*}$ are the optimal MAP parameters, and $F_i$ the Fisher for task $i$, then the overall Fisher is updated as:

$$ \begin{align*} F_i^* = \gamma F_{i-1}^* + F_i \end{align*} $$

The authors refer to this modified method as online EWC.

See this for an implementation.