Files
ForAug/AAAI Supplementary Material/Model Training Code/README.md
Tobias Christian Nauen ff34712155 AAAI Version
2026-02-24 12:22:44 +01:00

7.5 KiB

ForNet

This is the training code for the ForNet paper. All our experiments and evaluations were run using this codebase.

Requirements

This project heavily builds on timm and open source implementations of the models that are tested. All requirements are listed in requirements.txt. To install those, run

pip install -r requirements.txt

Usage

After cloning this repository, you can train and test a lot of different models. By default, a srun command is executed to run the code on a slurm cluster. To run on the local machine, append the -local flag to the command.

General Preparation

After cloning the repository on a slurm cluster, make sure main.py is executable (by using "chmod a+x main.py").

To run the project on a slurm cluster, you need to create a docker image from the requirements file. You will also want to adapt the default slurm parameters in config.py.

Next, adjust the paths in paths_config.py for your system, specifically results_folder, slurm_output_folder and dataset folders.

Finally, if you want to use Weights and Biases for Tracking, create the file ".wandb.apikey" in this folder and paste your API Key into it.

Training

Pretraining

To pretrain a ViT-S on a given dataset, run

./main.py --task pre-train --model ViT-S/16 --epochs 300  --run_name <name_or_description_of_the_run> --experiment_name recombine_imagenet --lr 3e-3 (--local)

This will save a checkpoint (.pt file) every <save_epochs> epochs (the default is 10), which contains all the model weights, along with the optimizer and scheduler state, and the current training stats.

Finetuning

A model (checkpoint) can be finetuned on another dataset using the following command:

./main.py --task fine-tune --model <model_checkpoint.pt> --epochs 300  --run_name <name_or_description_of_the_run> --experiment_name recombine_imagenet --lr 3e-3 (--local)

This will also save new checkpoints during training.

Evaluation

It is also possible to evaluate the models. To evaluate the model's accuracy on a specific dataset, run

./main.py -t eval -ds <dataset name> -m <model_checkpoint.pt> --ntasks 1 -bs 512 --num-workers 10 --cpus-per-task 10 --time 10:00 (--local)

You can run our center-bias, size-bias, and foreground-focus evaluations using the eval-attr, eval-center-bias, and eval-size-bias tasks (-t or --task argument).

Further Arguments

There can be multiple further arguments and flags given to the scripts. The most important ones are

Arg Description
--model <model> Model name or checkpoint.
--run_name <name for the run> Name or description of this training run.
--dataset <dataset> Specifies a dataset to use.
--task <task> Specifies a task. The default is pre-train.
--local Run on the local machine, not on a slurm cluster.
--epochs <epochs> Epochs to train.
--lr <lr> Learning rate. Default is 3e-3.
--batch_size <bs> Batch size. Default is 2048.
--weight_decay <wd> Weight decay. Default is 0.02.
--imsize <image resolution> Resulution of the image to train with. Default is 224.

For a list of all arguments, run

./main.py --help

Supported Models

These are the models we support. Links are to original code sources. If no link is provided, we implemented the architecture from scratch, following the specific paper.

Architecture Versions
DeiT deit_tiny_patch16_LS, deit_small_patch16_LS, deit_medium_patch16_LS, deit_base_patch16_LS, deit_large_patch16_LS, deit_huge_patch14_LS, deit_huge_patch14_52_LS, deit_huge_patch14_26x2_LS, deit_Giant_48_patch14_LS, deit_giant_40_patch14_LS, deit_small_patch16_36_LS, deit_small_patch16_36, deit_small_patch16_18x2_LS, deit_small_patch16_18x2, deit_base_patch16_18x2_LS, deit_base_patch16_18x2, deit_base_patch16_36x1_LS, deit_base_patch16_36x1
ResNet resnet18, resnet34, resnet26, resnet50, resnet101, wide_resnet50_2
Swin swin_tiny_patch4_window7, swin_small_patch4_window7, swin_base_patch4_window7, swin_large_patch4_window7
ViT ViT-{Ti,S,B,L}/<patch_size>

License

We release this code under the MIT license.

Citation

If you use this codebase in your project, please cite: