108 lines
7.5 KiB
Markdown
108 lines
7.5 KiB
Markdown
# ForNet
|
|
|
|
This is the training code for the ForNet paper.
|
|
All our experiments and evaluations were run using this codebase.
|
|
|
|
## Requirements
|
|
|
|
This project heavily builds on [timm](https://github.com/huggingface/pytorch-image-models) and open source implementations of the models that are tested.
|
|
All requirements are listed in [requirements.txt](./requirements.txt).
|
|
To install those, run
|
|
|
|
```commandline
|
|
pip install -r requirements.txt
|
|
```
|
|
|
|
## Usage
|
|
|
|
After **cloning this repository**, you can train and test a lot of different models.
|
|
By default, a `srun` command is executed to run the code on a slurm cluster.
|
|
To run on the local machine, append the `-local` flag to the command.
|
|
|
|
### General Preparation
|
|
|
|
After cloning the repository on a slurm cluster, make sure main.py is executable (by using "chmod a+x main.py").
|
|
|
|
To run the project on a slurm cluster, you need to create a docker image from the requirements file.
|
|
You will also want to adapt the default slurm parameters in `config.py`.
|
|
|
|
Next, adjust the paths in paths_config.py for your system, specifically results_folder, slurm_output_folder and dataset folders.
|
|
|
|
Finally, if you want to use Weights and Biases for Tracking, create the file ".wandb.apikey" in this folder and paste your API Key into it.
|
|
|
|
### Training
|
|
|
|
#### Pretraining
|
|
|
|
To pretrain a `ViT-S` on a given dataset, run
|
|
|
|
```commandline
|
|
./main.py --task pre-train --model ViT-S/16 --epochs 300 --run_name <name_or_description_of_the_run> --experiment_name recombine_imagenet --lr 3e-3 (--local)
|
|
```
|
|
|
|
This will save a checkpoint (`.pt` file) every `<save_epochs>` epochs (the default is 10), which contains all the model weights, along with the optimizer and scheduler state, and the current training stats.
|
|
|
|
#### Finetuning
|
|
|
|
A model (checkpoint) can be finetuned on another dataset using the following command:
|
|
|
|
```commandline
|
|
./main.py --task fine-tune --model <model_checkpoint.pt> --epochs 300 --run_name <name_or_description_of_the_run> --experiment_name recombine_imagenet --lr 3e-3 (--local)
|
|
```
|
|
|
|
This will also save new checkpoints during training.
|
|
|
|
### Evaluation
|
|
|
|
It is also possible to evaluate the models.
|
|
To evaluate the model's accuracy on a specific dataset, run
|
|
|
|
```commandline
|
|
./main.py -t eval -ds <dataset name> -m <model_checkpoint.pt> --ntasks 1 -bs 512 --num-workers 10 --cpus-per-task 10 --time 10:00 (--local)
|
|
```
|
|
|
|
You can run our center-bias, size-bias, and foreground-focus evaluations using the `eval-attr`, `eval-center-bias`, and `eval-size-bias` tasks (`-t` or `--task` argument).
|
|
|
|
### Further Arguments
|
|
|
|
There can be multiple further arguments and flags given to the scripts.
|
|
The most important ones are
|
|
|
|
| Arg | Description |
|
|
| :------------------------------ | :----------------------------------------------------- |
|
|
| `--model <model>` | Model name or checkpoint. |
|
|
| `--run_name <name for the run>` | Name or description of this training run. |
|
|
| `--dataset <dataset>` | Specifies a dataset to use. |
|
|
| `--task <task>` | Specifies a task. The default is `pre-train`. |
|
|
| `--local` | Run on the local machine, not on a slurm cluster. |
|
|
| `--epochs <epochs>` | Epochs to train. |
|
|
| `--lr <lr>` | Learning rate. Default is 3e-3. |
|
|
| `--batch_size <bs>` | Batch size. Default is 2048. |
|
|
| `--weight_decay <wd>` | Weight decay. Default is 0.02. |
|
|
| `--imsize <image resolution>` | Resulution of the image to train with. Default is 224. |
|
|
|
|
For a list of all arguments, run
|
|
|
|
```commandline
|
|
./main.py --help
|
|
```
|
|
|
|
## Supported Models
|
|
|
|
These are the models we support. Links are to original code sources. If no link is provided, we implemented the architecture from scratch, following the specific paper.
|
|
|
|
| Architecture | Versions |
|
|
| :----------------------------------------------------------------------------------------------------- | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
|
| [DeiT](https://github.com/facebookresearch/deit) | `deit_tiny_patch16_LS`, `deit_small_patch16_LS`, `deit_medium_patch16_LS`, `deit_base_patch16_LS`, `deit_large_patch16_LS`, `deit_huge_patch14_LS`, `deit_huge_patch14_52_LS`, `deit_huge_patch14_26x2_LS`, `deit_Giant_48_patch14_LS`, `deit_giant_40_patch14_LS`, `deit_small_patch16_36_LS`, `deit_small_patch16_36`, `deit_small_patch16_18x2_LS`, `deit_small_patch16_18x2`, `deit_base_patch16_18x2_LS`, `deit_base_patch16_18x2`, `deit_base_patch16_36x1_LS`, `deit_base_patch16_36x1` |
|
|
| [ResNet](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/resnet.py) | `resnet18`, `resnet34`, `resnet26`, `resnet50`, `resnet101`, `wide_resnet50_2` |
|
|
| [Swin](https://github.com/microsoft/Swin-Transformer) | `swin_tiny_patch4_window7`, `swin_small_patch4_window7`, `swin_base_patch4_window7`, `swin_large_patch4_window7` |
|
|
| [ViT](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py) | `ViT-{Ti,S,B,L}/<patch_size>` |
|
|
|
|
## License
|
|
|
|
We release this code under the [MIT license](./LICENSE).
|
|
|
|
## Citation
|
|
|
|
If you use this codebase in your project, please cite:
|