AAAI Version
This commit is contained in:
@@ -0,0 +1,195 @@
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
|
||||
parser = argparse.ArgumentParser(description="Create Segment & Recombine dataset")
|
||||
parser.add_argument("-d", "--dataset", choices=["imagenet", "tinyimagenet"], required=True, help="Dataset to use")
|
||||
parser.add_argument("-r", "--dataset_root", required=True, help="Root directory of the dataset")
|
||||
parser.add_argument("-t", "--threshold", type=float, default=0.3, help="Detection threshold")
|
||||
parser.add_argument(
|
||||
"--out_root", type=str, required=True, help="Root directory where the output directory will be created."
|
||||
)
|
||||
parser.add_argument("-o", "--output", type=str, required=True, help="Output directory")
|
||||
parser.add_argument(
|
||||
"--parent_labels", type=int, default=2, help="Number of parent labels to use; steps to go up the tree"
|
||||
)
|
||||
parser.add_argument("--output_ims", choices=["best", "all"], default="best", help="Output all or best masks")
|
||||
parser.add_argument("--mask_merge_threshold", type=float, default=0.9, help="Threshold on IoU for merging masks")
|
||||
parser.add_argument("--parent_in_prompt", action="store_true", help="Include parent label in the prompt")
|
||||
parser.add_argument("-infill_model", choices=["LaMa", "AttErase"], default="LaMa", help="Infilling model to use")
|
||||
parser.add_argument("--continue", dest="continue_", action="store_true", help="Continue from previous run")
|
||||
|
||||
args = parser.parse_args()
|
||||
out_root = args.out_root
|
||||
|
||||
base_folder = os.path.dirname(__file__)
|
||||
training_code_folder = os.path.join(base_folder, os.pardir, "Model Training Code")
|
||||
|
||||
ds_name_re = re.compile(r"(Tiny)?INSegment_v(\d*)(_f\d*)?")
|
||||
name_match = ds_name_re.match(args.output)
|
||||
assert name_match, f"Output name {args.output} does not match the expected format: {ds_name_re.pattern}"
|
||||
assert (
|
||||
args.output_ims == "best" or name_match.group(3) is None
|
||||
), "For output_ims == 'all', the filter subversions will be automatically created."
|
||||
assert args.continue_ or not os.path.exists(out_root + args.output), f"Output directory {args.output} already exists."
|
||||
|
||||
settings_file = {
|
||||
"dataset": args.dataset,
|
||||
"threshold": args.threshold,
|
||||
"output_ims": args.output_ims,
|
||||
"mask_merge_threshold": args.mask_merge_threshold,
|
||||
"parent_labels": args.parent_labels,
|
||||
"parent_in_prompt": args.parent_in_prompt,
|
||||
"infill_model": args.infill_model,
|
||||
}
|
||||
settings_file = [f"{k} = {str(settings_file[k])}" for k in sorted(list(settings_file.keys()))]
|
||||
if os.path.exists(out_root + args.output) and os.path.exists(out_root + args.output + "/settings.txt"):
|
||||
with open(out_root + args.output + "/settings.txt", "r") as f:
|
||||
old_settings = f.read().split("\n")
|
||||
old_settings = [line.strip() for line in old_settings if line if len(line.strip()) > 0]
|
||||
assert old_settings == settings_file, (
|
||||
f"Settings file {out_root + args.output}/settings.txt does not match current settings: old: {old_settings} vs"
|
||||
f" new: {settings_file}"
|
||||
)
|
||||
else:
|
||||
os.makedirs(out_root + args.output, exist_ok=True)
|
||||
with open(out_root + args.output + "/settings.txt", "w") as f:
|
||||
f.write("\n".join(settings_file) + "\n")
|
||||
|
||||
general_args = [
|
||||
"sbatch",
|
||||
"sbatch-segment-tinyimnet-wait" if args.dataset == "tinyimagenet" else "sbatch-segment-imagenet-wait",
|
||||
"-r",
|
||||
args.dataset_root,
|
||||
"-o",
|
||||
out_root + args.output,
|
||||
"--parent_labels",
|
||||
str(args.parent_labels),
|
||||
"--output_ims",
|
||||
args.output_ims,
|
||||
"--mask_merge_threshold",
|
||||
str(args.mask_merge_threshold),
|
||||
"-t",
|
||||
str(args.threshold),
|
||||
"-model",
|
||||
args.infill_model,
|
||||
]
|
||||
|
||||
if args.parent_in_prompt:
|
||||
general_args.append("--parent_in_prompt")
|
||||
|
||||
print(f"Starting segmentation: {' '.join(general_args)} for {args.dataset}-val and {args.dataset}")
|
||||
p_train = subprocess.Popen(general_args + ["-d", args.dataset], cwd=base_folder)
|
||||
if args.dataset == "imagenet":
|
||||
general_args[1] = "sbatch-segment-imagenet-val-wait"
|
||||
p_val = subprocess.Popen(general_args + ["-d", args.dataset + "-val"], cwd=base_folder)
|
||||
|
||||
# detect if exit in error
|
||||
p_val.wait()
|
||||
p_train.wait()
|
||||
rcodes = (p_val.returncode, p_train.returncode)
|
||||
if any(rcode != 0 for rcode in rcodes):
|
||||
print(f"Error in segmentation (val, train): {rcodes}")
|
||||
exit(1)
|
||||
print("Segmentation done.")
|
||||
|
||||
if args.output_ims == "all":
|
||||
print("copy to subversions for filtering")
|
||||
p_1 = subprocess.Popen(
|
||||
[
|
||||
"cp",
|
||||
"-rl",
|
||||
os.path.join(out_root, args.output),
|
||||
out_root + f"{args.output}_f1/",
|
||||
]
|
||||
)
|
||||
p_2 = subprocess.Popen(
|
||||
[
|
||||
"cp",
|
||||
"-rl",
|
||||
os.path.join(out_root, args.output),
|
||||
out_root + f"{args.output}_f2/",
|
||||
]
|
||||
)
|
||||
p_1.wait()
|
||||
p_2.wait()
|
||||
print("Filtering subversions copied over.")
|
||||
|
||||
print("Starting filtering")
|
||||
filtering_args = ["./experiments/general_srun.sh", "experiments/filter_segmentation_versions.py"]
|
||||
p_val_f1 = subprocess.Popen(
|
||||
filtering_args + ["-f", out_root + f"{args.output}_f1" + "/val", "-d", args.dataset], cwd=training_code_folder
|
||||
)
|
||||
p_train_f1 = subprocess.Popen(
|
||||
filtering_args + ["-f", out_root + f"{args.output}_f1" + "/train", "-d", args.dataset], cwd=training_code_folder
|
||||
)
|
||||
p_val_f2 = subprocess.Popen(
|
||||
filtering_args
|
||||
+ ["-f", out_root + f"{args.output}_f2" + "/val", "-score_f_weights", "automatic", "-d", args.dataset],
|
||||
cwd=training_code_folder,
|
||||
)
|
||||
p_train_f2 = subprocess.Popen(
|
||||
filtering_args
|
||||
+ ["-f", out_root + f"{args.output}_f2" + "/train", "-score_f_weights", "automatic", "-d", args.dataset],
|
||||
cwd=training_code_folder,
|
||||
)
|
||||
|
||||
p_val_f1.wait()
|
||||
p_train_f1.wait()
|
||||
p_val_f2.wait()
|
||||
p_train_f2.wait()
|
||||
|
||||
ds_folders = [f"{args.output}_f1", f"{args.output}_f2"]
|
||||
else:
|
||||
if name_match.group(2) is None:
|
||||
new_name = args.output + "_f1"
|
||||
print(f"Renaming output folder: {args.output} -> {new_name}")
|
||||
os.rename(out_root + args.output, out_root + new_name)
|
||||
else:
|
||||
new_name = args.output
|
||||
ds_folders = [new_name]
|
||||
|
||||
for folder in ds_folders:
|
||||
print(f"Zipping up {folder}")
|
||||
p_val_fg = subprocess.Popen(
|
||||
["zip", "-r", "-0", "foregrounds_val.zip", "val/foregrounds", ">", "/dev/null", "2>&1"], cwd=out_root + folder
|
||||
)
|
||||
p_train_fg = subprocess.Popen(
|
||||
["zip", "-r", "-0", "foregrounds_train.zip", "train/foregrounds", ">", "/dev/null", "2>&1"],
|
||||
cwd=out_root + folder,
|
||||
)
|
||||
p_val_bg = subprocess.Popen(
|
||||
["zip", "-r", "-0", "backgrounds_val.zip", "val/backgrounds", ">", "/dev/null", "2>&1"], cwd=out_root + folder
|
||||
)
|
||||
p_train_bg = subprocess.Popen(
|
||||
["zip", "-r", "-0", "backgrounds_train.zip", "train/backgrounds", ">", "/dev/null", "2>&1"],
|
||||
cwd=out_root + folder,
|
||||
)
|
||||
|
||||
p_val_fg.wait()
|
||||
p_train_fg.wait()
|
||||
p_val_bg.wait()
|
||||
p_train_bg.wait()
|
||||
|
||||
print(f"Gathering foreground size ratios for {folder}")
|
||||
p_val = subprocess.Popen(
|
||||
[
|
||||
"./srun-general.sh",
|
||||
"python",
|
||||
"foreground_size_ratio.py",
|
||||
"--root",
|
||||
args.dataset_root,
|
||||
"-ds",
|
||||
folder,
|
||||
"-mode",
|
||||
"val",
|
||||
],
|
||||
cwd=base_folder,
|
||||
)
|
||||
p_train = subprocess.Popen(
|
||||
["./srun-general.sh", "python", "foreground_size_ratio.py", "--root", args.dataset_root, "-ds", folder],
|
||||
cwd=base_folder,
|
||||
)
|
||||
p_val.wait()
|
||||
p_train.wait()
|
||||
Reference in New Issue
Block a user