Fishr: Invariant Gradient Variances for Out-of-Distribution Generalization
Alexandre Rame, Corentin Dancette, Matthieu Cord
Code Available — Be the first to reproduce this paper.
ReproduceCode
- github.com/facebookresearch/DomainBedOfficialIn paperpytorch★ 1,604
- github.com/alexrame/fishrOfficialIn paperpytorch★ 89
Abstract
Learning robust models that generalize well under changes in the data distribution is critical for real-world applications. To this end, there has been a growing surge of interest to learn simultaneously from multiple training domains - while enforcing different types of invariance across those domains. Yet, all existing approaches fail to show systematic benefits under controlled evaluation protocols. In this paper, we introduce a new regularization - named Fishr - that enforces domain invariance in the space of the gradients of the loss: specifically, the domain-level variances of gradients are matched across training domains. Our approach is based on the close relations between the gradient covariance, the Fisher Information and the Hessian of the loss: in particular, we show that Fishr eventually aligns the domain-level loss landscapes locally around the final weights. Extensive experiments demonstrate the effectiveness of Fishr for out-of-distribution generalization. Notably, Fishr improves the state of the art on the DomainBed benchmark and performs consistently better than Empirical Risk Minimization. Our code is available at https://github.com/alexrame/fishr.
Tasks
Benchmark Results
| Dataset | Model | Metric | Claimed | Verified | Status |
|---|---|---|---|---|---|
| DomainNet | Fishr (ResNet-50) | Average Accuracy | 41.8 | — | Unverified |
| Office-Home | Fishr (ResNet-50) | Average Accuracy | 68.2 | — | Unverified |
| PACS | Fishr(ResNet-50,DomainBed) | Average Accuracy | 86.9 | — | Unverified |
| TerraIncognita | Fishr(ResNet-50) | Average Accuracy | 47.4 | — | Unverified |
| VLCS | Fishr (ResNet-50) | Average Accuracy | 78.2 | — | Unverified |