SOTAVerified

Ensemble of Averages: Improving Model Selection and Boosting Performance in Domain Generalization

2021-10-21Code Available1· sign in to hype

Devansh Arpit, Huan Wang, Yingbo Zhou, Caiming Xiong

Code Available — Be the first to reproduce this paper.

Reproduce

Code

Abstract

In Domain Generalization (DG) settings, models trained independently on a given set of training domains have notoriously chaotic performance on distribution shifted test domains, and stochasticity in optimization (e.g. seed) plays a big role. This makes deep learning models unreliable in real world settings. We first show that this chaotic behavior exists even along the training optimization trajectory of a single model, and propose a simple model averaging protocol that both significantly boosts domain generalization and diminishes the impact of stochasticity by improving the rank correlation between the in-domain validation accuracy and out-domain test accuracy, which is crucial for reliable early stopping. Taking advantage of our observation, we show that instead of ensembling unaveraged models (that is typical in practice), ensembling moving average models (EoA) from independent runs further boosts performance. We theoretically explain the boost in performance of ensembling and model averaging by adapting the well known Bias-Variance trade-off to the domain generalization setting. On the DomainBed benchmark, when using a pre-trained ResNet-50, this ensemble of averages achieves an average of 68.0\%, beating vanilla ERM (w/o averaging/ensembling) by 4\%, and when using a pre-trained RegNetY-16GF, achieves an average of 76.6\%, beating vanilla ERM by 6\%. Our code is available at https://github.com/salesforce/ensemble-of-averages.

Tasks

Benchmark Results

DatasetModelMetricClaimedVerifiedStatus
DomainNetEnsemble of Averages (RegNetY-16GF)Average Accuracy60.9Unverified
DomainNetEnsemble of Averages (ResNeXt-50 32x4d)Average Accuracy54.6Unverified
DomainNetEnsemble of Averages (ResNet-50)Average Accuracy47.4Unverified
Office-HomeEnsemble of Averages (RegNetY-16GF)Average Accuracy83.9Unverified
Office-HomeEnsemble of Averages (ResNeXt-50 32x4d)Average Accuracy80.2Unverified
Office-HomeEnsemble of Averages (ResNet-50)Average Accuracy72.5Unverified
PACSEnsemble of Averages (RegNetY-16GF)Average Accuracy95.8Unverified
PACSEnsemble of Averages (ResNeXt-50 32x4d)Average Accuracy93.2Unverified
PACSEnsemble of Averages (ResNet-50)Average Accuracy88.6Unverified
TerraIncognitaEnsemble of Averages (RegNetY-16GF)Average Accuracy61.1Unverified
TerraIncognitaEnsemble of Averages (ResNeXt-50 32x4d)Average Accuracy55.2Unverified
TerraIncognitaEnsemble of Averages (ResNet-50)Average Accuracy52.3Unverified
VLCSEnsemble of Averages (RegNetY-16GF)Average Accuracy81.1Unverified
VLCSEnsemble of Averages (ResNeXt-50 32x4d)Average Accuracy80.4Unverified
VLCSEnsemble of Averages (ResNet-50)Average Accuracy79.1Unverified

Reproductions