# ForNet This is the training code for the ForNet paper. All our experiments and evaluations were run using this codebase. ## Requirements This project heavily builds on [timm](https://github.com/huggingface/pytorch-image-models) and open source implementations of the models that are tested. All requirements are listed in [requirements.txt](./requirements.txt). To install those, run ```commandline pip install -r requirements.txt ``` ## Usage After **cloning this repository**, you can train and test a lot of different models. By default, a `srun` command is executed to run the code on a slurm cluster. To run on the local machine, append the `-local` flag to the command. ### General Preparation After cloning the repository on a slurm cluster, make sure main.py is executable (by using "chmod a+x main.py"). To run the project on a slurm cluster, you need to create a docker image from the requirements file. You will also want to adapt the default slurm parameters in `config.py`. Next, adjust the paths in paths_config.py for your system, specifically results_folder, slurm_output_folder and dataset folders. Finally, if you want to use Weights and Biases for Tracking, create the file ".wandb.apikey" in this folder and paste your API Key into it. ### Training #### Pretraining To pretrain a `ViT-S` on a given dataset, run ```commandline ./main.py --task pre-train --model ViT-S/16 --epochs 300 --run_name --experiment_name recombine_imagenet --lr 3e-3 (--local) ``` This will save a checkpoint (`.pt` file) every `` epochs (the default is 10), which contains all the model weights, along with the optimizer and scheduler state, and the current training stats. #### Finetuning A model (checkpoint) can be finetuned on another dataset using the following command: ```commandline ./main.py --task fine-tune --model --epochs 300 --run_name --experiment_name recombine_imagenet --lr 3e-3 (--local) ``` This will also save new checkpoints during training. ### Evaluation It is also possible to evaluate the models. To evaluate the model's accuracy on a specific dataset, run ```commandline ./main.py -t eval -ds -m --ntasks 1 -bs 512 --num-workers 10 --cpus-per-task 10 --time 10:00 (--local) ``` You can run our center-bias, size-bias, and foreground-focus evaluations using the `eval-attr`, `eval-center-bias`, and `eval-size-bias` tasks (`-t` or `--task` argument). ### Further Arguments There can be multiple further arguments and flags given to the scripts. The most important ones are | Arg | Description | | :------------------------------ | :----------------------------------------------------- | | `--model ` | Model name or checkpoint. | | `--run_name ` | Name or description of this training run. | | `--dataset ` | Specifies a dataset to use. | | `--task ` | Specifies a task. The default is `pre-train`. | | `--local` | Run on the local machine, not on a slurm cluster. | | `--epochs ` | Epochs to train. | | `--lr ` | Learning rate. Default is 3e-3. | | `--batch_size ` | Batch size. Default is 2048. | | `--weight_decay ` | Weight decay. Default is 0.02. | | `--imsize ` | Resulution of the image to train with. Default is 224. | For a list of all arguments, run ```commandline ./main.py --help ``` ## Supported Models These are the models we support. Links are to original code sources. If no link is provided, we implemented the architecture from scratch, following the specific paper. | Architecture | Versions | | :----------------------------------------------------------------------------------------------------- | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | [DeiT](https://github.com/facebookresearch/deit) | `deit_tiny_patch16_LS`, `deit_small_patch16_LS`, `deit_medium_patch16_LS`, `deit_base_patch16_LS`, `deit_large_patch16_LS`, `deit_huge_patch14_LS`, `deit_huge_patch14_52_LS`, `deit_huge_patch14_26x2_LS`, `deit_Giant_48_patch14_LS`, `deit_giant_40_patch14_LS`, `deit_small_patch16_36_LS`, `deit_small_patch16_36`, `deit_small_patch16_18x2_LS`, `deit_small_patch16_18x2`, `deit_base_patch16_18x2_LS`, `deit_base_patch16_18x2`, `deit_base_patch16_36x1_LS`, `deit_base_patch16_36x1` | | [ResNet](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/resnet.py) | `resnet18`, `resnet34`, `resnet26`, `resnet50`, `resnet101`, `wide_resnet50_2` | | [Swin](https://github.com/microsoft/Swin-Transformer) | `swin_tiny_patch4_window7`, `swin_small_patch4_window7`, `swin_base_patch4_window7`, `swin_large_patch4_window7` | | [ViT](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py) | `ViT-{Ti,S,B,L}/` | ## License We release this code under the [MIT license](./LICENSE). ## Citation If you use this codebase in your project, please cite: