AAAI Version
This commit is contained in:
107
AAAI Supplementary Material/Model Training Code/README.md
Normal file
107
AAAI Supplementary Material/Model Training Code/README.md
Normal file
@@ -0,0 +1,107 @@
|
||||
# ForNet
|
||||
|
||||
This is the training code for the ForNet paper.
|
||||
All our experiments and evaluations were run using this codebase.
|
||||
|
||||
## Requirements
|
||||
|
||||
This project heavily builds on [timm](https://github.com/huggingface/pytorch-image-models) and open source implementations of the models that are tested.
|
||||
All requirements are listed in [requirements.txt](./requirements.txt).
|
||||
To install those, run
|
||||
|
||||
```commandline
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
After **cloning this repository**, you can train and test a lot of different models.
|
||||
By default, a `srun` command is executed to run the code on a slurm cluster.
|
||||
To run on the local machine, append the `-local` flag to the command.
|
||||
|
||||
### General Preparation
|
||||
|
||||
After cloning the repository on a slurm cluster, make sure main.py is executable (by using "chmod a+x main.py").
|
||||
|
||||
To run the project on a slurm cluster, you need to create a docker image from the requirements file.
|
||||
You will also want to adapt the default slurm parameters in `config.py`.
|
||||
|
||||
Next, adjust the paths in paths_config.py for your system, specifically results_folder, slurm_output_folder and dataset folders.
|
||||
|
||||
Finally, if you want to use Weights and Biases for Tracking, create the file ".wandb.apikey" in this folder and paste your API Key into it.
|
||||
|
||||
### Training
|
||||
|
||||
#### Pretraining
|
||||
|
||||
To pretrain a `ViT-S` on a given dataset, run
|
||||
|
||||
```commandline
|
||||
./main.py --task pre-train --model ViT-S/16 --epochs 300 --run_name <name_or_description_of_the_run> --experiment_name recombine_imagenet --lr 3e-3 (--local)
|
||||
```
|
||||
|
||||
This will save a checkpoint (`.pt` file) every `<save_epochs>` epochs (the default is 10), which contains all the model weights, along with the optimizer and scheduler state, and the current training stats.
|
||||
|
||||
#### Finetuning
|
||||
|
||||
A model (checkpoint) can be finetuned on another dataset using the following command:
|
||||
|
||||
```commandline
|
||||
./main.py --task fine-tune --model <model_checkpoint.pt> --epochs 300 --run_name <name_or_description_of_the_run> --experiment_name recombine_imagenet --lr 3e-3 (--local)
|
||||
```
|
||||
|
||||
This will also save new checkpoints during training.
|
||||
|
||||
### Evaluation
|
||||
|
||||
It is also possible to evaluate the models.
|
||||
To evaluate the model's accuracy on a specific dataset, run
|
||||
|
||||
```commandline
|
||||
./main.py -t eval -ds <dataset name> -m <model_checkpoint.pt> --ntasks 1 -bs 512 --num-workers 10 --cpus-per-task 10 --time 10:00 (--local)
|
||||
```
|
||||
|
||||
You can run our center-bias, size-bias, and foreground-focus evaluations using the `eval-attr`, `eval-center-bias`, and `eval-size-bias` tasks (`-t` or `--task` argument).
|
||||
|
||||
### Further Arguments
|
||||
|
||||
There can be multiple further arguments and flags given to the scripts.
|
||||
The most important ones are
|
||||
|
||||
| Arg | Description |
|
||||
| :------------------------------ | :----------------------------------------------------- |
|
||||
| `--model <model>` | Model name or checkpoint. |
|
||||
| `--run_name <name for the run>` | Name or description of this training run. |
|
||||
| `--dataset <dataset>` | Specifies a dataset to use. |
|
||||
| `--task <task>` | Specifies a task. The default is `pre-train`. |
|
||||
| `--local` | Run on the local machine, not on a slurm cluster. |
|
||||
| `--epochs <epochs>` | Epochs to train. |
|
||||
| `--lr <lr>` | Learning rate. Default is 3e-3. |
|
||||
| `--batch_size <bs>` | Batch size. Default is 2048. |
|
||||
| `--weight_decay <wd>` | Weight decay. Default is 0.02. |
|
||||
| `--imsize <image resolution>` | Resulution of the image to train with. Default is 224. |
|
||||
|
||||
For a list of all arguments, run
|
||||
|
||||
```commandline
|
||||
./main.py --help
|
||||
```
|
||||
|
||||
## Supported Models
|
||||
|
||||
These are the models we support. Links are to original code sources. If no link is provided, we implemented the architecture from scratch, following the specific paper.
|
||||
|
||||
| Architecture | Versions |
|
||||
| :----------------------------------------------------------------------------------------------------- | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| [DeiT](https://github.com/facebookresearch/deit) | `deit_tiny_patch16_LS`, `deit_small_patch16_LS`, `deit_medium_patch16_LS`, `deit_base_patch16_LS`, `deit_large_patch16_LS`, `deit_huge_patch14_LS`, `deit_huge_patch14_52_LS`, `deit_huge_patch14_26x2_LS`, `deit_Giant_48_patch14_LS`, `deit_giant_40_patch14_LS`, `deit_small_patch16_36_LS`, `deit_small_patch16_36`, `deit_small_patch16_18x2_LS`, `deit_small_patch16_18x2`, `deit_base_patch16_18x2_LS`, `deit_base_patch16_18x2`, `deit_base_patch16_36x1_LS`, `deit_base_patch16_36x1` |
|
||||
| [ResNet](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/resnet.py) | `resnet18`, `resnet34`, `resnet26`, `resnet50`, `resnet101`, `wide_resnet50_2` |
|
||||
| [Swin](https://github.com/microsoft/Swin-Transformer) | `swin_tiny_patch4_window7`, `swin_small_patch4_window7`, `swin_base_patch4_window7`, `swin_large_patch4_window7` |
|
||||
| [ViT](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py) | `ViT-{Ti,S,B,L}/<patch_size>` |
|
||||
|
||||
## License
|
||||
|
||||
We release this code under the [MIT license](./LICENSE).
|
||||
|
||||
## Citation
|
||||
|
||||
If you use this codebase in your project, please cite:
|
||||
Reference in New Issue
Block a user