Files
ForAug/AAAI Supplementary Material/Model Training Code/data/fornet.py
Tobias Christian Nauen ff34712155 AAAI Version
2026-02-24 12:22:44 +01:00

485 lines
22 KiB
Python

import json
import os
import zipfile
from io import BytesIO
from math import floor
import numpy as np
import PIL
from datadings.torch import Compose
from loguru import logger
from PIL import Image, ImageFilter
from torch.utils.data import Dataset, get_worker_info
from torchvision import transforms as T
from tqdm.auto import tqdm
from data.data_utils import apply_dense_transforms
class ForNet(Dataset):
"""Recombine ImageNet forgrounds and backgrounds.
Note:
This dataset has exactly the ImageNet classes.
"""
_back_combs = ["same", "all", "original"]
_bg_transforms = {T.RandomCrop, T.CenterCrop, T.Resize, T.RandomResizedCrop}
def __init__(
self,
root,
transform=None,
train=True,
target_transform=None,
background_combination="all",
fg_scale_jitter=0.3,
fg_transform=None,
pruning_ratio=0.8,
return_fg_masks=False,
fg_size_mode="range",
fg_bates_n=1,
paste_pre_transform=True,
mask_smoothing_sigma=4.0,
rel_jut_out=0.0,
fg_in_nonant=None,
size_fact=1.0,
orig_img_prob=0.0,
orig_ds=None,
_orig_ds_file_type="JPEG",
epochs=0,
_album_compose=False,
):
"""Create RecombinationNet dataset.
Args:
root (str): Root folder for the dataset.
transform (T.Collate | list, optional): Transform to apply to the image. Defaults to None.
train (bool, optional): On the train set (False -> val set). Defaults to True.
target_transform (T.Collate | list, optional): Transform to apply to the target values. Defaults to None.
background_combination (str, optional): Which backgrounds to combine with foregrounds. Defaults to "same".
fg_scale_jitter (tuple, optional): How much should the size of the foreground be changed (random ratio). Defaults to (0.1, 0.8).
fg_transform (_type_, optional): Transform to apply to the foreground before applying to the background. This is supposed to be a random rotation, mainly. Defaults to None.
pruning_ratio (float, optional): For pruning backgrounds, with (foreground size/background size) >= <pruning_ratio>. Backgrounds from images that contain very large foreground objects are mostly computer generated and therefore relatively unnatural. Defaults to full dataset.
return_fg_masks (bool, optional): Return the foreground masks. Defaults to False.
fg_size_mode (str, optional): How to determine the size of the foreground, based on the foreground sizes of the foreground and background images. Defaults to "max".
fg_bates_n (int, optional): Bates parameter for the distribution of the object position in the foreground. Defaults to 1 (uniform distribution). The higher the value, the more likely the object is in the center. For fg_bates_n = 0, the object is always in the center.
paste_pre_transform (bool, optional): Paste the foreground onto the background before applying the transform. If false, the background will be cropped and resized before pasting the foreground. Defaults to False.
mask_smoothing_sigma (float, optional): Sigma for the Gaussian blur of the mask edge. Defaults to 0.0. Try 2.0 or 4.0?
rel_jut_out (float, optional): How much is the foreground allowed to stand/jut out of the background (and then cut off). Defaults to 0.0.
fg_in_nonant (int, optional): If not None, the foreground will be placed in a specific nonant (0-8) of the image. Defaults to None.
size_fact (float, optional): Factor to multiply the size of the foreground with. Defaults to 1.0.
orig_img_prob (float | str, optional): Probability to use the original image, instead of the fg-bg recombinations. Defaults to 0.0.
orig_ds (Dataset, optional): Original dataset (without transforms) to use for the original images. Defaults to None.
_orig_ds_file_type (str, optional): File type of the original dataset. Defaults to "JPEG".
epochs (int, optional): Number of epochs to train on. Used for linear increase of orig_img_prob.
Note:
For more information on the bates distribution, see https://en.wikipedia.org/wiki/Bates_distribution.
For fg_bats_n < 0, we take extend the bates dirstribution to focus more and more on the edges. This is done by sampling B ~ Bates(|fg_bates_n|) and then passing through f(x) = x + 0.5 - floor(x + 0.5).
For the list of transformations that will be applied to the background only (if paste_pre_transform=False), see RecombinationNet._bg_transforms.
A nonant in this case refers to a square in a 3x3 grid dividing the image.
"""
assert (
background_combination in self._back_combs
), f"background_combination={background_combination} is not supported. Use one of {self._back_combs}"
if (not os.path.exists(f"{root}/backgrounds_{'train' if train else 'val'}.zip")) and os.path.exists(
os.path.join(root, "train" if train else "val", "backgrounds")
):
self._mode = "folder"
else:
self._mode = "zip"
if self._mode == "zip":
try:
with zipfile.ZipFile(f"{root}/backgrounds_{'train' if train else 'val'}.zip", "r") as bg_zip:
self.backgrounds = [f for f in bg_zip.namelist() if f.endswith(".JPEG")]
with zipfile.ZipFile(f"{root}/foregrounds_{'train' if train else 'val'}.zip", "r") as fg_zip:
self.foregrounds = [f for f in fg_zip.namelist() if f.endswith(".WEBP")]
except FileNotFoundError as e:
logger.error(
f"RecombinationNet: {e}. Make sure to have the background and foreground zips in the root"
f" directory: found {os.listdir(root)}"
)
raise e
classes = set([f.split("/")[-2] for f in self.foregrounds])
else:
logger.info("ForNet folder mode: loading classes")
classes = set(os.listdir(os.path.join(root, "train" if train else "val", "foregrounds")))
foregrounds = []
backgrounds = []
for cls in tqdm(classes, desc="Loading files"):
foregrounds.extend(
[
f"{cls}/{f}"
for f in os.listdir(os.path.join(root, "train" if train else "val", "foregrounds", cls))
]
)
backgrounds.extend(
[
f"{cls}/{f}"
for f in os.listdir(os.path.join(root, "train" if train else "val", "backgrounds", cls))
]
)
self.foregrounds = foregrounds
self.backgrounds = backgrounds
self.classes = sorted(list(classes), key=lambda x: int(x[1:]))
assert os.path.exists(f"{root}/fg_bg_ratios_{'train' if train else 'val'}.json"), (
f"{root}/fg_bg_ratios_{'train' if train else 'val'}.json not found, provide the information or set"
" pruning_ratio=1.0"
)
with open(f"{root}/fg_bg_ratios_{'train' if train else 'val'}.json", "r") as f:
self.fg_bg_ratios = json.load(f)
if self._mode == "folder":
self.fg_bg_ratios = {"/".join(key.split("/")[-2:]): val for key, val in self.fg_bg_ratios.items()}
logger.info(f"Renamed fg_bg_ratios keys to {list(self.fg_bg_ratios.keys())[:3]}...")
if pruning_ratio <= 1.0:
backup_backgrounds = {}
for bg_file in self.backgrounds:
bg_cls = bg_file.split("/")[-2]
backup_backgrounds[bg_cls] = bg_file
self.backgrounds = [
bg for bg in self.backgrounds if bg in self.fg_bg_ratios and self.fg_bg_ratios[bg] < pruning_ratio
]
# logger.info(
# f"RecombinationNet: keep {len(self.backgrounds)} of {len(self.fg_bg_ratios)} backgrounds with pr {pruning_ratio}"
# )
self.root = root
self.train = train
self.background_combination = background_combination
self.fg_scale_jitter = fg_scale_jitter
self.fg_transform = fg_transform
self.return_fg_masks = return_fg_masks
self.paste_pre_transform = paste_pre_transform
self.mask_smoothing_sigma = mask_smoothing_sigma
self.rel_jut_out = rel_jut_out
self.size_fact = size_fact
self.fg_in_nonant = fg_in_nonant
assert fg_in_nonant is None or -1 <= fg_in_nonant < 9, f"fg_in_nonant={fg_in_nonant} not in [0, 8] or None"
self.orig_img_prob = orig_img_prob
if orig_img_prob != 0.0:
assert (isinstance(orig_img_prob, float) and orig_img_prob > 0.0) or orig_img_prob in [
"linear",
"cos",
"revlinear",
]
assert orig_ds is not None, "orig_ds must be provided if orig_img_prob > 0.0"
assert not return_fg_masks, "can't provide fg masks for original images (yet)"
assert os.path.exists(os.path.join(root, f"{'train' if train else 'val'}_indices.json")) or isinstance(
orig_ds, str
), f"{root}/{'train' if train else 'val'}_indices.json must be provided if orig_ds is a dataset"
if not isinstance(orig_ds, str):
with open(os.path.join(root, f"{'train' if train else 'val'}_indices.json"), "r") as f:
self.key_to_orig_idx = json.load(f)
else:
if not (
orig_ds.endswith("train" if train else "val") or orig_ds.endswith("train/" if train else "val/")
):
orig_ds = f"{orig_ds}/{'train' if train else 'val'}"
self.key_to_orig_idx = None
self._orig_ds_file_type = _orig_ds_file_type
self.orig_ds = orig_ds
self.epochs = epochs
self._epoch = 0
assert fg_size_mode in [
"max",
"min",
"mean",
"range",
], f"fg_size_mode={fg_size_mode} not supported; use one of ['max', 'min', 'mean', 'range']"
self.fg_size_mode = fg_size_mode
self.fg_bates_n = fg_bates_n
if not paste_pre_transform:
if isinstance(transform, (T.Compose, Compose)):
transform = transform.transforms
# do cropping and resizing mainly on background; paste foreground on top later
self.bg_transform = []
self.join_transform = []
for tf in transform:
if isinstance(tf, tuple(self._bg_transforms)):
self.bg_transform.append(tf)
else:
self.join_transform.append(tf)
if _album_compose:
from data.album_transf import AlbumTorchCompose
self.bg_transform = AlbumTorchCompose(self.bg_transform)
self.join_transform = AlbumTorchCompose(self.join_transform)
else:
self.bg_transform = T.Compose(self.bg_transform)
self.join_transform = T.Compose(self.join_transform)
else:
if isinstance(transform, list):
if _album_compose:
from data.album_transf import AlbumTorchCompose
self.join_transform = AlbumTorchCompose(transform)
else:
self.join_transform = T.Compose(transform)
else:
self.join_transform = transform
self.bg_transform = None
self.trgt_map = {cls: i for i, cls in enumerate(self.classes)}
self.target_transform = target_transform
self.cls_to_allowed_bg = {}
for bg_file in self.backgrounds:
if background_combination == "same":
bg_cls = bg_file.split("/")[-2]
if bg_cls not in self.cls_to_allowed_bg:
self.cls_to_allowed_bg[bg_cls] = []
self.cls_to_allowed_bg[bg_cls].append(bg_file)
if background_combination == "same":
for cls_code in classes:
if cls_code not in self.cls_to_allowed_bg or len(self.cls_to_allowed_bg[cls_code]) == 0:
self.cls_to_allowed_bg[cls_code] = [backup_backgrounds[cls_code]]
logger.warning(f"No background for class {cls_code}, using {backup_backgrounds[cls_code]}")
self._zf = {}
@property
def epoch(self):
return self._epoch
@epoch.setter
def epoch(self, value):
self._epoch = value
def __len__(self):
"""Size of the dataset.
Returns:
int: number of foregrounds
"""
return len(self.foregrounds)
def num_classes(self):
return len(self.classes)
def _get_fg(self, idx):
worker_id = self._wrkr_info()
fg_file = self.foregrounds[idx]
with self._zf[worker_id]["fg"].open(fg_file) as f:
fg_data = BytesIO(f.read())
return Image.open(fg_data)
def _wrkr_info(self):
worker_id = get_worker_info().id if get_worker_info() else 0
if worker_id not in self._zf and self._mode == "zip":
self._zf[worker_id] = {
"bg": zipfile.ZipFile(f"{self.root}/backgrounds_{'train' if self.train else 'val'}.zip", "r"),
"fg": zipfile.ZipFile(f"{self.root}/foregrounds_{'train' if self.train else 'val'}.zip", "r"),
}
return worker_id
def __getitem__(self, idx):
"""Get the foreground at index idx and combine it with a (random) background.
Args:
idx (int): foreground index
Returns:
torch.Tensor, torch.Tensor: image, target
"""
worker_id = self._wrkr_info()
fg_file = self.foregrounds[idx]
trgt_cls = fg_file.split("/")[-2]
if (
(self.orig_img_prob == "linear" and np.random.rand() < self._epoch / self.epochs)
or (self.orig_img_prob == "revlinear" and np.random.rand() < (self._epoch - self.epochs) / self.epochs)
or (self.orig_img_prob == "cos" and np.random.rand() > np.cos(np.pi * self._epoch / (2 * self.epochs)))
or (
isinstance(self.orig_img_prob, float)
and self.orig_img_prob > 0.0
and np.random.rand() < self.orig_img_prob
)
):
data_key = f"{trgt_cls}/{fg_file.split('/')[-1].split('.')[0]}"
if isinstance(self.orig_ds, str):
image_file = os.path.join(self.orig_ds, f"{data_key}.{self._orig_ds_file_type}")
orig_img = Image.open(image_file).convert("RGB")
else:
orig_data = self.orig_ds[self.key_to_orig_idx[data_key]]
orig_img = orig_data["image"] if isinstance(orig_data, dict) else orig_data[0]
if self.bg_transform:
orig_img = self.bg_transform(orig_img)
if self.join_transform:
orig_img = self.join_transform(orig_img)
trgt = self.trgt_map[trgt_cls]
if self.target_transform:
trgt = self.target_transform(trgt)
return orig_img, trgt
if self._mode == "zip":
with self._zf[worker_id]["fg"].open(fg_file) as f:
fg_data = BytesIO(f.read())
try:
fg_img = Image.open(fg_data).convert("RGBA")
except PIL.UnidentifiedImageError as e:
logger.error(f"Error with idx={idx}, file={self.foregrounds[idx]}")
raise e
else:
# data_key = f"{trgt_cls}/{fg_file.split('/')[-1].split('.')[0]}"
fg_img = Image.open(
os.path.join(self.root, "train" if self.train else "val", "foregrounds", fg_file)
).convert("RGBA")
if self.fg_transform:
fg_img = self.fg_transform(fg_img)
fg_size_factor = T.ToTensor()(fg_img.split()[-1]).mean().item()
if self.background_combination == "all":
bg_idx = np.random.randint(len(self.backgrounds))
bg_file = self.backgrounds[bg_idx]
elif self.background_combination == "original":
bg_file = fg_file.replace("foregrounds", "backgrounds").replace("WEBP", "JPEG")
else:
bg_idx = np.random.randint(len(self.cls_to_allowed_bg[trgt_cls]))
bg_file = self.cls_to_allowed_bg[trgt_cls][bg_idx]
if self._mode == "zip":
with self._zf[worker_id]["bg"].open(bg_file) as f:
bg_data = BytesIO(f.read())
bg_img = Image.open(bg_data).convert("RGB")
else:
bg_img = Image.open(
os.path.join(self.root, "train" if self.train else "val", "backgrounds", bg_file)
).convert("RGB")
if not self.paste_pre_transform:
bg_img = self.bg_transform(bg_img)
bg_size = bg_img.size
# choose scale factor, such that relative area is in fg_scale
bg_area = bg_size[0] * bg_size[1]
if self.fg_in_nonant is not None:
bg_area = bg_area / 9
# logger.info(f"background: size={bg_size} area={bg_area}")
# logger.info(f"fg_file={fg_file}, fg_bg_ratio_keys={list(self.fg_bg_ratios.keys())[:3]}...")
orig_fg_ratio = self.fg_bg_ratios[fg_file.replace("foregrounds", "backgrounds").replace("WEBP", "JPEG")]
bg_fg_ratio = self.fg_bg_ratios[bg_file]
if self.fg_size_mode == "max":
goal_fg_ratio_lower = goal_fg_ratio_upper = max(orig_fg_ratio, bg_fg_ratio)
elif self.fg_size_mode == "min":
goal_fg_ratio_lower = goal_fg_ratio_upper = min(orig_fg_ratio, bg_fg_ratio)
elif self.fg_size_mode == "mean":
goal_fg_ratio_lower = goal_fg_ratio_upper = (orig_fg_ratio + bg_fg_ratio) / 2
else:
# range
goal_fg_ratio_lower = min(orig_fg_ratio, bg_fg_ratio)
goal_fg_ratio_upper = max(orig_fg_ratio, bg_fg_ratio)
# logger.info(f"fg_bg_ratio={orig_fg_ratio}")
fg_scale = (
np.random.uniform(
goal_fg_ratio_lower * (1 - self.fg_scale_jitter), goal_fg_ratio_upper * (1 + self.fg_scale_jitter)
)
/ fg_size_factor
* self.size_fact
)
goal_shape_y = round(np.sqrt(bg_area * fg_scale * fg_img.size[1] / fg_img.size[0]))
goal_shape_x = round(np.sqrt(bg_area * fg_scale * fg_img.size[0] / fg_img.size[1]))
fg_img = fg_img.resize((goal_shape_x, goal_shape_y))
if fg_img.size[0] > bg_size[0] or fg_img.size[1] > bg_size[1]:
# random crop to fit
goal_w, goal_h = (min(fg_img.size[0], bg_size[0]), min(fg_img.size[1], bg_size[1]))
fg_img = T.RandomCrop((goal_h, goal_w))(fg_img) if self.train else T.CenterCrop((goal_h, goal_w))(fg_img)
# paste fg on bg
z1, z2 = (
(
np.random.uniform(0, 1, abs(self.fg_bates_n)).mean(), # bates distribution n=1 => uniform
np.random.uniform(0, 1, abs(self.fg_bates_n)).mean(),
)
if self.fg_bates_n != 0
else (0.5, 0.5)
)
if self.fg_bates_n < 0:
z1 = z1 + 0.5 - floor(z1 + 0.5)
z2 = z2 + 0.5 - floor(z2 + 0.5)
x_min = -self.rel_jut_out * fg_img.size[0]
x_max = bg_size[0] - fg_img.size[0] * (1 - self.rel_jut_out)
y_min = -self.rel_jut_out * fg_img.size[1]
y_max = bg_size[1] - fg_img.size[1] * (1 - self.rel_jut_out)
if self.fg_in_nonant is not None and self.fg_in_nonant >= 0:
x_min = (self.fg_in_nonant % 3) * bg_size[0] / 3
x_max = ((self.fg_in_nonant % 3) + 1) * bg_size[0] / 3 - fg_img.size[0]
y_min = (self.fg_in_nonant // 3) * bg_size[1] / 3
y_max = ((self.fg_in_nonant // 3) + 1) * bg_size[1] / 3 - fg_img.size[1]
if x_min > x_max:
x_min = x_max = (x_min + x_max) / 2
if y_min > y_max:
y_min = y_max = (y_min + y_max) / 2
offs_x = round(z1 * (x_max - x_min) + x_min)
offs_y = round(z2 * (y_max - y_min) + y_min)
paste_mask = fg_img.split()[-1]
if self.mask_smoothing_sigma > 0.0:
sigma = (np.random.rand() * 0.9 + 0.1) * self.mask_smoothing_sigma
paste_mask = paste_mask.filter(ImageFilter.GaussianBlur(radius=sigma))
paste_mask = paste_mask.point(lambda p: 2 * p - 255 if p > 128 else 0)
bg_img.paste(fg_img.convert("RGB"), (offs_x, offs_y), paste_mask)
bg_img = bg_img.convert("RGB")
if self.return_fg_masks:
fg_mask = Image.new("L", bg_size, 0)
fg_mask.paste(paste_mask, (offs_x, offs_y))
fg_mask = T.ToTensor()(fg_mask)[0]
bg_img = T.ToTensor()(bg_img)
if self.join_transform:
# img_mask_stack = torch.cat([bg_img, fg_mask.unsqueeze(0)], dim=0)
# img_mask_stack = self.join_transform(img_mask_stack)
# bg_img, fg_mask = img_mask_stack[:-1], img_mask_stack[-1]
bg_img, fg_mask = apply_dense_transforms(bg_img, fg_mask, self.join_transform)
else:
bg_img = self.join_transform(bg_img)
if trgt_cls not in self.trgt_map:
raise ValueError(f"trgt_cls={trgt_cls} not in trgt_map: {self.trgt_map}")
trgt = self.trgt_map[trgt_cls]
if self.target_transform:
trgt = self.target_transform(trgt)
if self.return_fg_masks:
return bg_img, trgt, fg_mask
return bg_img, trgt