196 lines
7.3 KiB
Python
196 lines
7.3 KiB
Python
import argparse
|
|
import os
|
|
import re
|
|
import subprocess
|
|
|
|
parser = argparse.ArgumentParser(description="Create Segment & Recombine dataset")
|
|
parser.add_argument("-d", "--dataset", choices=["imagenet", "tinyimagenet"], required=True, help="Dataset to use")
|
|
parser.add_argument("-r", "--dataset_root", required=True, help="Root directory of the dataset")
|
|
parser.add_argument("-t", "--threshold", type=float, default=0.3, help="Detection threshold")
|
|
parser.add_argument(
|
|
"--out_root", type=str, required=True, help="Root directory where the output directory will be created."
|
|
)
|
|
parser.add_argument("-o", "--output", type=str, required=True, help="Output directory")
|
|
parser.add_argument(
|
|
"--parent_labels", type=int, default=2, help="Number of parent labels to use; steps to go up the tree"
|
|
)
|
|
parser.add_argument("--output_ims", choices=["best", "all"], default="best", help="Output all or best masks")
|
|
parser.add_argument("--mask_merge_threshold", type=float, default=0.9, help="Threshold on IoU for merging masks")
|
|
parser.add_argument("--parent_in_prompt", action="store_true", help="Include parent label in the prompt")
|
|
parser.add_argument("-infill_model", choices=["LaMa", "AttErase"], default="LaMa", help="Infilling model to use")
|
|
parser.add_argument("--continue", dest="continue_", action="store_true", help="Continue from previous run")
|
|
|
|
args = parser.parse_args()
|
|
out_root = args.out_root
|
|
|
|
base_folder = os.path.dirname(__file__)
|
|
training_code_folder = os.path.join(base_folder, os.pardir, "Model Training Code")
|
|
|
|
ds_name_re = re.compile(r"(Tiny)?INSegment_v(\d*)(_f\d*)?")
|
|
name_match = ds_name_re.match(args.output)
|
|
assert name_match, f"Output name {args.output} does not match the expected format: {ds_name_re.pattern}"
|
|
assert (
|
|
args.output_ims == "best" or name_match.group(3) is None
|
|
), "For output_ims == 'all', the filter subversions will be automatically created."
|
|
assert args.continue_ or not os.path.exists(out_root + args.output), f"Output directory {args.output} already exists."
|
|
|
|
settings_file = {
|
|
"dataset": args.dataset,
|
|
"threshold": args.threshold,
|
|
"output_ims": args.output_ims,
|
|
"mask_merge_threshold": args.mask_merge_threshold,
|
|
"parent_labels": args.parent_labels,
|
|
"parent_in_prompt": args.parent_in_prompt,
|
|
"infill_model": args.infill_model,
|
|
}
|
|
settings_file = [f"{k} = {str(settings_file[k])}" for k in sorted(list(settings_file.keys()))]
|
|
if os.path.exists(out_root + args.output) and os.path.exists(out_root + args.output + "/settings.txt"):
|
|
with open(out_root + args.output + "/settings.txt", "r") as f:
|
|
old_settings = f.read().split("\n")
|
|
old_settings = [line.strip() for line in old_settings if line if len(line.strip()) > 0]
|
|
assert old_settings == settings_file, (
|
|
f"Settings file {out_root + args.output}/settings.txt does not match current settings: old: {old_settings} vs"
|
|
f" new: {settings_file}"
|
|
)
|
|
else:
|
|
os.makedirs(out_root + args.output, exist_ok=True)
|
|
with open(out_root + args.output + "/settings.txt", "w") as f:
|
|
f.write("\n".join(settings_file) + "\n")
|
|
|
|
general_args = [
|
|
"sbatch",
|
|
"sbatch-segment-tinyimnet-wait" if args.dataset == "tinyimagenet" else "sbatch-segment-imagenet-wait",
|
|
"-r",
|
|
args.dataset_root,
|
|
"-o",
|
|
out_root + args.output,
|
|
"--parent_labels",
|
|
str(args.parent_labels),
|
|
"--output_ims",
|
|
args.output_ims,
|
|
"--mask_merge_threshold",
|
|
str(args.mask_merge_threshold),
|
|
"-t",
|
|
str(args.threshold),
|
|
"-model",
|
|
args.infill_model,
|
|
]
|
|
|
|
if args.parent_in_prompt:
|
|
general_args.append("--parent_in_prompt")
|
|
|
|
print(f"Starting segmentation: {' '.join(general_args)} for {args.dataset}-val and {args.dataset}")
|
|
p_train = subprocess.Popen(general_args + ["-d", args.dataset], cwd=base_folder)
|
|
if args.dataset == "imagenet":
|
|
general_args[1] = "sbatch-segment-imagenet-val-wait"
|
|
p_val = subprocess.Popen(general_args + ["-d", args.dataset + "-val"], cwd=base_folder)
|
|
|
|
# detect if exit in error
|
|
p_val.wait()
|
|
p_train.wait()
|
|
rcodes = (p_val.returncode, p_train.returncode)
|
|
if any(rcode != 0 for rcode in rcodes):
|
|
print(f"Error in segmentation (val, train): {rcodes}")
|
|
exit(1)
|
|
print("Segmentation done.")
|
|
|
|
if args.output_ims == "all":
|
|
print("copy to subversions for filtering")
|
|
p_1 = subprocess.Popen(
|
|
[
|
|
"cp",
|
|
"-rl",
|
|
os.path.join(out_root, args.output),
|
|
out_root + f"{args.output}_f1/",
|
|
]
|
|
)
|
|
p_2 = subprocess.Popen(
|
|
[
|
|
"cp",
|
|
"-rl",
|
|
os.path.join(out_root, args.output),
|
|
out_root + f"{args.output}_f2/",
|
|
]
|
|
)
|
|
p_1.wait()
|
|
p_2.wait()
|
|
print("Filtering subversions copied over.")
|
|
|
|
print("Starting filtering")
|
|
filtering_args = ["./experiments/general_srun.sh", "experiments/filter_segmentation_versions.py"]
|
|
p_val_f1 = subprocess.Popen(
|
|
filtering_args + ["-f", out_root + f"{args.output}_f1" + "/val", "-d", args.dataset], cwd=training_code_folder
|
|
)
|
|
p_train_f1 = subprocess.Popen(
|
|
filtering_args + ["-f", out_root + f"{args.output}_f1" + "/train", "-d", args.dataset], cwd=training_code_folder
|
|
)
|
|
p_val_f2 = subprocess.Popen(
|
|
filtering_args
|
|
+ ["-f", out_root + f"{args.output}_f2" + "/val", "-score_f_weights", "automatic", "-d", args.dataset],
|
|
cwd=training_code_folder,
|
|
)
|
|
p_train_f2 = subprocess.Popen(
|
|
filtering_args
|
|
+ ["-f", out_root + f"{args.output}_f2" + "/train", "-score_f_weights", "automatic", "-d", args.dataset],
|
|
cwd=training_code_folder,
|
|
)
|
|
|
|
p_val_f1.wait()
|
|
p_train_f1.wait()
|
|
p_val_f2.wait()
|
|
p_train_f2.wait()
|
|
|
|
ds_folders = [f"{args.output}_f1", f"{args.output}_f2"]
|
|
else:
|
|
if name_match.group(2) is None:
|
|
new_name = args.output + "_f1"
|
|
print(f"Renaming output folder: {args.output} -> {new_name}")
|
|
os.rename(out_root + args.output, out_root + new_name)
|
|
else:
|
|
new_name = args.output
|
|
ds_folders = [new_name]
|
|
|
|
for folder in ds_folders:
|
|
print(f"Zipping up {folder}")
|
|
p_val_fg = subprocess.Popen(
|
|
["zip", "-r", "-0", "foregrounds_val.zip", "val/foregrounds", ">", "/dev/null", "2>&1"], cwd=out_root + folder
|
|
)
|
|
p_train_fg = subprocess.Popen(
|
|
["zip", "-r", "-0", "foregrounds_train.zip", "train/foregrounds", ">", "/dev/null", "2>&1"],
|
|
cwd=out_root + folder,
|
|
)
|
|
p_val_bg = subprocess.Popen(
|
|
["zip", "-r", "-0", "backgrounds_val.zip", "val/backgrounds", ">", "/dev/null", "2>&1"], cwd=out_root + folder
|
|
)
|
|
p_train_bg = subprocess.Popen(
|
|
["zip", "-r", "-0", "backgrounds_train.zip", "train/backgrounds", ">", "/dev/null", "2>&1"],
|
|
cwd=out_root + folder,
|
|
)
|
|
|
|
p_val_fg.wait()
|
|
p_train_fg.wait()
|
|
p_val_bg.wait()
|
|
p_train_bg.wait()
|
|
|
|
print(f"Gathering foreground size ratios for {folder}")
|
|
p_val = subprocess.Popen(
|
|
[
|
|
"./srun-general.sh",
|
|
"python",
|
|
"foreground_size_ratio.py",
|
|
"--root",
|
|
args.dataset_root,
|
|
"-ds",
|
|
folder,
|
|
"-mode",
|
|
"val",
|
|
],
|
|
cwd=base_folder,
|
|
)
|
|
p_train = subprocess.Popen(
|
|
["./srun-general.sh", "python", "foreground_size_ratio.py", "--root", args.dataset_root, "-ds", folder],
|
|
cwd=base_folder,
|
|
)
|
|
p_val.wait()
|
|
p_train.wait()
|