Files
ForAug/AAAI Supplementary Material/Model Training Code/main.py
Tobias Christian Nauen ff34712155 AAAI Version
2026-02-24 12:22:44 +01:00

680 lines
26 KiB
Python

#!/usr/bin/env python3
"""Parse args and call the correct script inside slurm container.
Outside the container, on the head-node, create and call the correct srun command.
"""
import argparse
import os
import subprocess
from datetime import datetime
from config import default_kwargs, slurm_defaults
from paths_config import results_folder, slurm_output_folder
_EXPNAMES = ["EfficientCVBench", "test", "recombine_imagenet"]
def base_parser():
"""Create the argument parser with all the choices for the training / evaluation scripts."""
parser = argparse.ArgumentParser("Transformer training and evaluation.")
# Main
group = parser.add_argument_group("Main")
group.add_argument(
"-t",
"--task",
nargs="?",
choices=[
"pre-train",
"fine-tune",
"fine-tune-head",
"eval",
"parser-test",
"eval-metrics",
"eval-attr",
"continue",
"eval-center-bias",
"eval-size-bias",
"load-images",
"save-images",
],
required=True,
help="Task to perform.",
)
group.add_argument(
"-m",
"--model",
nargs="?",
type=str,
required=True,
help="Model to use. Either model name for a new model or weights and dicts to load for fine-tuning.",
)
group.add_argument("-ds", "--dataset", nargs="?", type=str, help="Dataset to train on.")
group.add_argument(
"-valds", "--val-dataset", nargs="?", type=str, help="Validation dataset. Defaults to same as training."
)
group.add_argument("-ep", "--epochs", nargs="?", type=int, help="Number of epochs to train.")
group.add_argument(
"-run",
"--run-name",
nargs="?",
type=str,
help="A name for the run. If not give, the model name is used instead.",
)
group.add_argument(
"--defaults", nargs="?", choices=["DeiT", "DeiTIII"], default="DeiTIII", help="Default settings to use."
)
# Further model parameters
group = parser.add_argument_group("Further model parameters")
group.add_argument("--drop-path-rate", nargs="?", type=float, help="Drop path rate for ViT models.")
group.add_argument("--layer-scale-init-values", nargs="?", type=float, help="LayerScale initial values.")
group.add_argument("--layer-scale", action=argparse.BooleanOptionalAction, help="Use layer scale?")
group.add_argument(
"--qkv-bias",
action=argparse.BooleanOptionalAction,
help="Use bias in linear transformation to queries, keys, and values?",
)
group.add_argument("--pre-norm", action=argparse.BooleanOptionalAction, help="Use norm first architecture?")
group.add_argument("--dropout", nargs="?", type=float, help="Model dropout.")
group.add_argument("-heads", "--num-heads", nargs="?", type=int, help="Number of parallel attention heads.")
group.add_argument("--input-dim", nargs="?", type=int, help="Dimensionality of text encoding.")
group.add_argument("--max-seq-len", nargs="?", type=int, help="Maximum sequence length for text data.")
group.add_argument(
"--fused-attn",
action=argparse.BooleanOptionalAction,
help="Use fused attention (for ViT with Timm's attention only)?",
)
# group.add_argument(
# "--perf-metric", nargs="?", choices=["acc", "mIoU"], help="Performance metric to use for evaluation."
# )
# group.add_argument("-no_model_ema", action="store_true",
# help="Don't use an exponential moving average for model parameters")
# group.add_argument("-model_ema_decay", nargs='?', type=float, default=default_kwargs["model_ema_decay"],
# help="Decay rate for exponential moving average of model parameters")
# Experiment management
group = parser.add_argument_group("Experiment management")
group.add_argument("--seed", nargs="?", type=int, help="Manual RNG seed.")
group.add_argument(
"-exp",
"--experiment-name",
nargs="?",
choices=_EXPNAMES,
help="Name for the experiment. Is used for grouping of runs.",
)
group.add_argument(
"--save-epochs", nargs="?", type=int, help="Number of epochs after which to save the full training state."
)
group.add_argument(
"--keep-interm-states",
nargs="?",
type=int,
help="Number of intermediate states to keep. All others (earlier ones) will be deleted automatically.",
)
group.add_argument(
"--custom-dataset-path", nargs="?", type=str, help="Overwrite the path to any dataset to this path."
)
group.add_argument(
"--results-folder",
nargs="?",
default=results_folder,
type=str,
help="Folder to put script results (mlflow data, models, etc.).",
)
group.add_argument(
"--gather-stats-during-training",
action=argparse.BooleanOptionalAction,
help="Gather training statistics from all GPUs?",
)
group.add_argument("--tqdm", action=argparse.BooleanOptionalAction, help="Show tqdm for every epoch?")
group.add_argument(
"--debug", action=argparse.BooleanOptionalAction, help="Debug mode: lots of intermediate prints."
)
group.add_argument("--wandb", action=argparse.BooleanOptionalAction, help="Use external logging via Wandb?")
group.add_argument("--log-level", choices=["info", "debug"], help="Log level", metavar="LEVEL")
group.add_argument("-out", "--out-dir", type=str, help="Output directory for additional outputs.")
# Speedup
group = parser.add_argument_group("Speedup")
group.add_argument("--amp", action=argparse.BooleanOptionalAction, help="Use automatic mixed precision?")
group.add_argument(
"--eval-amp", action=argparse.BooleanOptionalAction, help="Use automatic mixed precision during evaluation?"
)
group.add_argument("--compile-model", action=argparse.BooleanOptionalAction, help="Use torch.compile?")
group.add_argument("--cuda", action=argparse.BooleanOptionalAction, help="Use cuda?")
# Data loading
group = parser.add_argument_group("Data loading")
group.add_argument("-bs", "--batch-size", nargs="?", type=int, help="Batch size over all graphics cards (togeter).")
group.add_argument("--num-workers", nargs="?", type=int, help="Number of dataloader worker threads. Should be >0.")
group.add_argument(
"--pin-memory", action=argparse.BooleanOptionalAction, help="Use pin_memory of torch Dataloader?"
)
group.add_argument(
"--prefetch-factor",
nargs="?",
type=int,
help="Prefetch factor for dataloader workers (how many batches to fetch)",
)
group.add_argument("--shuffle", action=argparse.BooleanOptionalAction, help="Shuffle the training data?")
group.add_argument(
"--weighted-sampler",
action=argparse.BooleanOptionalAction,
help="Use a class-weighted sampler to sample evenly from all classes (train and val)?",
)
group.add_argument("--ipc", type=int, help="How many images per class to load and save.")
# Optimizer
group = parser.add_argument_group("Optimizer")
group.add_argument("--opt", nargs="?", type=str, help="Optimizer to use.")
group.add_argument("--weight-decay", nargs="?", type=float, help="Weight decay factor for use in optimizer.")
group.add_argument("-lr", "--lr", nargs="?", type=float, help="Initial learning rate.")
group.add_argument(
"--max-grad-norm", nargs="?", type=float, help="Maximum norm for the gradients (used for cutoff)."
)
group.add_argument("--warmup-epochs", nargs="?", type=int, help="Number of epochs of linear warmup.")
group.add_argument("--label-smoothing", nargs="?", type=float, help="Label smoothing factor.")
group.add_argument("--loss", nargs="?", choices=["ce", "baikal"], type=str, help="Loss function to use.")
group.add_argument(
"--loss-weight", nargs="?", type=str, choices=["none", "linear", "log", "sqrt"], help="Per class loss weight."
)
group.add_argument("--sched", nargs="?", choices=["cosine", "const"], help="Learning rate schedule.")
group.add_argument("--min-lr", nargs="?", type=float, help="Minimum learning rate to be hit by scheduler.")
group.add_argument("--warmup-lr", nargs="?", type=float, help="Warmup learning rate.")
group.add_argument("--warmup-sched", nargs="?", choices=["linear", "const"], help="Schedule for warmup")
group.add_argument(
"--opt-eps", nargs="?", type=float, help="Epsilon value added in the optimizer to stabilize training."
)
group.add_argument("--momentum", nargs="?", type=float, help="Optimizer momentum.")
# Data augmentation
group = parser.add_argument_group("Data augmentation")
group.add_argument("--augment-strategy", nargs="?", type=str, help="Data augmentation strategy.")
group.add_argument("--aug-rand-rot", nargs="?", type=int, help="Random rotation limit.")
group.add_argument(
"--aug-flip", action=argparse.BooleanOptionalAction, help="Use data augmentation: horizontal flip?"
)
group.add_argument(
"--aug-crop",
action=argparse.BooleanOptionalAction,
help="Use data augmentation: cropping. This may break the skript?",
)
group.add_argument("--aug-resize", action=argparse.BooleanOptionalAction, help="Use data augmentation: resize?")
group.add_argument(
"--aug-grayscale", action=argparse.BooleanOptionalAction, help="Use data augmentation: grayscale?"
)
group.add_argument("--aug-solarize", action=argparse.BooleanOptionalAction, help="Use data augmentation: solarize?")
group.add_argument(
"--aug-gauss-blur", action=argparse.BooleanOptionalAction, help="Use data augmentation: gaussian blur?"
)
group.add_argument(
"--aug-cutmix-alpha",
type=float,
help="Alpha value for using CutMix. CutMix is active when aug_cutmix_alpha > 0.",
)
group.add_argument(
"--aug-mixup-alpha", type=float, help="Alpha value for using Mixup. Mixup is active when aug_mixup_alpha > 0."
)
group.add_argument(
"--aug-color-jitter-factor",
nargs="?",
type=float,
help="Factor to use for the data augmentation: color jitter.",
)
group.add_argument(
"--aug-normalize", action=argparse.BooleanOptionalAction, help="Use data augmentation: Normalization?"
)
group.add_argument(
"--aug-repeated-augment-repeats",
type=int,
help="Number of image repeats with repeat-augment from DeiT. 1 is not using repeat-augment.",
)
group.add_argument(
"--aug-random-erase-prob", type=float, help="For DeiT augment: Probabiliy of RandomErase augmentation."
)
group.add_argument("--auto-augment-strategy", type=str, help="For DeiT augment: AutoAugment Policy to use.")
group.add_argument("--imsize", nargs="?", type=int, help="Image size given to the model -> imsize x imsize.")
group.add_argument(
"--augment-engine",
nargs="?",
choices=["torchvision", "albumentations", "dali"],
help="Which data augmentation engine to use.",
)
return parser
def partition_choices():
"""Automatically create a list of all possible slurm partitions."""
potential = list(set([l.split(" ")[0] for l in os.popen("sinfo")])) # noqa: E741
if len(potential) <= 2:
return slurm_defaults["partition"]
return [p[:-1] if "*" in p else p for p in potential if p != "PARTITION"]
def slurm_parser(parser=None):
"""Add srun arguments to the given parser.
Args:
parser (argparse.ArgumentParser, optional): base parser to extend; default is parser from *base_parser*
Returns:
parser (argparse.ArgumentParser): extended parser
"""
if parser is None:
parser = base_parser()
group = parser.add_argument_group("Slurm arguments")
group.add_argument(
"--partition",
nargs="*",
default=slurm_defaults["partition"],
choices=partition_choices(),
help="Slurm partition to use",
)
group.add_argument(
"--container-image",
nargs="?",
default=slurm_defaults["container_image"],
type=str,
help="Path to slurm container image (.sqsh)",
)
group.add_argument(
"--container-workdir",
nargs="?",
default=slurm_defaults["container_workdir"],
type=str,
help="Working directory in container",
)
group.add_argument(
"--container-mounts",
nargs="?",
default=slurm_defaults["container_mounts"],
type=str,
help="All slurm mounts separated by ','.",
)
group.add_argument(
"--job-name",
nargs="?",
default=slurm_defaults["job_name"],
type=str,
help="Slurm job name. Will default to '<model> <task> <dataset>'.",
)
group.add_argument(
"--nodes", nargs="?", default=slurm_defaults["nodes"], type=int, help="Number of cluster nodes to use."
)
group.add_argument(
"--ntasks", nargs="?", default=slurm_defaults["ntasks"], type=int, help="Number of GPUs to use for the job."
)
group.add_argument("--gpus", action=argparse.BooleanOptionalAction, default=True, help="Use gpus for this job?")
group.add_argument(
"-cpus",
"--cpus-per-task",
"--cpus-per-gpu",
nargs="?",
default=slurm_defaults["cpus_per_task"],
type=int,
help="Number of CPUs per task/GPU.",
)
group.add_argument(
"-mem",
"--mem-per-gpu",
"--mem-per-task",
nargs="?",
default=slurm_defaults["mem_per_gpu"],
type=int,
help="Ram per GPU (in Gb) to use. Will be given as total mem in srun command.",
)
group.add_argument(
"--task-prolog",
nargs="?",
default=slurm_defaults["task_prolog"],
type=str,
help="Shell script for task prolog (installing packages, etc.).",
)
group.add_argument("--time", nargs="?", default=slurm_defaults["time"], type=str, help="Slurm time limit.")
group.add_argument(
"--export",
nargs="?",
default=slurm_defaults["export"],
type=str,
help="Additional environment variables to export.",
)
group.add_argument("--exclude", nargs="?", default=slurm_defaults["exclude"], type=str, help="Nodes to exclude.")
group.add_argument(
"--after-job", nargs="?", default=slurm_defaults["after_job"], type=int, help="Job ID to wait for."
)
group.add_argument(
"--interactive",
action="store_true",
help=(
"Run using srun instead of sbatch. This will print the output into the terminal, not the slurm output file."
" The logfile will still be created as usual."
),
default=False,
)
group = parser.add_argument_group("Run locally")
group.add_argument("--local", action="store_true", help="Run locally; not in slurm", default=False)
return parser
def parse_args(args=None, parser=None):
"""Parse args from *base_parser* and insert defaults.
Args:
args: (Default value = None)
parser: (Default value = None)
Returns:
dict: parsed arguments
"""
if args is None:
parser = base_parser()
args = parser.parse_args()
args = dict(vars(args))
check_arg_completeness(args, parser)
return args
def check_arg_completeness(args, parser):
"""Check completeness of arguments.
Args:
args (dict): arguments to check
parser (argparse.ArgumentParser): for raising the parser error
Note:
will raise a parser error if the arguments are not complete.
"""
if args["task"] in ["pre-train", "fine-tune", "fine-tune-head"]:
if "run_name" not in args or args["run_name"] is None or len(args["run_name"]) == 0:
parser.error(f"-run_name is required for task {args['task']}")
if "experiment_name" not in args or args["experiment_name"] is None or len(args["experiment_name"]) == 0:
parser.error(f"-experiment_name is required for task {args['task']}. Choose from {_EXPNAMES}")
if "epochs" not in args or args["epochs"] is None:
parser.error(f"-epochs is required for task {args['task']}")
if ("dataset" not in args or args["dataset"] is None) and args["task"] in [
"pre-train",
"fine-tune",
"fine-tune-head",
"eval-metrics",
]:
parser.error(f"-dataset is required for task {args['task']}")
if (
("val_dataset" not in args or args["val_dataset"] is None)
and ("dataset" not in args or args["dataset"] is None)
and args["task"] in ["eval"]
):
parser.error(f"-dataset or -val_dataset is required for task {args['task']}")
if args["aug_repeated_augment_repeats"] is not None and args["aug_repeated_augment_repeats"] < 1:
parser.error(
"number of repeats for repeated augment has to be >= 1, but got -aug_repeated_augment_repeats ="
f" {args['aug_repeated_augment_repeats']}"
)
if args["task"] == "save-images" and ("out_dir" not in args or args["out_dir"] is None):
parser.error("Need to set save directory (--out-dir) to save the images in.")
def inside_slurm():
"""Test for being inside a slurm container.
Works by testing for environment variable 'RANK'.
"""
return "RANK" in os.environ
# TODO: fix ./runscript.tmp: 18: Syntax error: Unterminated quoted string
def create_runscript(args, file_name=None):
"""Create a run script for a distributed training job using SLURM.
Args:
args (dict): A dictionary containing various arguments for the job, including parameters for SLURM and for training.
file_name (str, optional, optional): The name of the file to create. Defaults to "runscript.tmp".
Returns:
str: The name of the created file.
str: Additional command line arguments for sbatch.
Example:
>>> args = {"model": "vit_large_patch16_384", "task": "pre-train", "batch_size": 256, ...}
>>> file_name = "my_run_script.sh"
>>> create_runscript(args, file_name)
"""
for key, val in slurm_defaults.items():
if key not in args and val is not None:
args[key] = val
if "run_name" not in args or args["run_name"] is None:
model_str = args["model"]
if model_str.endswith(".pt"):
model_str = os.path.dirname(model_str)
run_name = args["task"] + " " + model_str.split(os.sep)[-1].split("_")[0]
else:
run_name = args["run_name"]
job_name = run_name.replace(" ", "_").replace("/", "_").replace(">", "_").replace("<", "_")
if file_name is None:
file_name = (
f"experiments/sbatch/run_{args['task']}_{job_name}_at_{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}.sbatch"
)
task_args = ""
# slurm_command = "echo run distributed:\necho python3 main.py {0}\n\nsrun -K \\\n" # " --gpus-per-task=1 \\\n --gpu-bind=none \\\n"
srun_command = "\nsrun -K \\\n"
sbatch_commands = ( # outfile name is job name, date, job id, node name
"#!/bin/bash\n\n#SBATCH"
f" --output={slurm_output_folder}/%x-{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}-%j-%N.out\n"
)
sbatch_cmd_args = "" # additional command line arguments for sbatch
python_command = " python3 main.py {0}\n"
for key, val in args.items():
if key == "local":
continue
if key == "interactive":
continue
if key == "gpus":
continue
if key in slurm_defaults:
# it's a parameter for srun
# slurm has - instead of _
key = key.replace("_", "-")
if key == "mem-per-gpu":
# convert mem-per-gpu to mem
# slurm_command += f" --mem={val * args['ntasks'] // args['nodes']}G \\\n" # that amount of memory is assigned on each node
key = "mem"
val = f"{val * args['ntasks'] // args['nodes']}G" # that amount of memory is assigned on each node
# continue
if key == "job-name" and val is None:
# # default jobname is '<task> <model> <dataset>'
# model_str = args["model"]
# task = args["task"]
# if task == "pre-train":
# # it's just the model name...
# model = model_str.split("_")[0]
# else:
# # it's a path to the tar file
# if not model_str.startswith(res_folder):
# model = "<vit model>"
# else:
# model = model_str[len(res_folder) :].split("_")[1].split(" ")[0]
# if "dataset" in args and args["dataset"] is not None:
# dataset = args["dataset"]
# else:
# dataset = ""
val = run_name
if key == "job-name" and not val.startswith('"'):
val = f'"{val}"'
if key in ["task-prolog", "nodes", "exclude", "after-job"] and val is None:
continue
if key == "task-prolog":
srun_command += f' --{key}="{val}" \\\n'
continue
if key == "after-job":
sbatch_cmd_args += f"--dependency=afterany:{val} "
continue
if key == "partition" and isinstance(val, list):
val = ",".join(val)
if key == "ntasks":
if args["nodes"] == 1:
gpus = val if args["gpus"] else 0
# slurm_command += f" --gpus={val} \\\n"
sbatch_commands += f"#SBATCH --gpus={gpus}\n"
else:
assert (
val % args["nodes"] == 0
), f"Number of tasks ({val}) must be a multiple of the number of nodes ({args['nodes']})."
# slurm_command += f" --gpus-per-node={val // args['nodes']} \\\n"
sbatch_commands += f"#SBATCH --gpus-per-node={val // args['nodes']}\n"
sbatch_commands += "#SBATCH --ntasks-per-node=8\n"
if "container" in key:
srun_command += f" --{key}={val} \\\n"
else:
sbatch_commands += f"#SBATCH --{key}={val}\n"
# slurm_command += f" --{key}={val} \\\n"
else:
# it's a parameter for the training
if val is None:
continue
if key in ["results_folder"] and val == globals()[key]:
continue
key = key.replace("_", "-")
if isinstance(val, bool):
if val:
task_args += f"--{key} "
else:
task_args += f"--no-{key} "
continue
if isinstance(val, str):
task_args += f'--{key} "{val}" '
else:
task_args += f"--{key} {val} "
# slurm_command += "python3 main.py {0}\n"
# os.umask(0) # make it possible to create an executable file
# with open(file_name, "w+", opener=lambda pth, flgs: os.open(pth, flgs, 0o777)) as f:
# f.write(slurm_command.format(task_args))
with open(file_name, "w+") as f:
f.write(sbatch_commands + srun_command + python_command.format(task_args))
# delete all runscripts older than a month
n_old_files = int(os.popen("find experiments/sbatch/ -type f -mtime +30 | wc -l").read())
if n_old_files > 0:
print(f"Deleting {n_old_files} old runscripts.")
os.system("find experiments/sbatch/ -type f -mtime +30 -delete")
return file_name, sbatch_cmd_args
if __name__ == "__main__":
if not inside_slurm():
# Make execution script and execute it
parser = slurm_parser()
args = vars(parser.parse_args())
if not args["local"]:
script_name, cmd_args = create_runscript(args)
# os.system("./" + script_name) # run srun to execute this script in slurm cluster
# -> the following lines will be executed there
if args["interactive"]:
os.system(f"python3 srun-sbatch.py {script_name}")
else:
os.system(f"sbatch {cmd_args} {script_name}") # sbatch to queue the job on the cluster
exit(0)
# local execution is wanted
for key in list(args.keys()):
if key.replace("_", "-") in slurm_defaults:
args.pop(key)
args = parse_args(args, parser)
else:
args = parse_args()
args["branch"] = subprocess.check_output(["git", "branch", "--show-current"]).strip().decode("utf-8")
args["commit"] = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode("utf-8")
if args["task"] == "pre-train":
from train import pretrain
pretrain(**args)
elif args["task"] == "fine-tune":
from train import finetune
finetune(**args)
elif args["task"] == "fine-tune-head":
from train import finetune
finetune(**args, head_only=True)
elif args["task"] == "parser-test":
from copy import copy
from utils import prep_kwargs, log_args
kwargs = prep_kwargs(copy(args))
log_args(kwargs)
# keys = sorted(list(args.keys()))
# fill_len = max(len(k) for k in keys)
# for key in keys:
# print(f"{key + ' ' * (fill_len - len(key))} = {args[key]} -> {kwargs[key]}")
elif args["task"] == "eval-metrics":
from evaluate import evaluate_metrics
evaluate_metrics(**args)
elif args["task"] == "eval":
from evaluate import evaluate
evaluate(**args)
elif args["task"] == "eval-attr":
from evaluate import evaluate_attributions
evaluate_attributions(**args)
elif args["task"] == "continue":
from recover import continue_training
continue_training(**args)
elif args["task"] == "eval-center-bias":
from evaluate import evaluate_center_bias
evaluate_center_bias(**args)
elif args["task"] == "eval-size-bias":
from evaluate import evaluate_size_bias
evaluate_size_bias(**args)
elif args["task"] == "load-images":
from test import load_images
load_images(**args)
elif args["task"] == "save-images":
from test import save_images
save_images(**args)
else:
raise NotImplementedError(f"Task {args['task']} is not implemented.")