LEARNING AN INVERTIBLE OUTPUT MAPPING CAN MITIGATE SIMPLICITY BIAS IN NEURAL NETWORKS
Abstract
Deep Neural Networks (DNNs) are known to be brittle to even minor distribution
shifts compared to the training distribution . Simplicity Bias (SB) of DNNs – bias
towards learning a small number of simplest features – has been demonstrated
to be a key reason for this brittleness. Prior works have shown that the effect of
Simplicity Bias is extreme – even when the features learned are diverse, training
the classification head again selects only few of the simplest features, leading to
similarly brittle models. In this work, we introduce Feature Reconstruction Regularizer
(FRR) in the linear classification head, with the aim of reducing Simplicity
Bias, thereby improving Out-Of-Distribution (OOD) robustness. The proposed
regularizer when used during linear layer training, termed as FRR-L, enforces that
the features can be reconstructed back from the logit layer, ensuring that diverse
features participate in the classification task. We further propose to finetune the
full network by freezing the weights of the linear layer trained using FRR-L. This
approach, termed as FRR-FLFT or Fixed Linear FineTuning, improves the quality
of the learned features, making them more suitable for the classification task.
Using this simple solution, we demonstrate up to 12% gain in accuracy on the
recently introduced synthetic datasets with extreme distribution shifts. Moreover,
on the standard OOD benchmarks recommended on DomainBed, our technique
can provide up to 5% gains over the existing SOTA methods .
shifts compared to the training distribution . Simplicity Bias (SB) of DNNs – bias
towards learning a small number of simplest features – has been demonstrated
to be a key reason for this brittleness. Prior works have shown that the effect of
Simplicity Bias is extreme – even when the features learned are diverse, training
the classification head again selects only few of the simplest features, leading to
similarly brittle models. In this work, we introduce Feature Reconstruction Regularizer
(FRR) in the linear classification head, with the aim of reducing Simplicity
Bias, thereby improving Out-Of-Distribution (OOD) robustness. The proposed
regularizer when used during linear layer training, termed as FRR-L, enforces that
the features can be reconstructed back from the logit layer, ensuring that diverse
features participate in the classification task. We further propose to finetune the
full network by freezing the weights of the linear layer trained using FRR-L. This
approach, termed as FRR-FLFT or Fixed Linear FineTuning, improves the quality
of the learned features, making them more suitable for the classification task.
Using this simple solution, we demonstrate up to 12% gain in accuracy on the
recently introduced synthetic datasets with extreme distribution shifts. Moreover,
on the standard OOD benchmarks recommended on DomainBed, our technique
can provide up to 5% gains over the existing SOTA methods .