Overcoming the Stability Gap in Continual Learning

1Rochester Institute of Technology, 2University of Rochester

Motivation

Pre-trained deep neural networks (DNNs) are being widely deployed by industry for making business decisions and to serve users; however, a major problem is model decay, where the DNN's predictions become more erroneous over time, resulting in revenue loss or unhappy users. To mitigate model decay, DNNs are retrained from scratch using old and new data. This is computationally expensive, so retraining happens only once performance has significantly decreased. Here, we study how continual learning (CL) could potentially overcome model decay in large pre-trained DNNs and also greatly reduce computational costs for keeping DNNs up-to-date. We identify the stability gap as a major obstacle in our setting. The stability gap refers to a phenomenon where learning new data causes large drops in performance for past tasks before CL mitigation methods eventually compensate for this drop. We test two hypotheses to investigate the factors influencing the stability gap and identify a method that vastly reduces this gap. In large-scale experiments for both easy and hard CL distributions (e.g., class incremental learning), we demonstrate that our method reduces the stability gap and greatly increases computational efficiency. Our work aligns CL with the goals of the production setting, where CL is needed for many applications.

The Stability Gap

Interpolation end reference image.

The stability gap is a phenomenon that occurs in CL when learning new data, where accuracy on previously learned data (Y-axis) drops significantly as a function of training iterations when a new distribution is introduced (X-axis). Fig.(a) illustrates this behavior during CIL, where a network pre-trained on ImageNet-1K, learns 365 new classes from Places365-LT over five rehearsal sessions. Each rehearsal session involves 600 iterations that combine samples from the old and new tasks. A gray dotted vertical line marks the end of a rehearsal session or a task transition. When rehearsal begins, accuracy on the old task for the conventional rehearsal drops dramatically before slowly recovering, although it fails to recover the original performance on the old data. The traditional measures of catastrophic forgetting focus on performance at task transitions (red diamonds), ignoring significant forgetting that occurs during the learning process between task transitions.

Hypotheses for Mitigating the Stability Gap

We test two hypotheses to examine the stability gap in CIL with pre-trained DNNs:

1. The stability gap is increased in part due to having a large loss at the output layer for the new classes. To test this hypothesis, we study two methods to mitigate the large loss in the output layer for the new classes. The first method is to initialize the output layer in a data-driven approach rather than randomly initializing the output units responsible for the new classes. The second method is a specialized form of soft targets for the network, rather than the typical hard targets used for the network where these soft targets are designed to improve performance for the new classes while minimally perturbing others.

2. The stability gap is increased in part due to excessive network plasticity. We test this hypothesis by controlling the level of plasticity in network layers in a dynamic manner. For hidden layers, we test this hypothesis using LoRA (Low Rank Adaptation), which reduces the number of trainable parameters in the hidden layers of the network. For rehearsal methods, after each rehearsal session, these weights are folded into the original network weights. For the output layer, we test this hypothesis by freezing the output units for classes seen in earlier batches during rehearsal.

Stability Gap Mitigation Methods

1. Weight Initialization. In CIL, typically the output units for new classes are randomly initialized causing those units to produce a high loss during backpropagation. We hypothesize that using data-driven initialization for new class units will reduce the loss and therefore reduce the stability gap.

2. Hard vs. Dynamic Soft Targets. Hard targets are one-hot encoded and enforce strict inter-class independence despite several classes sharing distributional similarities. This property of hard targets, therefore, also causes a large initial loss when learning new classes. Soft targets, on the other hand, can help the network to retain the joint inter-class distributions, which further ameliorates the perturbation of learned classes. To test this, we use soft targets constructed such that the model's predictions on previously learned classes are largely preserved.

Dynamic Soft Targets

Interpolation end reference image.

3. Limiting Hidden Layer Plasticity Using LoRA. We inject LoRA weights into the linear layers of the network, and only these parameters and the output layer are updated during rehearsal. At the end of the rehearsal session the LoRA parameters are folded into the network. This prevents excessively perturbation of hidden representations, reducing the stability gap.

4. Limiting Output Layer Plasticity via Targeted Freezing. While LoRA restricts plasticity in hidden representations, we hypothesize that restricting plasticity in the output layer could also be helpful. We refer to this technique as old output class freezing (OOCF).

Combining Mitigation Methods & SGM. We refer to the method that combines dynamic soft targets, weight initialization, OOCF, and LoRA as SGM (Stability Gap Mitigation).

