ForNet
This is the training code for the ForNet paper. All our experiments and evaluations were run using this codebase.
Requirements
This project heavily builds on timm and open source implementations of the models that are tested. All requirements are listed in requirements.txt. To install those, run
pip install -r requirements.txt
Usage
After cloning this repository, you can train and test a lot of different models.
By default, a srun command is executed to run the code on a slurm cluster.
To run on the local machine, append the -local flag to the command.
General Preparation
After cloning the repository on a slurm cluster, make sure main.py is executable (by using "chmod a+x main.py").
To run the project on a slurm cluster, you need to create a docker image from the requirements file.
You will also want to adapt the default slurm parameters in config.py.
Next, adjust the paths in paths_config.py for your system, specifically results_folder, slurm_output_folder and dataset folders.
Finally, if you want to use Weights and Biases for Tracking, create the file ".wandb.apikey" in this folder and paste your API Key into it.
Training
Pretraining
To pretrain a ViT-S on a given dataset, run
./main.py --task pre-train --model ViT-S/16 --epochs 300 --run_name <name_or_description_of_the_run> --experiment_name recombine_imagenet --lr 3e-3 (--local)
This will save a checkpoint (.pt file) every <save_epochs> epochs (the default is 10), which contains all the model weights, along with the optimizer and scheduler state, and the current training stats.
Finetuning
A model (checkpoint) can be finetuned on another dataset using the following command:
./main.py --task fine-tune --model <model_checkpoint.pt> --epochs 300 --run_name <name_or_description_of_the_run> --experiment_name recombine_imagenet --lr 3e-3 (--local)
This will also save new checkpoints during training.
Evaluation
It is also possible to evaluate the models. To evaluate the model's accuracy on a specific dataset, run
./main.py -t eval -ds <dataset name> -m <model_checkpoint.pt> --ntasks 1 -bs 512 --num-workers 10 --cpus-per-task 10 --time 10:00 (--local)
You can run our center-bias, size-bias, and foreground-focus evaluations using the eval-attr, eval-center-bias, and eval-size-bias tasks (-t or --task argument).
Further Arguments
There can be multiple further arguments and flags given to the scripts. The most important ones are
| Arg | Description |
|---|---|
--model <model> |
Model name or checkpoint. |
--run_name <name for the run> |
Name or description of this training run. |
--dataset <dataset> |
Specifies a dataset to use. |
--task <task> |
Specifies a task. The default is pre-train. |
--local |
Run on the local machine, not on a slurm cluster. |
--epochs <epochs> |
Epochs to train. |
--lr <lr> |
Learning rate. Default is 3e-3. |
--batch_size <bs> |
Batch size. Default is 2048. |
--weight_decay <wd> |
Weight decay. Default is 0.02. |
--imsize <image resolution> |
Resulution of the image to train with. Default is 224. |
For a list of all arguments, run
./main.py --help
Supported Models
These are the models we support. Links are to original code sources. If no link is provided, we implemented the architecture from scratch, following the specific paper.
| Architecture | Versions |
|---|---|
| DeiT | deit_tiny_patch16_LS, deit_small_patch16_LS, deit_medium_patch16_LS, deit_base_patch16_LS, deit_large_patch16_LS, deit_huge_patch14_LS, deit_huge_patch14_52_LS, deit_huge_patch14_26x2_LS, deit_Giant_48_patch14_LS, deit_giant_40_patch14_LS, deit_small_patch16_36_LS, deit_small_patch16_36, deit_small_patch16_18x2_LS, deit_small_patch16_18x2, deit_base_patch16_18x2_LS, deit_base_patch16_18x2, deit_base_patch16_36x1_LS, deit_base_patch16_36x1 |
| ResNet | resnet18, resnet34, resnet26, resnet50, resnet101, wide_resnet50_2 |
| Swin | swin_tiny_patch4_window7, swin_small_patch4_window7, swin_base_patch4_window7, swin_large_patch4_window7 |
| ViT | ViT-{Ti,S,B,L}/<patch_size> |
License
We release this code under the MIT license.
Citation
If you use this codebase in your project, please cite: