#!/usr/bin/env python3 """Parse args and call the correct script inside slurm container. Outside the container, on the head-node, create and call the correct srun command. """ import argparse import os import subprocess from datetime import datetime from config import default_kwargs, slurm_defaults from paths_config import results_folder, slurm_output_folder _EXPNAMES = ["EfficientCVBench", "test", "recombine_imagenet"] def base_parser(): """Create the argument parser with all the choices for the training / evaluation scripts.""" parser = argparse.ArgumentParser("Transformer training and evaluation.") # Main group = parser.add_argument_group("Main") group.add_argument( "-t", "--task", nargs="?", choices=[ "pre-train", "fine-tune", "fine-tune-head", "eval", "parser-test", "eval-metrics", "eval-attr", "continue", "eval-center-bias", "eval-size-bias", "load-images", "save-images", ], required=True, help="Task to perform.", ) group.add_argument( "-m", "--model", nargs="?", type=str, required=True, help="Model to use. Either model name for a new model or weights and dicts to load for fine-tuning.", ) group.add_argument("-ds", "--dataset", nargs="?", type=str, help="Dataset to train on.") group.add_argument( "-valds", "--val-dataset", nargs="?", type=str, help="Validation dataset. Defaults to same as training." ) group.add_argument("-ep", "--epochs", nargs="?", type=int, help="Number of epochs to train.") group.add_argument( "-run", "--run-name", nargs="?", type=str, help="A name for the run. If not give, the model name is used instead.", ) group.add_argument( "--defaults", nargs="?", choices=["DeiT", "DeiTIII"], default="DeiTIII", help="Default settings to use." ) # Further model parameters group = parser.add_argument_group("Further model parameters") group.add_argument("--drop-path-rate", nargs="?", type=float, help="Drop path rate for ViT models.") group.add_argument("--layer-scale-init-values", nargs="?", type=float, help="LayerScale initial values.") group.add_argument("--layer-scale", action=argparse.BooleanOptionalAction, help="Use layer scale?") group.add_argument( "--qkv-bias", action=argparse.BooleanOptionalAction, help="Use bias in linear transformation to queries, keys, and values?", ) group.add_argument("--pre-norm", action=argparse.BooleanOptionalAction, help="Use norm first architecture?") group.add_argument("--dropout", nargs="?", type=float, help="Model dropout.") group.add_argument("-heads", "--num-heads", nargs="?", type=int, help="Number of parallel attention heads.") group.add_argument("--input-dim", nargs="?", type=int, help="Dimensionality of text encoding.") group.add_argument("--max-seq-len", nargs="?", type=int, help="Maximum sequence length for text data.") group.add_argument( "--fused-attn", action=argparse.BooleanOptionalAction, help="Use fused attention (for ViT with Timm's attention only)?", ) # group.add_argument( # "--perf-metric", nargs="?", choices=["acc", "mIoU"], help="Performance metric to use for evaluation." # ) # group.add_argument("-no_model_ema", action="store_true", # help="Don't use an exponential moving average for model parameters") # group.add_argument("-model_ema_decay", nargs='?', type=float, default=default_kwargs["model_ema_decay"], # help="Decay rate for exponential moving average of model parameters") # Experiment management group = parser.add_argument_group("Experiment management") group.add_argument("--seed", nargs="?", type=int, help="Manual RNG seed.") group.add_argument( "-exp", "--experiment-name", nargs="?", choices=_EXPNAMES, help="Name for the experiment. Is used for grouping of runs.", ) group.add_argument( "--save-epochs", nargs="?", type=int, help="Number of epochs after which to save the full training state." ) group.add_argument( "--keep-interm-states", nargs="?", type=int, help="Number of intermediate states to keep. All others (earlier ones) will be deleted automatically.", ) group.add_argument( "--custom-dataset-path", nargs="?", type=str, help="Overwrite the path to any dataset to this path." ) group.add_argument( "--results-folder", nargs="?", default=results_folder, type=str, help="Folder to put script results (mlflow data, models, etc.).", ) group.add_argument( "--gather-stats-during-training", action=argparse.BooleanOptionalAction, help="Gather training statistics from all GPUs?", ) group.add_argument("--tqdm", action=argparse.BooleanOptionalAction, help="Show tqdm for every epoch?") group.add_argument( "--debug", action=argparse.BooleanOptionalAction, help="Debug mode: lots of intermediate prints." ) group.add_argument("--wandb", action=argparse.BooleanOptionalAction, help="Use external logging via Wandb?") group.add_argument("--log-level", choices=["info", "debug"], help="Log level", metavar="LEVEL") group.add_argument("-out", "--out-dir", type=str, help="Output directory for additional outputs.") # Speedup group = parser.add_argument_group("Speedup") group.add_argument("--amp", action=argparse.BooleanOptionalAction, help="Use automatic mixed precision?") group.add_argument( "--eval-amp", action=argparse.BooleanOptionalAction, help="Use automatic mixed precision during evaluation?" ) group.add_argument("--compile-model", action=argparse.BooleanOptionalAction, help="Use torch.compile?") group.add_argument("--cuda", action=argparse.BooleanOptionalAction, help="Use cuda?") # Data loading group = parser.add_argument_group("Data loading") group.add_argument("-bs", "--batch-size", nargs="?", type=int, help="Batch size over all graphics cards (togeter).") group.add_argument("--num-workers", nargs="?", type=int, help="Number of dataloader worker threads. Should be >0.") group.add_argument( "--pin-memory", action=argparse.BooleanOptionalAction, help="Use pin_memory of torch Dataloader?" ) group.add_argument( "--prefetch-factor", nargs="?", type=int, help="Prefetch factor for dataloader workers (how many batches to fetch)", ) group.add_argument("--shuffle", action=argparse.BooleanOptionalAction, help="Shuffle the training data?") group.add_argument( "--weighted-sampler", action=argparse.BooleanOptionalAction, help="Use a class-weighted sampler to sample evenly from all classes (train and val)?", ) group.add_argument("--ipc", type=int, help="How many images per class to load and save.") # Optimizer group = parser.add_argument_group("Optimizer") group.add_argument("--opt", nargs="?", type=str, help="Optimizer to use.") group.add_argument("--weight-decay", nargs="?", type=float, help="Weight decay factor for use in optimizer.") group.add_argument("-lr", "--lr", nargs="?", type=float, help="Initial learning rate.") group.add_argument( "--max-grad-norm", nargs="?", type=float, help="Maximum norm for the gradients (used for cutoff)." ) group.add_argument("--warmup-epochs", nargs="?", type=int, help="Number of epochs of linear warmup.") group.add_argument("--label-smoothing", nargs="?", type=float, help="Label smoothing factor.") group.add_argument("--loss", nargs="?", choices=["ce", "baikal"], type=str, help="Loss function to use.") group.add_argument( "--loss-weight", nargs="?", type=str, choices=["none", "linear", "log", "sqrt"], help="Per class loss weight." ) group.add_argument("--sched", nargs="?", choices=["cosine", "const"], help="Learning rate schedule.") group.add_argument("--min-lr", nargs="?", type=float, help="Minimum learning rate to be hit by scheduler.") group.add_argument("--warmup-lr", nargs="?", type=float, help="Warmup learning rate.") group.add_argument("--warmup-sched", nargs="?", choices=["linear", "const"], help="Schedule for warmup") group.add_argument( "--opt-eps", nargs="?", type=float, help="Epsilon value added in the optimizer to stabilize training." ) group.add_argument("--momentum", nargs="?", type=float, help="Optimizer momentum.") # Data augmentation group = parser.add_argument_group("Data augmentation") group.add_argument("--augment-strategy", nargs="?", type=str, help="Data augmentation strategy.") group.add_argument("--aug-rand-rot", nargs="?", type=int, help="Random rotation limit.") group.add_argument( "--aug-flip", action=argparse.BooleanOptionalAction, help="Use data augmentation: horizontal flip?" ) group.add_argument( "--aug-crop", action=argparse.BooleanOptionalAction, help="Use data augmentation: cropping. This may break the skript?", ) group.add_argument("--aug-resize", action=argparse.BooleanOptionalAction, help="Use data augmentation: resize?") group.add_argument( "--aug-grayscale", action=argparse.BooleanOptionalAction, help="Use data augmentation: grayscale?" ) group.add_argument("--aug-solarize", action=argparse.BooleanOptionalAction, help="Use data augmentation: solarize?") group.add_argument( "--aug-gauss-blur", action=argparse.BooleanOptionalAction, help="Use data augmentation: gaussian blur?" ) group.add_argument( "--aug-cutmix-alpha", type=float, help="Alpha value for using CutMix. CutMix is active when aug_cutmix_alpha > 0.", ) group.add_argument( "--aug-mixup-alpha", type=float, help="Alpha value for using Mixup. Mixup is active when aug_mixup_alpha > 0." ) group.add_argument( "--aug-color-jitter-factor", nargs="?", type=float, help="Factor to use for the data augmentation: color jitter.", ) group.add_argument( "--aug-normalize", action=argparse.BooleanOptionalAction, help="Use data augmentation: Normalization?" ) group.add_argument( "--aug-repeated-augment-repeats", type=int, help="Number of image repeats with repeat-augment from DeiT. 1 is not using repeat-augment.", ) group.add_argument( "--aug-random-erase-prob", type=float, help="For DeiT augment: Probabiliy of RandomErase augmentation." ) group.add_argument("--auto-augment-strategy", type=str, help="For DeiT augment: AutoAugment Policy to use.") group.add_argument("--imsize", nargs="?", type=int, help="Image size given to the model -> imsize x imsize.") group.add_argument( "--augment-engine", nargs="?", choices=["torchvision", "albumentations", "dali"], help="Which data augmentation engine to use.", ) return parser def partition_choices(): """Automatically create a list of all possible slurm partitions.""" potential = list(set([l.split(" ")[0] for l in os.popen("sinfo")])) # noqa: E741 if len(potential) <= 2: return slurm_defaults["partition"] return [p[:-1] if "*" in p else p for p in potential if p != "PARTITION"] def slurm_parser(parser=None): """Add srun arguments to the given parser. Args: parser (argparse.ArgumentParser, optional): base parser to extend; default is parser from *base_parser* Returns: parser (argparse.ArgumentParser): extended parser """ if parser is None: parser = base_parser() group = parser.add_argument_group("Slurm arguments") group.add_argument( "--partition", nargs="*", default=slurm_defaults["partition"], choices=partition_choices(), help="Slurm partition to use", ) group.add_argument( "--container-image", nargs="?", default=slurm_defaults["container_image"], type=str, help="Path to slurm container image (.sqsh)", ) group.add_argument( "--container-workdir", nargs="?", default=slurm_defaults["container_workdir"], type=str, help="Working directory in container", ) group.add_argument( "--container-mounts", nargs="?", default=slurm_defaults["container_mounts"], type=str, help="All slurm mounts separated by ','.", ) group.add_argument( "--job-name", nargs="?", default=slurm_defaults["job_name"], type=str, help="Slurm job name. Will default to ' '.", ) group.add_argument( "--nodes", nargs="?", default=slurm_defaults["nodes"], type=int, help="Number of cluster nodes to use." ) group.add_argument( "--ntasks", nargs="?", default=slurm_defaults["ntasks"], type=int, help="Number of GPUs to use for the job." ) group.add_argument("--gpus", action=argparse.BooleanOptionalAction, default=True, help="Use gpus for this job?") group.add_argument( "-cpus", "--cpus-per-task", "--cpus-per-gpu", nargs="?", default=slurm_defaults["cpus_per_task"], type=int, help="Number of CPUs per task/GPU.", ) group.add_argument( "-mem", "--mem-per-gpu", "--mem-per-task", nargs="?", default=slurm_defaults["mem_per_gpu"], type=int, help="Ram per GPU (in Gb) to use. Will be given as total mem in srun command.", ) group.add_argument( "--task-prolog", nargs="?", default=slurm_defaults["task_prolog"], type=str, help="Shell script for task prolog (installing packages, etc.).", ) group.add_argument("--time", nargs="?", default=slurm_defaults["time"], type=str, help="Slurm time limit.") group.add_argument( "--export", nargs="?", default=slurm_defaults["export"], type=str, help="Additional environment variables to export.", ) group.add_argument("--exclude", nargs="?", default=slurm_defaults["exclude"], type=str, help="Nodes to exclude.") group.add_argument( "--after-job", nargs="?", default=slurm_defaults["after_job"], type=int, help="Job ID to wait for." ) group.add_argument( "--interactive", action="store_true", help=( "Run using srun instead of sbatch. This will print the output into the terminal, not the slurm output file." " The logfile will still be created as usual." ), default=False, ) group = parser.add_argument_group("Run locally") group.add_argument("--local", action="store_true", help="Run locally; not in slurm", default=False) return parser def parse_args(args=None, parser=None): """Parse args from *base_parser* and insert defaults. Args: args: (Default value = None) parser: (Default value = None) Returns: dict: parsed arguments """ if args is None: parser = base_parser() args = parser.parse_args() args = dict(vars(args)) check_arg_completeness(args, parser) return args def check_arg_completeness(args, parser): """Check completeness of arguments. Args: args (dict): arguments to check parser (argparse.ArgumentParser): for raising the parser error Note: will raise a parser error if the arguments are not complete. """ if args["task"] in ["pre-train", "fine-tune", "fine-tune-head"]: if "run_name" not in args or args["run_name"] is None or len(args["run_name"]) == 0: parser.error(f"-run_name is required for task {args['task']}") if "experiment_name" not in args or args["experiment_name"] is None or len(args["experiment_name"]) == 0: parser.error(f"-experiment_name is required for task {args['task']}. Choose from {_EXPNAMES}") if "epochs" not in args or args["epochs"] is None: parser.error(f"-epochs is required for task {args['task']}") if ("dataset" not in args or args["dataset"] is None) and args["task"] in [ "pre-train", "fine-tune", "fine-tune-head", "eval-metrics", ]: parser.error(f"-dataset is required for task {args['task']}") if ( ("val_dataset" not in args or args["val_dataset"] is None) and ("dataset" not in args or args["dataset"] is None) and args["task"] in ["eval"] ): parser.error(f"-dataset or -val_dataset is required for task {args['task']}") if args["aug_repeated_augment_repeats"] is not None and args["aug_repeated_augment_repeats"] < 1: parser.error( "number of repeats for repeated augment has to be >= 1, but got -aug_repeated_augment_repeats =" f" {args['aug_repeated_augment_repeats']}" ) if args["task"] == "save-images" and ("out_dir" not in args or args["out_dir"] is None): parser.error("Need to set save directory (--out-dir) to save the images in.") def inside_slurm(): """Test for being inside a slurm container. Works by testing for environment variable 'RANK'. """ return "RANK" in os.environ # TODO: fix ./runscript.tmp: 18: Syntax error: Unterminated quoted string def create_runscript(args, file_name=None): """Create a run script for a distributed training job using SLURM. Args: args (dict): A dictionary containing various arguments for the job, including parameters for SLURM and for training. file_name (str, optional, optional): The name of the file to create. Defaults to "runscript.tmp". Returns: str: The name of the created file. str: Additional command line arguments for sbatch. Example: >>> args = {"model": "vit_large_patch16_384", "task": "pre-train", "batch_size": 256, ...} >>> file_name = "my_run_script.sh" >>> create_runscript(args, file_name) """ for key, val in slurm_defaults.items(): if key not in args and val is not None: args[key] = val if "run_name" not in args or args["run_name"] is None: model_str = args["model"] if model_str.endswith(".pt"): model_str = os.path.dirname(model_str) run_name = args["task"] + " " + model_str.split(os.sep)[-1].split("_")[0] else: run_name = args["run_name"] job_name = run_name.replace(" ", "_").replace("/", "_").replace(">", "_").replace("<", "_") if file_name is None: file_name = ( f"experiments/sbatch/run_{args['task']}_{job_name}_at_{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}.sbatch" ) task_args = "" # slurm_command = "echo run distributed:\necho python3 main.py {0}\n\nsrun -K \\\n" # " --gpus-per-task=1 \\\n --gpu-bind=none \\\n" srun_command = "\nsrun -K \\\n" sbatch_commands = ( # outfile name is job name, date, job id, node name "#!/bin/bash\n\n#SBATCH" f" --output={slurm_output_folder}/%x-{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}-%j-%N.out\n" ) sbatch_cmd_args = "" # additional command line arguments for sbatch python_command = " python3 main.py {0}\n" for key, val in args.items(): if key == "local": continue if key == "interactive": continue if key == "gpus": continue if key in slurm_defaults: # it's a parameter for srun # slurm has - instead of _ key = key.replace("_", "-") if key == "mem-per-gpu": # convert mem-per-gpu to mem # slurm_command += f" --mem={val * args['ntasks'] // args['nodes']}G \\\n" # that amount of memory is assigned on each node key = "mem" val = f"{val * args['ntasks'] // args['nodes']}G" # that amount of memory is assigned on each node # continue if key == "job-name" and val is None: # # default jobname is ' ' # model_str = args["model"] # task = args["task"] # if task == "pre-train": # # it's just the model name... # model = model_str.split("_")[0] # else: # # it's a path to the tar file # if not model_str.startswith(res_folder): # model = "" # else: # model = model_str[len(res_folder) :].split("_")[1].split(" ")[0] # if "dataset" in args and args["dataset"] is not None: # dataset = args["dataset"] # else: # dataset = "" val = run_name if key == "job-name" and not val.startswith('"'): val = f'"{val}"' if key in ["task-prolog", "nodes", "exclude", "after-job"] and val is None: continue if key == "task-prolog": srun_command += f' --{key}="{val}" \\\n' continue if key == "after-job": sbatch_cmd_args += f"--dependency=afterany:{val} " continue if key == "partition" and isinstance(val, list): val = ",".join(val) if key == "ntasks": if args["nodes"] == 1: gpus = val if args["gpus"] else 0 # slurm_command += f" --gpus={val} \\\n" sbatch_commands += f"#SBATCH --gpus={gpus}\n" else: assert ( val % args["nodes"] == 0 ), f"Number of tasks ({val}) must be a multiple of the number of nodes ({args['nodes']})." # slurm_command += f" --gpus-per-node={val // args['nodes']} \\\n" sbatch_commands += f"#SBATCH --gpus-per-node={val // args['nodes']}\n" sbatch_commands += "#SBATCH --ntasks-per-node=8\n" if "container" in key: srun_command += f" --{key}={val} \\\n" else: sbatch_commands += f"#SBATCH --{key}={val}\n" # slurm_command += f" --{key}={val} \\\n" else: # it's a parameter for the training if val is None: continue if key in ["results_folder"] and val == globals()[key]: continue key = key.replace("_", "-") if isinstance(val, bool): if val: task_args += f"--{key} " else: task_args += f"--no-{key} " continue if isinstance(val, str): task_args += f'--{key} "{val}" ' else: task_args += f"--{key} {val} " # slurm_command += "python3 main.py {0}\n" # os.umask(0) # make it possible to create an executable file # with open(file_name, "w+", opener=lambda pth, flgs: os.open(pth, flgs, 0o777)) as f: # f.write(slurm_command.format(task_args)) with open(file_name, "w+") as f: f.write(sbatch_commands + srun_command + python_command.format(task_args)) # delete all runscripts older than a month n_old_files = int(os.popen("find experiments/sbatch/ -type f -mtime +30 | wc -l").read()) if n_old_files > 0: print(f"Deleting {n_old_files} old runscripts.") os.system("find experiments/sbatch/ -type f -mtime +30 -delete") return file_name, sbatch_cmd_args if __name__ == "__main__": if not inside_slurm(): # Make execution script and execute it parser = slurm_parser() args = vars(parser.parse_args()) if not args["local"]: script_name, cmd_args = create_runscript(args) # os.system("./" + script_name) # run srun to execute this script in slurm cluster # -> the following lines will be executed there if args["interactive"]: os.system(f"python3 srun-sbatch.py {script_name}") else: os.system(f"sbatch {cmd_args} {script_name}") # sbatch to queue the job on the cluster exit(0) # local execution is wanted for key in list(args.keys()): if key.replace("_", "-") in slurm_defaults: args.pop(key) args = parse_args(args, parser) else: args = parse_args() args["branch"] = subprocess.check_output(["git", "branch", "--show-current"]).strip().decode("utf-8") args["commit"] = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode("utf-8") if args["task"] == "pre-train": from train import pretrain pretrain(**args) elif args["task"] == "fine-tune": from train import finetune finetune(**args) elif args["task"] == "fine-tune-head": from train import finetune finetune(**args, head_only=True) elif args["task"] == "parser-test": from copy import copy from utils import prep_kwargs, log_args kwargs = prep_kwargs(copy(args)) log_args(kwargs) # keys = sorted(list(args.keys())) # fill_len = max(len(k) for k in keys) # for key in keys: # print(f"{key + ' ' * (fill_len - len(key))} = {args[key]} -> {kwargs[key]}") elif args["task"] == "eval-metrics": from evaluate import evaluate_metrics evaluate_metrics(**args) elif args["task"] == "eval": from evaluate import evaluate evaluate(**args) elif args["task"] == "eval-attr": from evaluate import evaluate_attributions evaluate_attributions(**args) elif args["task"] == "continue": from recover import continue_training continue_training(**args) elif args["task"] == "eval-center-bias": from evaluate import evaluate_center_bias evaluate_center_bias(**args) elif args["task"] == "eval-size-bias": from evaluate import evaluate_size_bias evaluate_size_bias(**args) elif args["task"] == "load-images": from test import load_images load_images(**args) elif args["task"] == "save-images": from test import save_images save_images(**args) else: raise NotImplementedError(f"Task {args['task']} is not implemented.")