import argparse import json import os from datetime import datetime from math import ceil import numpy as np import torch from datadings.reader import MsgpackReader from datadings.torch import Compose, CompressedToPIL, Dataset from PIL import Image, ImageFilter from torchvision.transforms import ToPILImage, ToTensor from tqdm.auto import tqdm from attentive_eraser import AttentiveEraser from grounded_segmentation import grounded_segmentation from infill_lama import LaMa from utils import already_segmented, save_img from wordnet_tree import WNTree def _collate_single(data): return data[0] def _collate_multiple(datas): return { "images": [data["image"] for data in datas], "keys": [data["key"] for data in datas], "labels": [data["label"] for data in datas], } def smallest_crop(image: torch.Tensor, mask: torch.Tensor): """Crops the image to just so fit the mask given. Mask and image have to be of the same size. Args: image (torch.Tensor): image to crop mask (torch.Tensor): cropping to mask Returns: torch.Tensor: cropped image """ if len(mask.shape) == 3: assert mask.shape[0] == 1 mask = mask[0] assert ( len(image.shape) == 3 and len(mask.shape) == 2 and image.shape[1:] == mask.shape ), f"Invalid shapes: {image.shape}, {mask.shape}" dim0_indices = mask.sum(dim=0).nonzero() dim1_indices = mask.sum(dim=1).nonzero() return image[ :, dim1_indices.min().item() : dim1_indices.max().item() + 1, dim0_indices.min().item() : dim0_indices.max().item() + 1, ] def _synset_to_prompt(synset, tree, parent_in_prompt): prompt = f"an {synset.print_name}." if synset.print_name[0] in "aeiou" else f"a {synset.print_name}." if synset.parent_id is None or not parent_in_prompt: return prompt parent = tree[synset.parent_id] return prompt[:-1] + f", a type of {parent.print_name}." if __name__ == "__main__": parser = argparse.ArgumentParser(description="Grounded Segmentation") parser.add_argument( "-d", "--dataset", choices=["imagenet", "imagenet-val", "imagenet21k", "tinyimagenet", "tinyimagenet-val"], default="imagenet", help="Dataset to use", ) parser.add_argument("-r", "--dataset_root", required=True, help="Root directory of the dataset") parser.add_argument("-b", "--batch_size", type=int, default=1, help="Batch size") parser.add_argument("-t", "--threshold", type=float, default=0.3, help="Detection threshold") parser.add_argument("--debug", action="store_true", help="Debug mode") parser.add_argument("-o", "--output", type=str, required=True, help="Output directory") parser.add_argument( "-p", "--processes", type=int, default=1, help="Number of processes that are used to process the data" ) parser.add_argument("-id", type=int, default=0, help="ID of this process)") parser.add_argument("--overwrite", action="store_true", help="Overwrite existing files, instead of skipping them") parser.add_argument( "--parent_labels", type=int, default=2, help="Number of parent labels to use; steps to go up the tree" ) parser.add_argument("--output_ims", choices=["best", "all"], default="best", help="Output all or best masks") parser.add_argument("--mask_merge_threshold", type=float, default=0.9, help="Threshold on IoU for merging masks") parser.add_argument("--parent_in_prompt", action="store_true", help="Include parent label in the prompt") parser.add_argument( "-model", choices=["LaMa", "AttErase"], default="LaMa", help="Model to use for erasing/infilling. Defaults to LaMa.", ) args = parser.parse_args() dataset = args.dataset.lower() part = "train" if dataset == "imagenet21k": reader = MsgpackReader(f"{args.dataset_root}imagenet21k/train.msgpack") dataset = Dataset(reader, transforms={"image": Compose([CompressedToPIL()])}) elif dataset == "imagenet-val": reader = MsgpackReader(f"{args.dataset_root}imagenet/msgpack/val.msgpack") dataset = Dataset(reader, transforms={"image": Compose([CompressedToPIL()])}) part = "val" elif dataset == "imagenet": reader = MsgpackReader(f"{args.dataset_root}imagenet/msgpack/train.msgpack") dataset = Dataset(reader, transforms={"image": Compose([CompressedToPIL()])}) elif dataset == "tinyimagenet": reader = MsgpackReader(f"{args.dataset_root}TinyINSegment/TinyIN_train.msgpack") dataset = Dataset(reader, transforms={"image": Compose([CompressedToPIL()])}) elif dataset == "tinyimagenet-val": reader = MsgpackReader(f"{args.dataset_root}TinyINSegment/TinyIN_val.msgpack") dataset = Dataset(reader, transforms={"image": Compose([CompressedToPIL()])}) part = "val" else: raise ValueError(f"Unknown dataset: {dataset}") assert 0 <= args.id < args.processes, "ID must be in the range [0, processes)" if args.processes > 1: partlen = ceil(len(dataset) / args.processes) start_id = args.id * partlen end_id = min((args.id + 1) * partlen, len(dataset)) dataset = torch.utils.data.Subset(dataset, list(range(start_id, end_id))) assert args.batch_size == 1, "Batch size must be 1 for grounded segmentation (for now)" dataset = torch.utils.data.DataLoader( dataset, batch_size=args.batch_size, shuffle=args.debug, drop_last=False, num_workers=10, collate_fn=_collate_single if args.batch_size == 1 else _collate_multiple, ) infill_model = ( LaMa(device="cuda" if torch.cuda.is_available() else "cpu") if args.model == "LaMa" else AttentiveEraser() ) if args.dataset.startswith("imagenet"): with open("wordnet_data/imagenet1k_synsets.json", "r") as f: id_to_synset = json.load(f) id_to_synset = {int(k): v for k, v in id_to_synset.items()} elif args.dataset.startswith("tinyimagenet"): with open("wordnet_data/tinyimagenet_synset_names.txt", "r") as f: synsets = f.readlines() id_to_synset = [int(synset.split(":")[0].strip()[1:]) for synset in synsets] id_to_synset = sorted(id_to_synset) wordnet = WNTree.load("wordnet_data/imagenet21k+1k_masses_tree.json") # create the necessary folders os.makedirs(os.path.join(args.output, part, "foregrounds"), exist_ok=True) os.makedirs(os.path.join(args.output, part, "backgrounds"), exist_ok=True) os.makedirs(os.path.join(args.output, part, "no_detect"), exist_ok=True) os.makedirs(os.path.join(args.output, part, "error"), exist_ok=True) man_print_progress = int(os.environ.get("TQDM_DISABLE", "0")) == 1 for i, data in tqdm(enumerate(dataset), total=len(dataset)): if args.debug and i > 10: break if man_print_progress and i % 200 == 0: print(f"{datetime.now()} \t process {args.id}/{args.processes}: \t sample: {i}/{len(dataset)}", flush=True) image = data["image"] key = data["key"] if args.dataset.endswith("-val"): img_class = f"n{id_to_synset[data['label']]:08d}" # overwrite the synset id on the val set else: img_class = key.split("/")[-1].split("_")[0] if ( already_segmented(key, os.path.join(args.output, part, "foregrounds"), img_class=img_class) and not args.overwrite ): continue # get the next 3 labels up in the wordnet tree label_id = data["label"] offset = id_to_synset[label_id] # int(data["key"].split("_")[0][1:]) synset = wordnet[offset] labels = [_synset_to_prompt(synset, wordnet, args.parent_in_prompt)] while len(labels) < 1 + args.parent_labels and synset.parent_id is not None: synset = wordnet[synset.parent_id] _label = _synset_to_prompt(synset, wordnet, args.parent_in_prompt) _label = _label.replace("_", " ") labels.append(_label) assert len(labels) > 0, f"No labels for {key}" # 1. detect and segment image_tensor, detections = grounded_segmentation( image, labels, threshold=args.threshold, polygon_refinement=True ) # merge all detections for the same label masks = {} for detection in detections: if detection.label not in masks: masks[detection.label] = detection.mask else: masks[detection.label] |= detection.mask labels = list(masks.keys()) masks = [masks[lbl] for lbl in labels] assert len(masks) == len( labels ), f"Number of masks ({len(masks)}) != number of labels ({len(labels)}); detections: {detections}" if args.debug: for detected_mask, label in zip(masks, labels): detected_mask = torch.from_numpy(detected_mask).unsqueeze(0) / 255 mask_foreground = torch.cat((image_tensor * detected_mask, detected_mask), dim=0) ToPILImage()(mask_foreground).save(f"example_images/{key}_{label}_masked.png") if len(masks) == 0: tqdm.write(f"No detections for {key}; skipping") save_img(image, key, os.path.join(args.output, part, "no_detect"), img_class=img_class, format="JPEG") continue elif len(masks) == 1: # max_overlap_mask = masks[0] masks = [masks[0]] elif args.output_ims == "best": # find the 2 masks with the largest overlap max_overlap = 0 max_overlap_mask = None using_labels = None for i, mask1 in enumerate(masks): for j, mask2 in enumerate(masks): if i >= j: continue iou = (mask1 & mask2).sum() / (mask1 | mask2).sum() if iou > max_overlap: max_overlap = iou max_overlap_mask = mask1 | mask2 using_labels = f"{labels[i]} & {labels[j]}" if args.debug: tqdm.write(f"{key}:\tMax overlap: {max_overlap}, using {using_labels}") if max_overlap_mask is None: # assert len(masks) > 0 and len(masks) == len(labels), f"No detections for {key}" max_overlap_mask = masks[0] masks = [max_overlap_mask] else: # merge masks that are too similar has_changed = True while has_changed: has_changed = False for i, mask1 in enumerate(masks): for j, mask2 in enumerate(masks): if i >= j: continue iou = (mask1 & mask2).sum() / (mask1 | mask2).sum() if iou > args.mask_merge_threshold: masks[i] |= masks[j] masks.pop(j) labels.pop(j) has_changed = True break if has_changed: break for mask_idx, mask_array in enumerate(masks): mask = torch.from_numpy(mask_array).unsqueeze(0) / 255 if args.debug: mask_image = ToPILImage()(mask) mask_image.save(f"example_images/{key}_mask.png") # foreground = image_tensor foreground = torch.cat((image_tensor * mask, mask), dim=0) foreground = ToPILImage()(foreground) mask_img = foreground.split()[-1] foreground = smallest_crop(ToTensor()(foreground), ToTensor()(mask_img)) # TODO: fix smallest crop fg_img = ToPILImage()(foreground) background = (1 - mask) * image_tensor bg_image = ToPILImage()(background) # 2. infill background mask_image = Image.fromarray(mask_array) mask_image = mask_image.filter(ImageFilter.GaussianBlur(radius=7)) try: infilled_bg = infill_model(np.array(bg_image), np.array(mask_image)) infilled_bg = Image.fromarray(np.uint8(infilled_bg)) except RuntimeError as e: tqdm.write( f"Error infilling {key}: bg_image.shape={np.array(bg_image).shape}," f" mask_image.shape={np.array(mask_image).shape}\n{e}" ) save_img( image, key, os.path.join(args.output, part, "error"), img_class=img_class, img_version=mask_idx if len(masks) > 1 else None, format="JPEG", ) infilled_bg = None if args.debug: data["image"].save(f"example_images/{key}_orig.png") fg_img.save(f"example_images/{key}_fg.png") infilled_bg.save(f"example_images/{key}_bg.png") else: # save files class_name = key.split("_")[0] save_img( fg_img, key, os.path.join(args.output, part, "foregrounds"), img_class=img_class, img_version=mask_idx if len(masks) > 1 else None, format="WEBP", ) if infilled_bg is not None: save_img( infilled_bg, key, os.path.join(args.output, part, "backgrounds"), img_class=img_class, img_version=mask_idx if len(masks) > 1 else None, format="JPEG", ) if man_print_progress: print( f"{datetime.now()} \t process {args.id}/{args.processes}: \t done with all {len(dataset)} samples", flush=True, )