SMART: Robust and Efficient Fine-Tuning for Pre-trained Natural Language Models through Principled Regularized Optimization
Haoming Jiang, Pengcheng He, Weizhu Chen, Xiaodong Liu, Jianfeng Gao, Tuo Zhao
Code Available — Be the first to reproduce this paper.
ReproduceCode
- github.com/namisan/mt-dnnOfficialIn paperpytorch★ 2,257
- github.com/microsoft/MT-DNNpytorch★ 167
- github.com/archinetai/vat-pytorchpytorch★ 17
- github.com/cliang1453/cameropytorch★ 9
- github.com/chunhuililili/mt_dnnpytorch★ 2
- github.com/archinetai/smart-pytorchpytorch★ 0
Abstract
Transfer learning has fundamentally changed the landscape of natural language processing (NLP) research. Many existing state-of-the-art models are first pre-trained on a large text corpus and then fine-tuned on downstream tasks. However, due to limited data resources from downstream tasks and the extremely large capacity of pre-trained models, aggressive fine-tuning often causes the adapted model to overfit the data of downstream tasks and forget the knowledge of the pre-trained model. To address the above issue in a more principled manner, we propose a new computational framework for robust and efficient fine-tuning for pre-trained language models. Specifically, our proposed framework contains two important ingredients: 1. Smoothness-inducing regularization, which effectively manages the capacity of the model; 2. Bregman proximal point optimization, which is a class of trust-region methods and can prevent knowledge forgetting. Our experiments demonstrate that our proposed method achieves the state-of-the-art performance on multiple NLP benchmarks.
Tasks
Benchmark Results
| Dataset | Model | Metric | Claimed | Verified | Status |
|---|---|---|---|---|---|
| AX | T5 | Accuracy | 53.1 | — | Unverified |
| MNLI + SNLI + ANLI + FEVER | SMARTRoBERTa-LARGE | % Dev Accuracy | 57.1 | — | Unverified |
| MultiNLI | SMART-BERT | Dev Matched | 85.6 | — | Unverified |
| MultiNLI | SMART+BERT-BASE | Accuracy | 85.6 | — | Unverified |
| MultiNLI | SMARTRoBERTa | Dev Matched | 91.1 | — | Unverified |
| MultiNLI | T5 | Matched | 92 | — | Unverified |
| MultiNLI | MT-DNN-SMARTv0 | Accuracy | 85.7 | — | Unverified |
| MultiNLI | MT-DNN-SMART | Accuracy | 85.7 | — | Unverified |
| QNLI | ALICE | Accuracy | 99.2 | — | Unverified |
| QNLI | MT-DNN-SMART | Accuracy | 99.2 | — | Unverified |
| RTE | SMART | Accuracy | 71.2 | — | Unverified |
| RTE | T5-XXL 11B | Accuracy | 92.5 | — | Unverified |
| RTE | SMARTRoBERTa | Accuracy | 92 | — | Unverified |
| RTE | SMART-BERT | Accuracy | 71.2 | — | Unverified |
| SciTail | MT-DNN-SMART_1%ofTrainingData | Dev Accuracy | 88.6 | — | Unverified |
| SciTail | MT-DNN-SMART_0.1%ofTrainingData | Dev Accuracy | 82.3 | — | Unverified |
| SciTail | MT-DNN-SMARTLARGEv0 | % Dev Accuracy | 96.6 | — | Unverified |
| SciTail | MT-DNN-SMART_100%ofTrainingData | Dev Accuracy | 96.1 | — | Unverified |
| SciTail | MT-DNN-SMART_10%ofTrainingData | Dev Accuracy | 91.3 | — | Unverified |
| SNLI | MT-DNN-SMART_1%ofTrainingData | Dev Accuracy | 86 | — | Unverified |
| SNLI | MT-DNN-SMART_0.1%ofTrainingData | Dev Accuracy | 82.7 | — | Unverified |
| SNLI | MT-DNN-SMARTLARGEv0 | % Test Accuracy | 91.7 | — | Unverified |
| SNLI | MT-DNN-SMART_100%ofTrainingData | Dev Accuracy | 91.6 | — | Unverified |
| SNLI | MT-DNN-SMART_10%ofTrainingData | Dev Accuracy | 88.7 | — | Unverified |