339 lines
14 KiB
Python
339 lines
14 KiB
Python
import argparse
|
|
import json
|
|
import os
|
|
from datetime import datetime
|
|
from math import ceil
|
|
|
|
import numpy as np
|
|
import torch
|
|
from datadings.reader import MsgpackReader
|
|
from datadings.torch import Compose, CompressedToPIL, Dataset
|
|
from PIL import Image, ImageFilter
|
|
from torchvision.transforms import ToPILImage, ToTensor
|
|
from tqdm.auto import tqdm
|
|
|
|
from attentive_eraser import AttentiveEraser
|
|
from grounded_segmentation import grounded_segmentation
|
|
from infill_lama import LaMa
|
|
from utils import already_segmented, save_img
|
|
from wordnet_tree import WNTree
|
|
|
|
|
|
def _collate_single(data):
|
|
return data[0]
|
|
|
|
|
|
def _collate_multiple(datas):
|
|
return {
|
|
"images": [data["image"] for data in datas],
|
|
"keys": [data["key"] for data in datas],
|
|
"labels": [data["label"] for data in datas],
|
|
}
|
|
|
|
|
|
def smallest_crop(image: torch.Tensor, mask: torch.Tensor):
|
|
"""Crops the image to just so fit the mask given.
|
|
|
|
Mask and image have to be of the same size.
|
|
|
|
Args:
|
|
image (torch.Tensor): image to crop
|
|
mask (torch.Tensor): cropping to mask
|
|
|
|
Returns:
|
|
torch.Tensor: cropped image
|
|
|
|
"""
|
|
if len(mask.shape) == 3:
|
|
assert mask.shape[0] == 1
|
|
mask = mask[0]
|
|
assert (
|
|
len(image.shape) == 3 and len(mask.shape) == 2 and image.shape[1:] == mask.shape
|
|
), f"Invalid shapes: {image.shape}, {mask.shape}"
|
|
dim0_indices = mask.sum(dim=0).nonzero()
|
|
dim1_indices = mask.sum(dim=1).nonzero()
|
|
|
|
return image[
|
|
:,
|
|
dim1_indices.min().item() : dim1_indices.max().item() + 1,
|
|
dim0_indices.min().item() : dim0_indices.max().item() + 1,
|
|
]
|
|
|
|
|
|
def _synset_to_prompt(synset, tree, parent_in_prompt):
|
|
prompt = f"an {synset.print_name}." if synset.print_name[0] in "aeiou" else f"a {synset.print_name}."
|
|
if synset.parent_id is None or not parent_in_prompt:
|
|
return prompt
|
|
parent = tree[synset.parent_id]
|
|
return prompt[:-1] + f", a type of {parent.print_name}."
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="Grounded Segmentation")
|
|
parser.add_argument(
|
|
"-d",
|
|
"--dataset",
|
|
choices=["imagenet", "imagenet-val", "imagenet21k", "tinyimagenet", "tinyimagenet-val"],
|
|
default="imagenet",
|
|
help="Dataset to use",
|
|
)
|
|
parser.add_argument("-r", "--dataset_root", required=True, help="Root directory of the dataset")
|
|
parser.add_argument("-b", "--batch_size", type=int, default=1, help="Batch size")
|
|
parser.add_argument("-t", "--threshold", type=float, default=0.3, help="Detection threshold")
|
|
parser.add_argument("--debug", action="store_true", help="Debug mode")
|
|
parser.add_argument("-o", "--output", type=str, required=True, help="Output directory")
|
|
parser.add_argument(
|
|
"-p", "--processes", type=int, default=1, help="Number of processes that are used to process the data"
|
|
)
|
|
parser.add_argument("-id", type=int, default=0, help="ID of this process)")
|
|
parser.add_argument("--overwrite", action="store_true", help="Overwrite existing files, instead of skipping them")
|
|
parser.add_argument(
|
|
"--parent_labels", type=int, default=2, help="Number of parent labels to use; steps to go up the tree"
|
|
)
|
|
parser.add_argument("--output_ims", choices=["best", "all"], default="best", help="Output all or best masks")
|
|
parser.add_argument("--mask_merge_threshold", type=float, default=0.9, help="Threshold on IoU for merging masks")
|
|
parser.add_argument("--parent_in_prompt", action="store_true", help="Include parent label in the prompt")
|
|
parser.add_argument(
|
|
"-model",
|
|
choices=["LaMa", "AttErase"],
|
|
default="LaMa",
|
|
help="Model to use for erasing/infilling. Defaults to LaMa.",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
dataset = args.dataset.lower()
|
|
part = "train"
|
|
if dataset == "imagenet21k":
|
|
reader = MsgpackReader(f"{args.dataset_root}imagenet21k/train.msgpack")
|
|
dataset = Dataset(reader, transforms={"image": Compose([CompressedToPIL()])})
|
|
elif dataset == "imagenet-val":
|
|
reader = MsgpackReader(f"{args.dataset_root}imagenet/msgpack/val.msgpack")
|
|
dataset = Dataset(reader, transforms={"image": Compose([CompressedToPIL()])})
|
|
part = "val"
|
|
elif dataset == "imagenet":
|
|
reader = MsgpackReader(f"{args.dataset_root}imagenet/msgpack/train.msgpack")
|
|
dataset = Dataset(reader, transforms={"image": Compose([CompressedToPIL()])})
|
|
elif dataset == "tinyimagenet":
|
|
reader = MsgpackReader(f"{args.dataset_root}TinyINSegment/TinyIN_train.msgpack")
|
|
dataset = Dataset(reader, transforms={"image": Compose([CompressedToPIL()])})
|
|
elif dataset == "tinyimagenet-val":
|
|
reader = MsgpackReader(f"{args.dataset_root}TinyINSegment/TinyIN_val.msgpack")
|
|
dataset = Dataset(reader, transforms={"image": Compose([CompressedToPIL()])})
|
|
part = "val"
|
|
else:
|
|
raise ValueError(f"Unknown dataset: {dataset}")
|
|
|
|
assert 0 <= args.id < args.processes, "ID must be in the range [0, processes)"
|
|
if args.processes > 1:
|
|
partlen = ceil(len(dataset) / args.processes)
|
|
start_id = args.id * partlen
|
|
end_id = min((args.id + 1) * partlen, len(dataset))
|
|
dataset = torch.utils.data.Subset(dataset, list(range(start_id, end_id)))
|
|
|
|
assert args.batch_size == 1, "Batch size must be 1 for grounded segmentation (for now)"
|
|
dataset = torch.utils.data.DataLoader(
|
|
dataset,
|
|
batch_size=args.batch_size,
|
|
shuffle=args.debug,
|
|
drop_last=False,
|
|
num_workers=10,
|
|
collate_fn=_collate_single if args.batch_size == 1 else _collate_multiple,
|
|
)
|
|
|
|
infill_model = (
|
|
LaMa(device="cuda" if torch.cuda.is_available() else "cpu") if args.model == "LaMa" else AttentiveEraser()
|
|
)
|
|
|
|
if args.dataset.startswith("imagenet"):
|
|
with open("wordnet_data/imagenet1k_synsets.json", "r") as f:
|
|
id_to_synset = json.load(f)
|
|
id_to_synset = {int(k): v for k, v in id_to_synset.items()}
|
|
elif args.dataset.startswith("tinyimagenet"):
|
|
with open("wordnet_data/tinyimagenet_synset_names.txt", "r") as f:
|
|
synsets = f.readlines()
|
|
id_to_synset = [int(synset.split(":")[0].strip()[1:]) for synset in synsets]
|
|
id_to_synset = sorted(id_to_synset)
|
|
|
|
wordnet = WNTree.load("wordnet_data/imagenet21k+1k_masses_tree.json")
|
|
|
|
# create the necessary folders
|
|
os.makedirs(os.path.join(args.output, part, "foregrounds"), exist_ok=True)
|
|
os.makedirs(os.path.join(args.output, part, "backgrounds"), exist_ok=True)
|
|
os.makedirs(os.path.join(args.output, part, "no_detect"), exist_ok=True)
|
|
os.makedirs(os.path.join(args.output, part, "error"), exist_ok=True)
|
|
|
|
man_print_progress = int(os.environ.get("TQDM_DISABLE", "0")) == 1
|
|
|
|
for i, data in tqdm(enumerate(dataset), total=len(dataset)):
|
|
if args.debug and i > 10:
|
|
break
|
|
|
|
if man_print_progress and i % 200 == 0:
|
|
print(f"{datetime.now()} \t process {args.id}/{args.processes}: \t sample: {i}/{len(dataset)}", flush=True)
|
|
|
|
image = data["image"]
|
|
key = data["key"]
|
|
|
|
if args.dataset.endswith("-val"):
|
|
img_class = f"n{id_to_synset[data['label']]:08d}" # overwrite the synset id on the val set
|
|
else:
|
|
img_class = key.split("/")[-1].split("_")[0]
|
|
|
|
if (
|
|
already_segmented(key, os.path.join(args.output, part, "foregrounds"), img_class=img_class)
|
|
and not args.overwrite
|
|
):
|
|
continue
|
|
|
|
# get the next 3 labels up in the wordnet tree
|
|
label_id = data["label"]
|
|
offset = id_to_synset[label_id] # int(data["key"].split("_")[0][1:])
|
|
synset = wordnet[offset]
|
|
labels = [_synset_to_prompt(synset, wordnet, args.parent_in_prompt)]
|
|
while len(labels) < 1 + args.parent_labels and synset.parent_id is not None:
|
|
synset = wordnet[synset.parent_id]
|
|
_label = _synset_to_prompt(synset, wordnet, args.parent_in_prompt)
|
|
_label = _label.replace("_", " ")
|
|
labels.append(_label)
|
|
|
|
assert len(labels) > 0, f"No labels for {key}"
|
|
|
|
# 1. detect and segment
|
|
image_tensor, detections = grounded_segmentation(
|
|
image, labels, threshold=args.threshold, polygon_refinement=True
|
|
)
|
|
|
|
# merge all detections for the same label
|
|
masks = {}
|
|
for detection in detections:
|
|
if detection.label not in masks:
|
|
masks[detection.label] = detection.mask
|
|
else:
|
|
masks[detection.label] |= detection.mask
|
|
labels = list(masks.keys())
|
|
masks = [masks[lbl] for lbl in labels]
|
|
assert len(masks) == len(
|
|
labels
|
|
), f"Number of masks ({len(masks)}) != number of labels ({len(labels)}); detections: {detections}"
|
|
|
|
if args.debug:
|
|
for detected_mask, label in zip(masks, labels):
|
|
detected_mask = torch.from_numpy(detected_mask).unsqueeze(0) / 255
|
|
mask_foreground = torch.cat((image_tensor * detected_mask, detected_mask), dim=0)
|
|
ToPILImage()(mask_foreground).save(f"example_images/{key}_{label}_masked.png")
|
|
|
|
if len(masks) == 0:
|
|
tqdm.write(f"No detections for {key}; skipping")
|
|
save_img(image, key, os.path.join(args.output, part, "no_detect"), img_class=img_class, format="JPEG")
|
|
continue
|
|
elif len(masks) == 1:
|
|
# max_overlap_mask = masks[0]
|
|
masks = [masks[0]]
|
|
elif args.output_ims == "best":
|
|
# find the 2 masks with the largest overlap
|
|
max_overlap = 0
|
|
max_overlap_mask = None
|
|
using_labels = None
|
|
for i, mask1 in enumerate(masks):
|
|
for j, mask2 in enumerate(masks):
|
|
if i >= j:
|
|
continue
|
|
iou = (mask1 & mask2).sum() / (mask1 | mask2).sum()
|
|
if iou > max_overlap:
|
|
max_overlap = iou
|
|
max_overlap_mask = mask1 | mask2
|
|
using_labels = f"{labels[i]} & {labels[j]}"
|
|
if args.debug:
|
|
tqdm.write(f"{key}:\tMax overlap: {max_overlap}, using {using_labels}")
|
|
if max_overlap_mask is None:
|
|
# assert len(masks) > 0 and len(masks) == len(labels), f"No detections for {key}"
|
|
max_overlap_mask = masks[0]
|
|
masks = [max_overlap_mask]
|
|
else:
|
|
# merge masks that are too similar
|
|
has_changed = True
|
|
while has_changed:
|
|
has_changed = False
|
|
for i, mask1 in enumerate(masks):
|
|
for j, mask2 in enumerate(masks):
|
|
if i >= j:
|
|
continue
|
|
iou = (mask1 & mask2).sum() / (mask1 | mask2).sum()
|
|
if iou > args.mask_merge_threshold:
|
|
masks[i] |= masks[j]
|
|
masks.pop(j)
|
|
labels.pop(j)
|
|
has_changed = True
|
|
break
|
|
if has_changed:
|
|
break
|
|
|
|
for mask_idx, mask_array in enumerate(masks):
|
|
mask = torch.from_numpy(mask_array).unsqueeze(0) / 255
|
|
|
|
if args.debug:
|
|
mask_image = ToPILImage()(mask)
|
|
mask_image.save(f"example_images/{key}_mask.png")
|
|
|
|
# foreground = image_tensor
|
|
foreground = torch.cat((image_tensor * mask, mask), dim=0)
|
|
foreground = ToPILImage()(foreground)
|
|
mask_img = foreground.split()[-1]
|
|
foreground = smallest_crop(ToTensor()(foreground), ToTensor()(mask_img)) # TODO: fix smallest crop
|
|
fg_img = ToPILImage()(foreground)
|
|
|
|
background = (1 - mask) * image_tensor
|
|
bg_image = ToPILImage()(background)
|
|
|
|
# 2. infill background
|
|
mask_image = Image.fromarray(mask_array)
|
|
mask_image = mask_image.filter(ImageFilter.GaussianBlur(radius=7))
|
|
try:
|
|
infilled_bg = infill_model(np.array(bg_image), np.array(mask_image))
|
|
infilled_bg = Image.fromarray(np.uint8(infilled_bg))
|
|
except RuntimeError as e:
|
|
tqdm.write(
|
|
f"Error infilling {key}: bg_image.shape={np.array(bg_image).shape},"
|
|
f" mask_image.shape={np.array(mask_image).shape}\n{e}"
|
|
)
|
|
save_img(
|
|
image,
|
|
key,
|
|
os.path.join(args.output, part, "error"),
|
|
img_class=img_class,
|
|
img_version=mask_idx if len(masks) > 1 else None,
|
|
format="JPEG",
|
|
)
|
|
infilled_bg = None
|
|
|
|
if args.debug:
|
|
data["image"].save(f"example_images/{key}_orig.png")
|
|
fg_img.save(f"example_images/{key}_fg.png")
|
|
infilled_bg.save(f"example_images/{key}_bg.png")
|
|
else:
|
|
# save files
|
|
class_name = key.split("_")[0]
|
|
save_img(
|
|
fg_img,
|
|
key,
|
|
os.path.join(args.output, part, "foregrounds"),
|
|
img_class=img_class,
|
|
img_version=mask_idx if len(masks) > 1 else None,
|
|
format="WEBP",
|
|
)
|
|
if infilled_bg is not None:
|
|
save_img(
|
|
infilled_bg,
|
|
key,
|
|
os.path.join(args.output, part, "backgrounds"),
|
|
img_class=img_class,
|
|
img_version=mask_idx if len(masks) > 1 else None,
|
|
format="JPEG",
|
|
)
|
|
if man_print_progress:
|
|
print(
|
|
f"{datetime.now()} \t process {args.id}/{args.processes}: \t done with all {len(dataset)} samples",
|
|
flush=True,
|
|
)
|