Robust Training for Image Classification¶
Overview¶
This project contains the code for adversarial training on classification models. Heavy augmentations can be used in adversarial training. The well-known timm project is used as the default classification library.
Preparation¶
Dataset
We use
ImageNet Validation Set
as the default dataset to evaluate adversarial robustness of classification models. Please download ImageNet dataset first. If you want to use your own datasets, please define theirtorch.utils.data.Dataset
class and correspondingtransform
.Classification Models
To build a image classification model, you can create a model class from timm library or you can define custom network of
torch.nn.Module
.
Robust Training¶
Before start, modify the default attacking config files in train_configs
. It contains the augmentation parameters, dataset directory, etc.
Then, you can run the following command to start training ResNet50.
cd robust_training
python -m torch.distributed.launch --nproc_per_node=<num-of-gpus-to-use> adversarial_training.py --configs=./train_configs/resnet50.yaml