SGM Enhances Stability and Efficiency

After pre-training on ImageNet-1K, the model learns 365 new classes from Places365-LT over five rehearsal sessions. (a) Stability Gap. SGM quickly recovers old performance on ImageNet-1K at the beginning of CL whereas vanilla fails to obtain full recovery. After each rehearsal session (vertical dotted gray line), the final accuracy (%) is highlighted by diamond (SGM), star (joint model), and square (vanilla). The joint model (upper bound) is jointly trained on ImageNet-1K and seen CL batches from the Places365-LT dataset. (b) Computational Efficiency. SGM provides a 16.7X speedup in number of network updates compared to a joint model (upper bound) with the combined 1365 class dataset (ImageNet-1K and Places365-LT combined). For SGM and conventional rehearsal, we show the stability gap in the learning curve averaged over rehearsal sessions.

(a) Stability Gap

Interpolation end reference image.

(b) Computational Efficiency

Interpolation end reference image.

Evaluating Our Hypotheses

Mitigation methods averaged over 5 rehearsal sessions during CIL.

(a) The loss on new classes when only training the output layer, which reveals soft targets and data-driven weight initialization greatly reduce the initial loss.

(b) Accuracy on ImageNet-1K for hard vs. soft targets, which shows that soft targets reduce the stability gap.

(c) Network plasticity increases the stability gap.

(d-f) Stability-plasticity. Accuracy on new, old, and all classes.

Learning Efficiency

Speed of acquiring new knowledge. SGM requires fewer network updates and TFLOPs than vanilla to reach 99% of the best accuracy on new classes (highlighted).

Interpolation end reference image.

FLOPs Analysis

With sparse updates, SGM significantly reduces FLOPs, improving compute efficiency. SGM provides a 31.9X speedup in TFLOPs compared to a joint model (upper bound) with the combined 1365 class dataset (ImageNet-1K and Places365-LT combined). For SGM and conventional rehearsal, we show the stability gap in the learning curve averaged over rehearsal sessions.

Interpolation end reference image.

Class Incremental Learning & IID Orderings

Following figure shows results after learning ImageNet-1K followed by Places365-LT over 5 rehearsal sessions. Reported is the final accuracy (%) on all 1365 classes. SGM outperforms vanilla across various data orderings.

Interpolation end reference image.

Offline Continual Learning

To study SGM's efficacy under storage constraints, we combine it with two popular rehearsal methods, DERpp and GDumb. Following figure shows results for class incremental learning on a combination of ImageNet-1K and Places365-Standard (3 Million images). Reported is the final accuracy (%) on all 1365 classes. SGM significantly improves each method's performance.

Interpolation end reference image.

Online Continual Learning

We assess SGM's efficacy in an online continual learning setting using a state-of-the-art method, REMIND. After being pre-trained on ImageNet-1K, a model learns CUB-200 in sample-by-sample manner. In terms of final accuracy (%) on all 1200 classes, SGM combined with REMIND significantly outperforms standalone REMIND.

Interpolation end reference image.

Rehearsal-Free Setting

To examine if SGM would be helpful for non-rehearsal methods such as Learning without Forgetting (LwF), we compare LwF with a version of LwF that uses SGM without rehearsal during class incremental learning of ImageNet-1K and Places365-LT. Reported is the final accuracy (%) on 1365 classes. Results show that SGM is effective in both rehearsal and rehearsal-free settings.

Interpolation end reference image.

SGM Enhances Non-rehearsal Methods

Interpolation end reference image.

The Y-axis shows an average accuracy of 6 runs with a standard deviation (shaded region). The network is trained on ImageNet-1K and then learns 365 new classes from Places-LT over five batches. When a new batch arrives, accuracy on ImageNet-1k for LwF plummets. LwF fails to recover performance and ends up with a large stability gap. In contrast, LwF with SGM does not plummet like LwF and shows better performance throughout the CL phase with a significantly reduced stability gap.

Acknowledgements

This work was partly supported by NSF awards #1909696, #2326491, #2125362, and #2317706.

BibTeX

@article{harun2023overcoming,
  title     = {Overcoming the Stability Gap in Continual Learning},
  author    = {Harun, Md Yousuf and Kanan, Christopher},
  journal   = {arXiv preprint arXiv:2306.01904},
  year      = {2023}
  }