AAAI Version
This commit is contained in:
679
AAAI Supplementary Material/Model Training Code/main.py
Normal file
679
AAAI Supplementary Material/Model Training Code/main.py
Normal file
@@ -0,0 +1,679 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""Parse args and call the correct script inside slurm container.
|
||||
|
||||
Outside the container, on the head-node, create and call the correct srun command.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import subprocess
|
||||
from datetime import datetime
|
||||
|
||||
from config import default_kwargs, slurm_defaults
|
||||
from paths_config import results_folder, slurm_output_folder
|
||||
|
||||
_EXPNAMES = ["EfficientCVBench", "test", "recombine_imagenet"]
|
||||
|
||||
|
||||
def base_parser():
|
||||
"""Create the argument parser with all the choices for the training / evaluation scripts."""
|
||||
parser = argparse.ArgumentParser("Transformer training and evaluation.")
|
||||
|
||||
# Main
|
||||
group = parser.add_argument_group("Main")
|
||||
group.add_argument(
|
||||
"-t",
|
||||
"--task",
|
||||
nargs="?",
|
||||
choices=[
|
||||
"pre-train",
|
||||
"fine-tune",
|
||||
"fine-tune-head",
|
||||
"eval",
|
||||
"parser-test",
|
||||
"eval-metrics",
|
||||
"eval-attr",
|
||||
"continue",
|
||||
"eval-center-bias",
|
||||
"eval-size-bias",
|
||||
"load-images",
|
||||
"save-images",
|
||||
],
|
||||
required=True,
|
||||
help="Task to perform.",
|
||||
)
|
||||
group.add_argument(
|
||||
"-m",
|
||||
"--model",
|
||||
nargs="?",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Model to use. Either model name for a new model or weights and dicts to load for fine-tuning.",
|
||||
)
|
||||
group.add_argument("-ds", "--dataset", nargs="?", type=str, help="Dataset to train on.")
|
||||
group.add_argument(
|
||||
"-valds", "--val-dataset", nargs="?", type=str, help="Validation dataset. Defaults to same as training."
|
||||
)
|
||||
group.add_argument("-ep", "--epochs", nargs="?", type=int, help="Number of epochs to train.")
|
||||
group.add_argument(
|
||||
"-run",
|
||||
"--run-name",
|
||||
nargs="?",
|
||||
type=str,
|
||||
help="A name for the run. If not give, the model name is used instead.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--defaults", nargs="?", choices=["DeiT", "DeiTIII"], default="DeiTIII", help="Default settings to use."
|
||||
)
|
||||
|
||||
# Further model parameters
|
||||
group = parser.add_argument_group("Further model parameters")
|
||||
group.add_argument("--drop-path-rate", nargs="?", type=float, help="Drop path rate for ViT models.")
|
||||
group.add_argument("--layer-scale-init-values", nargs="?", type=float, help="LayerScale initial values.")
|
||||
group.add_argument("--layer-scale", action=argparse.BooleanOptionalAction, help="Use layer scale?")
|
||||
group.add_argument(
|
||||
"--qkv-bias",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Use bias in linear transformation to queries, keys, and values?",
|
||||
)
|
||||
group.add_argument("--pre-norm", action=argparse.BooleanOptionalAction, help="Use norm first architecture?")
|
||||
group.add_argument("--dropout", nargs="?", type=float, help="Model dropout.")
|
||||
group.add_argument("-heads", "--num-heads", nargs="?", type=int, help="Number of parallel attention heads.")
|
||||
group.add_argument("--input-dim", nargs="?", type=int, help="Dimensionality of text encoding.")
|
||||
group.add_argument("--max-seq-len", nargs="?", type=int, help="Maximum sequence length for text data.")
|
||||
group.add_argument(
|
||||
"--fused-attn",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Use fused attention (for ViT with Timm's attention only)?",
|
||||
)
|
||||
# group.add_argument(
|
||||
# "--perf-metric", nargs="?", choices=["acc", "mIoU"], help="Performance metric to use for evaluation."
|
||||
# )
|
||||
# group.add_argument("-no_model_ema", action="store_true",
|
||||
# help="Don't use an exponential moving average for model parameters")
|
||||
# group.add_argument("-model_ema_decay", nargs='?', type=float, default=default_kwargs["model_ema_decay"],
|
||||
# help="Decay rate for exponential moving average of model parameters")
|
||||
|
||||
# Experiment management
|
||||
group = parser.add_argument_group("Experiment management")
|
||||
group.add_argument("--seed", nargs="?", type=int, help="Manual RNG seed.")
|
||||
group.add_argument(
|
||||
"-exp",
|
||||
"--experiment-name",
|
||||
nargs="?",
|
||||
choices=_EXPNAMES,
|
||||
help="Name for the experiment. Is used for grouping of runs.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--save-epochs", nargs="?", type=int, help="Number of epochs after which to save the full training state."
|
||||
)
|
||||
group.add_argument(
|
||||
"--keep-interm-states",
|
||||
nargs="?",
|
||||
type=int,
|
||||
help="Number of intermediate states to keep. All others (earlier ones) will be deleted automatically.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--custom-dataset-path", nargs="?", type=str, help="Overwrite the path to any dataset to this path."
|
||||
)
|
||||
group.add_argument(
|
||||
"--results-folder",
|
||||
nargs="?",
|
||||
default=results_folder,
|
||||
type=str,
|
||||
help="Folder to put script results (mlflow data, models, etc.).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--gather-stats-during-training",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Gather training statistics from all GPUs?",
|
||||
)
|
||||
group.add_argument("--tqdm", action=argparse.BooleanOptionalAction, help="Show tqdm for every epoch?")
|
||||
group.add_argument(
|
||||
"--debug", action=argparse.BooleanOptionalAction, help="Debug mode: lots of intermediate prints."
|
||||
)
|
||||
group.add_argument("--wandb", action=argparse.BooleanOptionalAction, help="Use external logging via Wandb?")
|
||||
group.add_argument("--log-level", choices=["info", "debug"], help="Log level", metavar="LEVEL")
|
||||
group.add_argument("-out", "--out-dir", type=str, help="Output directory for additional outputs.")
|
||||
|
||||
# Speedup
|
||||
group = parser.add_argument_group("Speedup")
|
||||
group.add_argument("--amp", action=argparse.BooleanOptionalAction, help="Use automatic mixed precision?")
|
||||
group.add_argument(
|
||||
"--eval-amp", action=argparse.BooleanOptionalAction, help="Use automatic mixed precision during evaluation?"
|
||||
)
|
||||
group.add_argument("--compile-model", action=argparse.BooleanOptionalAction, help="Use torch.compile?")
|
||||
group.add_argument("--cuda", action=argparse.BooleanOptionalAction, help="Use cuda?")
|
||||
|
||||
# Data loading
|
||||
group = parser.add_argument_group("Data loading")
|
||||
group.add_argument("-bs", "--batch-size", nargs="?", type=int, help="Batch size over all graphics cards (togeter).")
|
||||
group.add_argument("--num-workers", nargs="?", type=int, help="Number of dataloader worker threads. Should be >0.")
|
||||
group.add_argument(
|
||||
"--pin-memory", action=argparse.BooleanOptionalAction, help="Use pin_memory of torch Dataloader?"
|
||||
)
|
||||
group.add_argument(
|
||||
"--prefetch-factor",
|
||||
nargs="?",
|
||||
type=int,
|
||||
help="Prefetch factor for dataloader workers (how many batches to fetch)",
|
||||
)
|
||||
group.add_argument("--shuffle", action=argparse.BooleanOptionalAction, help="Shuffle the training data?")
|
||||
group.add_argument(
|
||||
"--weighted-sampler",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Use a class-weighted sampler to sample evenly from all classes (train and val)?",
|
||||
)
|
||||
group.add_argument("--ipc", type=int, help="How many images per class to load and save.")
|
||||
|
||||
# Optimizer
|
||||
group = parser.add_argument_group("Optimizer")
|
||||
group.add_argument("--opt", nargs="?", type=str, help="Optimizer to use.")
|
||||
group.add_argument("--weight-decay", nargs="?", type=float, help="Weight decay factor for use in optimizer.")
|
||||
group.add_argument("-lr", "--lr", nargs="?", type=float, help="Initial learning rate.")
|
||||
group.add_argument(
|
||||
"--max-grad-norm", nargs="?", type=float, help="Maximum norm for the gradients (used for cutoff)."
|
||||
)
|
||||
group.add_argument("--warmup-epochs", nargs="?", type=int, help="Number of epochs of linear warmup.")
|
||||
group.add_argument("--label-smoothing", nargs="?", type=float, help="Label smoothing factor.")
|
||||
group.add_argument("--loss", nargs="?", choices=["ce", "baikal"], type=str, help="Loss function to use.")
|
||||
group.add_argument(
|
||||
"--loss-weight", nargs="?", type=str, choices=["none", "linear", "log", "sqrt"], help="Per class loss weight."
|
||||
)
|
||||
group.add_argument("--sched", nargs="?", choices=["cosine", "const"], help="Learning rate schedule.")
|
||||
group.add_argument("--min-lr", nargs="?", type=float, help="Minimum learning rate to be hit by scheduler.")
|
||||
group.add_argument("--warmup-lr", nargs="?", type=float, help="Warmup learning rate.")
|
||||
group.add_argument("--warmup-sched", nargs="?", choices=["linear", "const"], help="Schedule for warmup")
|
||||
group.add_argument(
|
||||
"--opt-eps", nargs="?", type=float, help="Epsilon value added in the optimizer to stabilize training."
|
||||
)
|
||||
group.add_argument("--momentum", nargs="?", type=float, help="Optimizer momentum.")
|
||||
|
||||
# Data augmentation
|
||||
group = parser.add_argument_group("Data augmentation")
|
||||
group.add_argument("--augment-strategy", nargs="?", type=str, help="Data augmentation strategy.")
|
||||
group.add_argument("--aug-rand-rot", nargs="?", type=int, help="Random rotation limit.")
|
||||
group.add_argument(
|
||||
"--aug-flip", action=argparse.BooleanOptionalAction, help="Use data augmentation: horizontal flip?"
|
||||
)
|
||||
group.add_argument(
|
||||
"--aug-crop",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Use data augmentation: cropping. This may break the skript?",
|
||||
)
|
||||
group.add_argument("--aug-resize", action=argparse.BooleanOptionalAction, help="Use data augmentation: resize?")
|
||||
group.add_argument(
|
||||
"--aug-grayscale", action=argparse.BooleanOptionalAction, help="Use data augmentation: grayscale?"
|
||||
)
|
||||
group.add_argument("--aug-solarize", action=argparse.BooleanOptionalAction, help="Use data augmentation: solarize?")
|
||||
group.add_argument(
|
||||
"--aug-gauss-blur", action=argparse.BooleanOptionalAction, help="Use data augmentation: gaussian blur?"
|
||||
)
|
||||
group.add_argument(
|
||||
"--aug-cutmix-alpha",
|
||||
type=float,
|
||||
help="Alpha value for using CutMix. CutMix is active when aug_cutmix_alpha > 0.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--aug-mixup-alpha", type=float, help="Alpha value for using Mixup. Mixup is active when aug_mixup_alpha > 0."
|
||||
)
|
||||
group.add_argument(
|
||||
"--aug-color-jitter-factor",
|
||||
nargs="?",
|
||||
type=float,
|
||||
help="Factor to use for the data augmentation: color jitter.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--aug-normalize", action=argparse.BooleanOptionalAction, help="Use data augmentation: Normalization?"
|
||||
)
|
||||
group.add_argument(
|
||||
"--aug-repeated-augment-repeats",
|
||||
type=int,
|
||||
help="Number of image repeats with repeat-augment from DeiT. 1 is not using repeat-augment.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--aug-random-erase-prob", type=float, help="For DeiT augment: Probabiliy of RandomErase augmentation."
|
||||
)
|
||||
group.add_argument("--auto-augment-strategy", type=str, help="For DeiT augment: AutoAugment Policy to use.")
|
||||
group.add_argument("--imsize", nargs="?", type=int, help="Image size given to the model -> imsize x imsize.")
|
||||
group.add_argument(
|
||||
"--augment-engine",
|
||||
nargs="?",
|
||||
choices=["torchvision", "albumentations", "dali"],
|
||||
help="Which data augmentation engine to use.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def partition_choices():
|
||||
"""Automatically create a list of all possible slurm partitions."""
|
||||
potential = list(set([l.split(" ")[0] for l in os.popen("sinfo")])) # noqa: E741
|
||||
if len(potential) <= 2:
|
||||
return slurm_defaults["partition"]
|
||||
return [p[:-1] if "*" in p else p for p in potential if p != "PARTITION"]
|
||||
|
||||
|
||||
def slurm_parser(parser=None):
|
||||
"""Add srun arguments to the given parser.
|
||||
|
||||
Args:
|
||||
parser (argparse.ArgumentParser, optional): base parser to extend; default is parser from *base_parser*
|
||||
|
||||
Returns:
|
||||
parser (argparse.ArgumentParser): extended parser
|
||||
|
||||
"""
|
||||
if parser is None:
|
||||
parser = base_parser()
|
||||
group = parser.add_argument_group("Slurm arguments")
|
||||
group.add_argument(
|
||||
"--partition",
|
||||
nargs="*",
|
||||
default=slurm_defaults["partition"],
|
||||
choices=partition_choices(),
|
||||
help="Slurm partition to use",
|
||||
)
|
||||
group.add_argument(
|
||||
"--container-image",
|
||||
nargs="?",
|
||||
default=slurm_defaults["container_image"],
|
||||
type=str,
|
||||
help="Path to slurm container image (.sqsh)",
|
||||
)
|
||||
group.add_argument(
|
||||
"--container-workdir",
|
||||
nargs="?",
|
||||
default=slurm_defaults["container_workdir"],
|
||||
type=str,
|
||||
help="Working directory in container",
|
||||
)
|
||||
group.add_argument(
|
||||
"--container-mounts",
|
||||
nargs="?",
|
||||
default=slurm_defaults["container_mounts"],
|
||||
type=str,
|
||||
help="All slurm mounts separated by ','.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--job-name",
|
||||
nargs="?",
|
||||
default=slurm_defaults["job_name"],
|
||||
type=str,
|
||||
help="Slurm job name. Will default to '<model> <task> <dataset>'.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--nodes", nargs="?", default=slurm_defaults["nodes"], type=int, help="Number of cluster nodes to use."
|
||||
)
|
||||
group.add_argument(
|
||||
"--ntasks", nargs="?", default=slurm_defaults["ntasks"], type=int, help="Number of GPUs to use for the job."
|
||||
)
|
||||
group.add_argument("--gpus", action=argparse.BooleanOptionalAction, default=True, help="Use gpus for this job?")
|
||||
group.add_argument(
|
||||
"-cpus",
|
||||
"--cpus-per-task",
|
||||
"--cpus-per-gpu",
|
||||
nargs="?",
|
||||
default=slurm_defaults["cpus_per_task"],
|
||||
type=int,
|
||||
help="Number of CPUs per task/GPU.",
|
||||
)
|
||||
group.add_argument(
|
||||
"-mem",
|
||||
"--mem-per-gpu",
|
||||
"--mem-per-task",
|
||||
nargs="?",
|
||||
default=slurm_defaults["mem_per_gpu"],
|
||||
type=int,
|
||||
help="Ram per GPU (in Gb) to use. Will be given as total mem in srun command.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--task-prolog",
|
||||
nargs="?",
|
||||
default=slurm_defaults["task_prolog"],
|
||||
type=str,
|
||||
help="Shell script for task prolog (installing packages, etc.).",
|
||||
)
|
||||
group.add_argument("--time", nargs="?", default=slurm_defaults["time"], type=str, help="Slurm time limit.")
|
||||
group.add_argument(
|
||||
"--export",
|
||||
nargs="?",
|
||||
default=slurm_defaults["export"],
|
||||
type=str,
|
||||
help="Additional environment variables to export.",
|
||||
)
|
||||
group.add_argument("--exclude", nargs="?", default=slurm_defaults["exclude"], type=str, help="Nodes to exclude.")
|
||||
group.add_argument(
|
||||
"--after-job", nargs="?", default=slurm_defaults["after_job"], type=int, help="Job ID to wait for."
|
||||
)
|
||||
group.add_argument(
|
||||
"--interactive",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Run using srun instead of sbatch. This will print the output into the terminal, not the slurm output file."
|
||||
" The logfile will still be created as usual."
|
||||
),
|
||||
default=False,
|
||||
)
|
||||
|
||||
group = parser.add_argument_group("Run locally")
|
||||
group.add_argument("--local", action="store_true", help="Run locally; not in slurm", default=False)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def parse_args(args=None, parser=None):
|
||||
"""Parse args from *base_parser* and insert defaults.
|
||||
|
||||
Args:
|
||||
args: (Default value = None)
|
||||
parser: (Default value = None)
|
||||
|
||||
Returns:
|
||||
dict: parsed arguments
|
||||
|
||||
"""
|
||||
if args is None:
|
||||
parser = base_parser()
|
||||
args = parser.parse_args()
|
||||
args = dict(vars(args))
|
||||
|
||||
check_arg_completeness(args, parser)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def check_arg_completeness(args, parser):
|
||||
"""Check completeness of arguments.
|
||||
|
||||
Args:
|
||||
args (dict): arguments to check
|
||||
parser (argparse.ArgumentParser): for raising the parser error
|
||||
Note:
|
||||
will raise a parser error if the arguments are not complete.
|
||||
"""
|
||||
if args["task"] in ["pre-train", "fine-tune", "fine-tune-head"]:
|
||||
if "run_name" not in args or args["run_name"] is None or len(args["run_name"]) == 0:
|
||||
parser.error(f"-run_name is required for task {args['task']}")
|
||||
|
||||
if "experiment_name" not in args or args["experiment_name"] is None or len(args["experiment_name"]) == 0:
|
||||
parser.error(f"-experiment_name is required for task {args['task']}. Choose from {_EXPNAMES}")
|
||||
|
||||
if "epochs" not in args or args["epochs"] is None:
|
||||
parser.error(f"-epochs is required for task {args['task']}")
|
||||
|
||||
if ("dataset" not in args or args["dataset"] is None) and args["task"] in [
|
||||
"pre-train",
|
||||
"fine-tune",
|
||||
"fine-tune-head",
|
||||
"eval-metrics",
|
||||
]:
|
||||
parser.error(f"-dataset is required for task {args['task']}")
|
||||
|
||||
if (
|
||||
("val_dataset" not in args or args["val_dataset"] is None)
|
||||
and ("dataset" not in args or args["dataset"] is None)
|
||||
and args["task"] in ["eval"]
|
||||
):
|
||||
parser.error(f"-dataset or -val_dataset is required for task {args['task']}")
|
||||
|
||||
if args["aug_repeated_augment_repeats"] is not None and args["aug_repeated_augment_repeats"] < 1:
|
||||
parser.error(
|
||||
"number of repeats for repeated augment has to be >= 1, but got -aug_repeated_augment_repeats ="
|
||||
f" {args['aug_repeated_augment_repeats']}"
|
||||
)
|
||||
|
||||
if args["task"] == "save-images" and ("out_dir" not in args or args["out_dir"] is None):
|
||||
parser.error("Need to set save directory (--out-dir) to save the images in.")
|
||||
|
||||
|
||||
def inside_slurm():
|
||||
"""Test for being inside a slurm container.
|
||||
|
||||
Works by testing for environment variable 'RANK'.
|
||||
"""
|
||||
return "RANK" in os.environ
|
||||
|
||||
|
||||
# TODO: fix ./runscript.tmp: 18: Syntax error: Unterminated quoted string
|
||||
def create_runscript(args, file_name=None):
|
||||
"""Create a run script for a distributed training job using SLURM.
|
||||
|
||||
Args:
|
||||
args (dict): A dictionary containing various arguments for the job, including parameters for SLURM and for training.
|
||||
file_name (str, optional, optional): The name of the file to create. Defaults to "runscript.tmp".
|
||||
|
||||
Returns:
|
||||
str: The name of the created file.
|
||||
str: Additional command line arguments for sbatch.
|
||||
|
||||
Example:
|
||||
>>> args = {"model": "vit_large_patch16_384", "task": "pre-train", "batch_size": 256, ...}
|
||||
|
||||
>>> file_name = "my_run_script.sh"
|
||||
|
||||
>>> create_runscript(args, file_name)
|
||||
|
||||
"""
|
||||
for key, val in slurm_defaults.items():
|
||||
if key not in args and val is not None:
|
||||
args[key] = val
|
||||
|
||||
if "run_name" not in args or args["run_name"] is None:
|
||||
model_str = args["model"]
|
||||
if model_str.endswith(".pt"):
|
||||
model_str = os.path.dirname(model_str)
|
||||
run_name = args["task"] + " " + model_str.split(os.sep)[-1].split("_")[0]
|
||||
else:
|
||||
run_name = args["run_name"]
|
||||
job_name = run_name.replace(" ", "_").replace("/", "_").replace(">", "_").replace("<", "_")
|
||||
if file_name is None:
|
||||
file_name = (
|
||||
f"experiments/sbatch/run_{args['task']}_{job_name}_at_{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}.sbatch"
|
||||
)
|
||||
|
||||
task_args = ""
|
||||
# slurm_command = "echo run distributed:\necho python3 main.py {0}\n\nsrun -K \\\n" # " --gpus-per-task=1 \\\n --gpu-bind=none \\\n"
|
||||
srun_command = "\nsrun -K \\\n"
|
||||
sbatch_commands = ( # outfile name is job name, date, job id, node name
|
||||
"#!/bin/bash\n\n#SBATCH"
|
||||
f" --output={slurm_output_folder}/%x-{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}-%j-%N.out\n"
|
||||
)
|
||||
sbatch_cmd_args = "" # additional command line arguments for sbatch
|
||||
python_command = " python3 main.py {0}\n"
|
||||
for key, val in args.items():
|
||||
if key == "local":
|
||||
continue
|
||||
if key == "interactive":
|
||||
continue
|
||||
if key == "gpus":
|
||||
continue
|
||||
if key in slurm_defaults:
|
||||
# it's a parameter for srun
|
||||
# slurm has - instead of _
|
||||
key = key.replace("_", "-")
|
||||
if key == "mem-per-gpu":
|
||||
# convert mem-per-gpu to mem
|
||||
# slurm_command += f" --mem={val * args['ntasks'] // args['nodes']}G \\\n" # that amount of memory is assigned on each node
|
||||
key = "mem"
|
||||
val = f"{val * args['ntasks'] // args['nodes']}G" # that amount of memory is assigned on each node
|
||||
# continue
|
||||
if key == "job-name" and val is None:
|
||||
# # default jobname is '<task> <model> <dataset>'
|
||||
# model_str = args["model"]
|
||||
# task = args["task"]
|
||||
# if task == "pre-train":
|
||||
# # it's just the model name...
|
||||
# model = model_str.split("_")[0]
|
||||
# else:
|
||||
# # it's a path to the tar file
|
||||
# if not model_str.startswith(res_folder):
|
||||
# model = "<vit model>"
|
||||
# else:
|
||||
# model = model_str[len(res_folder) :].split("_")[1].split(" ")[0]
|
||||
# if "dataset" in args and args["dataset"] is not None:
|
||||
# dataset = args["dataset"]
|
||||
# else:
|
||||
# dataset = ""
|
||||
val = run_name
|
||||
if key == "job-name" and not val.startswith('"'):
|
||||
val = f'"{val}"'
|
||||
if key in ["task-prolog", "nodes", "exclude", "after-job"] and val is None:
|
||||
continue
|
||||
if key == "task-prolog":
|
||||
srun_command += f' --{key}="{val}" \\\n'
|
||||
continue
|
||||
if key == "after-job":
|
||||
sbatch_cmd_args += f"--dependency=afterany:{val} "
|
||||
continue
|
||||
if key == "partition" and isinstance(val, list):
|
||||
val = ",".join(val)
|
||||
if key == "ntasks":
|
||||
if args["nodes"] == 1:
|
||||
gpus = val if args["gpus"] else 0
|
||||
# slurm_command += f" --gpus={val} \\\n"
|
||||
sbatch_commands += f"#SBATCH --gpus={gpus}\n"
|
||||
else:
|
||||
assert (
|
||||
val % args["nodes"] == 0
|
||||
), f"Number of tasks ({val}) must be a multiple of the number of nodes ({args['nodes']})."
|
||||
# slurm_command += f" --gpus-per-node={val // args['nodes']} \\\n"
|
||||
sbatch_commands += f"#SBATCH --gpus-per-node={val // args['nodes']}\n"
|
||||
sbatch_commands += "#SBATCH --ntasks-per-node=8\n"
|
||||
if "container" in key:
|
||||
srun_command += f" --{key}={val} \\\n"
|
||||
else:
|
||||
sbatch_commands += f"#SBATCH --{key}={val}\n"
|
||||
# slurm_command += f" --{key}={val} \\\n"
|
||||
else:
|
||||
# it's a parameter for the training
|
||||
if val is None:
|
||||
continue
|
||||
if key in ["results_folder"] and val == globals()[key]:
|
||||
continue
|
||||
key = key.replace("_", "-")
|
||||
if isinstance(val, bool):
|
||||
if val:
|
||||
task_args += f"--{key} "
|
||||
else:
|
||||
task_args += f"--no-{key} "
|
||||
continue
|
||||
if isinstance(val, str):
|
||||
task_args += f'--{key} "{val}" '
|
||||
else:
|
||||
task_args += f"--{key} {val} "
|
||||
|
||||
# slurm_command += "python3 main.py {0}\n"
|
||||
# os.umask(0) # make it possible to create an executable file
|
||||
# with open(file_name, "w+", opener=lambda pth, flgs: os.open(pth, flgs, 0o777)) as f:
|
||||
# f.write(slurm_command.format(task_args))
|
||||
with open(file_name, "w+") as f:
|
||||
f.write(sbatch_commands + srun_command + python_command.format(task_args))
|
||||
|
||||
# delete all runscripts older than a month
|
||||
n_old_files = int(os.popen("find experiments/sbatch/ -type f -mtime +30 | wc -l").read())
|
||||
if n_old_files > 0:
|
||||
print(f"Deleting {n_old_files} old runscripts.")
|
||||
os.system("find experiments/sbatch/ -type f -mtime +30 -delete")
|
||||
|
||||
return file_name, sbatch_cmd_args
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if not inside_slurm():
|
||||
# Make execution script and execute it
|
||||
parser = slurm_parser()
|
||||
args = vars(parser.parse_args())
|
||||
|
||||
if not args["local"]:
|
||||
script_name, cmd_args = create_runscript(args)
|
||||
# os.system("./" + script_name) # run srun to execute this script in slurm cluster
|
||||
# -> the following lines will be executed there
|
||||
if args["interactive"]:
|
||||
os.system(f"python3 srun-sbatch.py {script_name}")
|
||||
else:
|
||||
os.system(f"sbatch {cmd_args} {script_name}") # sbatch to queue the job on the cluster
|
||||
exit(0)
|
||||
|
||||
# local execution is wanted
|
||||
for key in list(args.keys()):
|
||||
if key.replace("_", "-") in slurm_defaults:
|
||||
args.pop(key)
|
||||
args = parse_args(args, parser)
|
||||
|
||||
else:
|
||||
args = parse_args()
|
||||
|
||||
args["branch"] = subprocess.check_output(["git", "branch", "--show-current"]).strip().decode("utf-8")
|
||||
args["commit"] = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode("utf-8")
|
||||
|
||||
if args["task"] == "pre-train":
|
||||
from train import pretrain
|
||||
|
||||
pretrain(**args)
|
||||
|
||||
elif args["task"] == "fine-tune":
|
||||
from train import finetune
|
||||
|
||||
finetune(**args)
|
||||
|
||||
elif args["task"] == "fine-tune-head":
|
||||
from train import finetune
|
||||
|
||||
finetune(**args, head_only=True)
|
||||
|
||||
elif args["task"] == "parser-test":
|
||||
from copy import copy
|
||||
|
||||
from utils import prep_kwargs, log_args
|
||||
|
||||
kwargs = prep_kwargs(copy(args))
|
||||
log_args(kwargs)
|
||||
# keys = sorted(list(args.keys()))
|
||||
# fill_len = max(len(k) for k in keys)
|
||||
# for key in keys:
|
||||
# print(f"{key + ' ' * (fill_len - len(key))} = {args[key]} -> {kwargs[key]}")
|
||||
|
||||
elif args["task"] == "eval-metrics":
|
||||
from evaluate import evaluate_metrics
|
||||
|
||||
evaluate_metrics(**args)
|
||||
|
||||
elif args["task"] == "eval":
|
||||
from evaluate import evaluate
|
||||
|
||||
evaluate(**args)
|
||||
|
||||
elif args["task"] == "eval-attr":
|
||||
from evaluate import evaluate_attributions
|
||||
|
||||
evaluate_attributions(**args)
|
||||
|
||||
elif args["task"] == "continue":
|
||||
from recover import continue_training
|
||||
|
||||
continue_training(**args)
|
||||
|
||||
elif args["task"] == "eval-center-bias":
|
||||
from evaluate import evaluate_center_bias
|
||||
|
||||
evaluate_center_bias(**args)
|
||||
|
||||
elif args["task"] == "eval-size-bias":
|
||||
from evaluate import evaluate_size_bias
|
||||
|
||||
evaluate_size_bias(**args)
|
||||
|
||||
elif args["task"] == "load-images":
|
||||
from test import load_images
|
||||
|
||||
load_images(**args)
|
||||
|
||||
elif args["task"] == "save-images":
|
||||
from test import save_images
|
||||
|
||||
save_images(**args)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Task {args['task']} is not implemented.")
|
||||
Reference in New Issue
Block a user