AAAI Version
This commit is contained in:
484
AAAI Supplementary Material/Model Training Code/data/fornet.py
Normal file
484
AAAI Supplementary Material/Model Training Code/data/fornet.py
Normal file
@@ -0,0 +1,484 @@
|
||||
import json
|
||||
import os
|
||||
import zipfile
|
||||
from io import BytesIO
|
||||
from math import floor
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
from datadings.torch import Compose
|
||||
from loguru import logger
|
||||
from PIL import Image, ImageFilter
|
||||
from torch.utils.data import Dataset, get_worker_info
|
||||
from torchvision import transforms as T
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from data.data_utils import apply_dense_transforms
|
||||
|
||||
|
||||
class ForNet(Dataset):
|
||||
"""Recombine ImageNet forgrounds and backgrounds.
|
||||
|
||||
Note:
|
||||
This dataset has exactly the ImageNet classes.
|
||||
|
||||
"""
|
||||
|
||||
_back_combs = ["same", "all", "original"]
|
||||
_bg_transforms = {T.RandomCrop, T.CenterCrop, T.Resize, T.RandomResizedCrop}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root,
|
||||
transform=None,
|
||||
train=True,
|
||||
target_transform=None,
|
||||
background_combination="all",
|
||||
fg_scale_jitter=0.3,
|
||||
fg_transform=None,
|
||||
pruning_ratio=0.8,
|
||||
return_fg_masks=False,
|
||||
fg_size_mode="range",
|
||||
fg_bates_n=1,
|
||||
paste_pre_transform=True,
|
||||
mask_smoothing_sigma=4.0,
|
||||
rel_jut_out=0.0,
|
||||
fg_in_nonant=None,
|
||||
size_fact=1.0,
|
||||
orig_img_prob=0.0,
|
||||
orig_ds=None,
|
||||
_orig_ds_file_type="JPEG",
|
||||
epochs=0,
|
||||
_album_compose=False,
|
||||
):
|
||||
"""Create RecombinationNet dataset.
|
||||
|
||||
Args:
|
||||
root (str): Root folder for the dataset.
|
||||
transform (T.Collate | list, optional): Transform to apply to the image. Defaults to None.
|
||||
train (bool, optional): On the train set (False -> val set). Defaults to True.
|
||||
target_transform (T.Collate | list, optional): Transform to apply to the target values. Defaults to None.
|
||||
background_combination (str, optional): Which backgrounds to combine with foregrounds. Defaults to "same".
|
||||
fg_scale_jitter (tuple, optional): How much should the size of the foreground be changed (random ratio). Defaults to (0.1, 0.8).
|
||||
fg_transform (_type_, optional): Transform to apply to the foreground before applying to the background. This is supposed to be a random rotation, mainly. Defaults to None.
|
||||
pruning_ratio (float, optional): For pruning backgrounds, with (foreground size/background size) >= <pruning_ratio>. Backgrounds from images that contain very large foreground objects are mostly computer generated and therefore relatively unnatural. Defaults to full dataset.
|
||||
return_fg_masks (bool, optional): Return the foreground masks. Defaults to False.
|
||||
fg_size_mode (str, optional): How to determine the size of the foreground, based on the foreground sizes of the foreground and background images. Defaults to "max".
|
||||
fg_bates_n (int, optional): Bates parameter for the distribution of the object position in the foreground. Defaults to 1 (uniform distribution). The higher the value, the more likely the object is in the center. For fg_bates_n = 0, the object is always in the center.
|
||||
paste_pre_transform (bool, optional): Paste the foreground onto the background before applying the transform. If false, the background will be cropped and resized before pasting the foreground. Defaults to False.
|
||||
mask_smoothing_sigma (float, optional): Sigma for the Gaussian blur of the mask edge. Defaults to 0.0. Try 2.0 or 4.0?
|
||||
rel_jut_out (float, optional): How much is the foreground allowed to stand/jut out of the background (and then cut off). Defaults to 0.0.
|
||||
fg_in_nonant (int, optional): If not None, the foreground will be placed in a specific nonant (0-8) of the image. Defaults to None.
|
||||
size_fact (float, optional): Factor to multiply the size of the foreground with. Defaults to 1.0.
|
||||
orig_img_prob (float | str, optional): Probability to use the original image, instead of the fg-bg recombinations. Defaults to 0.0.
|
||||
orig_ds (Dataset, optional): Original dataset (without transforms) to use for the original images. Defaults to None.
|
||||
_orig_ds_file_type (str, optional): File type of the original dataset. Defaults to "JPEG".
|
||||
epochs (int, optional): Number of epochs to train on. Used for linear increase of orig_img_prob.
|
||||
|
||||
Note:
|
||||
For more information on the bates distribution, see https://en.wikipedia.org/wiki/Bates_distribution.
|
||||
For fg_bats_n < 0, we take extend the bates dirstribution to focus more and more on the edges. This is done by sampling B ~ Bates(|fg_bates_n|) and then passing through f(x) = x + 0.5 - floor(x + 0.5).
|
||||
|
||||
For the list of transformations that will be applied to the background only (if paste_pre_transform=False), see RecombinationNet._bg_transforms.
|
||||
|
||||
A nonant in this case refers to a square in a 3x3 grid dividing the image.
|
||||
|
||||
"""
|
||||
assert (
|
||||
background_combination in self._back_combs
|
||||
), f"background_combination={background_combination} is not supported. Use one of {self._back_combs}"
|
||||
|
||||
if (not os.path.exists(f"{root}/backgrounds_{'train' if train else 'val'}.zip")) and os.path.exists(
|
||||
os.path.join(root, "train" if train else "val", "backgrounds")
|
||||
):
|
||||
self._mode = "folder"
|
||||
else:
|
||||
self._mode = "zip"
|
||||
|
||||
if self._mode == "zip":
|
||||
try:
|
||||
with zipfile.ZipFile(f"{root}/backgrounds_{'train' if train else 'val'}.zip", "r") as bg_zip:
|
||||
self.backgrounds = [f for f in bg_zip.namelist() if f.endswith(".JPEG")]
|
||||
with zipfile.ZipFile(f"{root}/foregrounds_{'train' if train else 'val'}.zip", "r") as fg_zip:
|
||||
self.foregrounds = [f for f in fg_zip.namelist() if f.endswith(".WEBP")]
|
||||
except FileNotFoundError as e:
|
||||
logger.error(
|
||||
f"RecombinationNet: {e}. Make sure to have the background and foreground zips in the root"
|
||||
f" directory: found {os.listdir(root)}"
|
||||
)
|
||||
raise e
|
||||
classes = set([f.split("/")[-2] for f in self.foregrounds])
|
||||
else:
|
||||
logger.info("ForNet folder mode: loading classes")
|
||||
classes = set(os.listdir(os.path.join(root, "train" if train else "val", "foregrounds")))
|
||||
foregrounds = []
|
||||
backgrounds = []
|
||||
for cls in tqdm(classes, desc="Loading files"):
|
||||
foregrounds.extend(
|
||||
[
|
||||
f"{cls}/{f}"
|
||||
for f in os.listdir(os.path.join(root, "train" if train else "val", "foregrounds", cls))
|
||||
]
|
||||
)
|
||||
backgrounds.extend(
|
||||
[
|
||||
f"{cls}/{f}"
|
||||
for f in os.listdir(os.path.join(root, "train" if train else "val", "backgrounds", cls))
|
||||
]
|
||||
)
|
||||
self.foregrounds = foregrounds
|
||||
self.backgrounds = backgrounds
|
||||
|
||||
self.classes = sorted(list(classes), key=lambda x: int(x[1:]))
|
||||
|
||||
assert os.path.exists(f"{root}/fg_bg_ratios_{'train' if train else 'val'}.json"), (
|
||||
f"{root}/fg_bg_ratios_{'train' if train else 'val'}.json not found, provide the information or set"
|
||||
" pruning_ratio=1.0"
|
||||
)
|
||||
with open(f"{root}/fg_bg_ratios_{'train' if train else 'val'}.json", "r") as f:
|
||||
self.fg_bg_ratios = json.load(f)
|
||||
if self._mode == "folder":
|
||||
self.fg_bg_ratios = {"/".join(key.split("/")[-2:]): val for key, val in self.fg_bg_ratios.items()}
|
||||
logger.info(f"Renamed fg_bg_ratios keys to {list(self.fg_bg_ratios.keys())[:3]}...")
|
||||
|
||||
if pruning_ratio <= 1.0:
|
||||
backup_backgrounds = {}
|
||||
for bg_file in self.backgrounds:
|
||||
bg_cls = bg_file.split("/")[-2]
|
||||
backup_backgrounds[bg_cls] = bg_file
|
||||
self.backgrounds = [
|
||||
bg for bg in self.backgrounds if bg in self.fg_bg_ratios and self.fg_bg_ratios[bg] < pruning_ratio
|
||||
]
|
||||
# logger.info(
|
||||
# f"RecombinationNet: keep {len(self.backgrounds)} of {len(self.fg_bg_ratios)} backgrounds with pr {pruning_ratio}"
|
||||
# )
|
||||
|
||||
self.root = root
|
||||
self.train = train
|
||||
self.background_combination = background_combination
|
||||
self.fg_scale_jitter = fg_scale_jitter
|
||||
self.fg_transform = fg_transform
|
||||
self.return_fg_masks = return_fg_masks
|
||||
self.paste_pre_transform = paste_pre_transform
|
||||
self.mask_smoothing_sigma = mask_smoothing_sigma
|
||||
self.rel_jut_out = rel_jut_out
|
||||
self.size_fact = size_fact
|
||||
self.fg_in_nonant = fg_in_nonant
|
||||
assert fg_in_nonant is None or -1 <= fg_in_nonant < 9, f"fg_in_nonant={fg_in_nonant} not in [0, 8] or None"
|
||||
|
||||
self.orig_img_prob = orig_img_prob
|
||||
if orig_img_prob != 0.0:
|
||||
assert (isinstance(orig_img_prob, float) and orig_img_prob > 0.0) or orig_img_prob in [
|
||||
"linear",
|
||||
"cos",
|
||||
"revlinear",
|
||||
]
|
||||
assert orig_ds is not None, "orig_ds must be provided if orig_img_prob > 0.0"
|
||||
assert not return_fg_masks, "can't provide fg masks for original images (yet)"
|
||||
assert os.path.exists(os.path.join(root, f"{'train' if train else 'val'}_indices.json")) or isinstance(
|
||||
orig_ds, str
|
||||
), f"{root}/{'train' if train else 'val'}_indices.json must be provided if orig_ds is a dataset"
|
||||
if not isinstance(orig_ds, str):
|
||||
with open(os.path.join(root, f"{'train' if train else 'val'}_indices.json"), "r") as f:
|
||||
self.key_to_orig_idx = json.load(f)
|
||||
else:
|
||||
if not (
|
||||
orig_ds.endswith("train" if train else "val") or orig_ds.endswith("train/" if train else "val/")
|
||||
):
|
||||
orig_ds = f"{orig_ds}/{'train' if train else 'val'}"
|
||||
self.key_to_orig_idx = None
|
||||
self._orig_ds_file_type = _orig_ds_file_type
|
||||
self.orig_ds = orig_ds
|
||||
self.epochs = epochs
|
||||
self._epoch = 0
|
||||
|
||||
assert fg_size_mode in [
|
||||
"max",
|
||||
"min",
|
||||
"mean",
|
||||
"range",
|
||||
], f"fg_size_mode={fg_size_mode} not supported; use one of ['max', 'min', 'mean', 'range']"
|
||||
self.fg_size_mode = fg_size_mode
|
||||
self.fg_bates_n = fg_bates_n
|
||||
|
||||
if not paste_pre_transform:
|
||||
if isinstance(transform, (T.Compose, Compose)):
|
||||
transform = transform.transforms
|
||||
|
||||
# do cropping and resizing mainly on background; paste foreground on top later
|
||||
self.bg_transform = []
|
||||
self.join_transform = []
|
||||
for tf in transform:
|
||||
if isinstance(tf, tuple(self._bg_transforms)):
|
||||
self.bg_transform.append(tf)
|
||||
else:
|
||||
self.join_transform.append(tf)
|
||||
|
||||
if _album_compose:
|
||||
from data.album_transf import AlbumTorchCompose
|
||||
|
||||
self.bg_transform = AlbumTorchCompose(self.bg_transform)
|
||||
self.join_transform = AlbumTorchCompose(self.join_transform)
|
||||
else:
|
||||
self.bg_transform = T.Compose(self.bg_transform)
|
||||
self.join_transform = T.Compose(self.join_transform)
|
||||
|
||||
else:
|
||||
if isinstance(transform, list):
|
||||
if _album_compose:
|
||||
from data.album_transf import AlbumTorchCompose
|
||||
|
||||
self.join_transform = AlbumTorchCompose(transform)
|
||||
else:
|
||||
self.join_transform = T.Compose(transform)
|
||||
else:
|
||||
self.join_transform = transform
|
||||
self.bg_transform = None
|
||||
|
||||
self.trgt_map = {cls: i for i, cls in enumerate(self.classes)}
|
||||
|
||||
self.target_transform = target_transform
|
||||
|
||||
self.cls_to_allowed_bg = {}
|
||||
for bg_file in self.backgrounds:
|
||||
if background_combination == "same":
|
||||
bg_cls = bg_file.split("/")[-2]
|
||||
if bg_cls not in self.cls_to_allowed_bg:
|
||||
self.cls_to_allowed_bg[bg_cls] = []
|
||||
self.cls_to_allowed_bg[bg_cls].append(bg_file)
|
||||
|
||||
if background_combination == "same":
|
||||
for cls_code in classes:
|
||||
if cls_code not in self.cls_to_allowed_bg or len(self.cls_to_allowed_bg[cls_code]) == 0:
|
||||
self.cls_to_allowed_bg[cls_code] = [backup_backgrounds[cls_code]]
|
||||
logger.warning(f"No background for class {cls_code}, using {backup_backgrounds[cls_code]}")
|
||||
|
||||
self._zf = {}
|
||||
|
||||
@property
|
||||
def epoch(self):
|
||||
return self._epoch
|
||||
|
||||
@epoch.setter
|
||||
def epoch(self, value):
|
||||
self._epoch = value
|
||||
|
||||
def __len__(self):
|
||||
"""Size of the dataset.
|
||||
|
||||
Returns:
|
||||
int: number of foregrounds
|
||||
|
||||
"""
|
||||
return len(self.foregrounds)
|
||||
|
||||
def num_classes(self):
|
||||
return len(self.classes)
|
||||
|
||||
def _get_fg(self, idx):
|
||||
worker_id = self._wrkr_info()
|
||||
|
||||
fg_file = self.foregrounds[idx]
|
||||
with self._zf[worker_id]["fg"].open(fg_file) as f:
|
||||
fg_data = BytesIO(f.read())
|
||||
return Image.open(fg_data)
|
||||
|
||||
def _wrkr_info(self):
|
||||
worker_id = get_worker_info().id if get_worker_info() else 0
|
||||
|
||||
if worker_id not in self._zf and self._mode == "zip":
|
||||
self._zf[worker_id] = {
|
||||
"bg": zipfile.ZipFile(f"{self.root}/backgrounds_{'train' if self.train else 'val'}.zip", "r"),
|
||||
"fg": zipfile.ZipFile(f"{self.root}/foregrounds_{'train' if self.train else 'val'}.zip", "r"),
|
||||
}
|
||||
return worker_id
|
||||
|
||||
def __getitem__(self, idx):
|
||||
"""Get the foreground at index idx and combine it with a (random) background.
|
||||
|
||||
Args:
|
||||
idx (int): foreground index
|
||||
|
||||
Returns:
|
||||
torch.Tensor, torch.Tensor: image, target
|
||||
|
||||
"""
|
||||
worker_id = self._wrkr_info()
|
||||
fg_file = self.foregrounds[idx]
|
||||
trgt_cls = fg_file.split("/")[-2]
|
||||
|
||||
if (
|
||||
(self.orig_img_prob == "linear" and np.random.rand() < self._epoch / self.epochs)
|
||||
or (self.orig_img_prob == "revlinear" and np.random.rand() < (self._epoch - self.epochs) / self.epochs)
|
||||
or (self.orig_img_prob == "cos" and np.random.rand() > np.cos(np.pi * self._epoch / (2 * self.epochs)))
|
||||
or (
|
||||
isinstance(self.orig_img_prob, float)
|
||||
and self.orig_img_prob > 0.0
|
||||
and np.random.rand() < self.orig_img_prob
|
||||
)
|
||||
):
|
||||
data_key = f"{trgt_cls}/{fg_file.split('/')[-1].split('.')[0]}"
|
||||
if isinstance(self.orig_ds, str):
|
||||
image_file = os.path.join(self.orig_ds, f"{data_key}.{self._orig_ds_file_type}")
|
||||
orig_img = Image.open(image_file).convert("RGB")
|
||||
else:
|
||||
orig_data = self.orig_ds[self.key_to_orig_idx[data_key]]
|
||||
orig_img = orig_data["image"] if isinstance(orig_data, dict) else orig_data[0]
|
||||
|
||||
if self.bg_transform:
|
||||
orig_img = self.bg_transform(orig_img)
|
||||
if self.join_transform:
|
||||
orig_img = self.join_transform(orig_img)
|
||||
trgt = self.trgt_map[trgt_cls]
|
||||
if self.target_transform:
|
||||
trgt = self.target_transform(trgt)
|
||||
return orig_img, trgt
|
||||
|
||||
if self._mode == "zip":
|
||||
with self._zf[worker_id]["fg"].open(fg_file) as f:
|
||||
fg_data = BytesIO(f.read())
|
||||
try:
|
||||
fg_img = Image.open(fg_data).convert("RGBA")
|
||||
except PIL.UnidentifiedImageError as e:
|
||||
logger.error(f"Error with idx={idx}, file={self.foregrounds[idx]}")
|
||||
raise e
|
||||
else:
|
||||
# data_key = f"{trgt_cls}/{fg_file.split('/')[-1].split('.')[0]}"
|
||||
fg_img = Image.open(
|
||||
os.path.join(self.root, "train" if self.train else "val", "foregrounds", fg_file)
|
||||
).convert("RGBA")
|
||||
|
||||
if self.fg_transform:
|
||||
fg_img = self.fg_transform(fg_img)
|
||||
fg_size_factor = T.ToTensor()(fg_img.split()[-1]).mean().item()
|
||||
|
||||
if self.background_combination == "all":
|
||||
bg_idx = np.random.randint(len(self.backgrounds))
|
||||
bg_file = self.backgrounds[bg_idx]
|
||||
elif self.background_combination == "original":
|
||||
bg_file = fg_file.replace("foregrounds", "backgrounds").replace("WEBP", "JPEG")
|
||||
else:
|
||||
bg_idx = np.random.randint(len(self.cls_to_allowed_bg[trgt_cls]))
|
||||
bg_file = self.cls_to_allowed_bg[trgt_cls][bg_idx]
|
||||
|
||||
if self._mode == "zip":
|
||||
with self._zf[worker_id]["bg"].open(bg_file) as f:
|
||||
bg_data = BytesIO(f.read())
|
||||
bg_img = Image.open(bg_data).convert("RGB")
|
||||
else:
|
||||
bg_img = Image.open(
|
||||
os.path.join(self.root, "train" if self.train else "val", "backgrounds", bg_file)
|
||||
).convert("RGB")
|
||||
|
||||
if not self.paste_pre_transform:
|
||||
bg_img = self.bg_transform(bg_img)
|
||||
|
||||
bg_size = bg_img.size
|
||||
|
||||
# choose scale factor, such that relative area is in fg_scale
|
||||
bg_area = bg_size[0] * bg_size[1]
|
||||
if self.fg_in_nonant is not None:
|
||||
bg_area = bg_area / 9
|
||||
|
||||
# logger.info(f"background: size={bg_size} area={bg_area}")
|
||||
# logger.info(f"fg_file={fg_file}, fg_bg_ratio_keys={list(self.fg_bg_ratios.keys())[:3]}...")
|
||||
orig_fg_ratio = self.fg_bg_ratios[fg_file.replace("foregrounds", "backgrounds").replace("WEBP", "JPEG")]
|
||||
bg_fg_ratio = self.fg_bg_ratios[bg_file]
|
||||
|
||||
if self.fg_size_mode == "max":
|
||||
goal_fg_ratio_lower = goal_fg_ratio_upper = max(orig_fg_ratio, bg_fg_ratio)
|
||||
elif self.fg_size_mode == "min":
|
||||
goal_fg_ratio_lower = goal_fg_ratio_upper = min(orig_fg_ratio, bg_fg_ratio)
|
||||
elif self.fg_size_mode == "mean":
|
||||
goal_fg_ratio_lower = goal_fg_ratio_upper = (orig_fg_ratio + bg_fg_ratio) / 2
|
||||
else:
|
||||
# range
|
||||
goal_fg_ratio_lower = min(orig_fg_ratio, bg_fg_ratio)
|
||||
goal_fg_ratio_upper = max(orig_fg_ratio, bg_fg_ratio)
|
||||
|
||||
# logger.info(f"fg_bg_ratio={orig_fg_ratio}")
|
||||
fg_scale = (
|
||||
np.random.uniform(
|
||||
goal_fg_ratio_lower * (1 - self.fg_scale_jitter), goal_fg_ratio_upper * (1 + self.fg_scale_jitter)
|
||||
)
|
||||
/ fg_size_factor
|
||||
* self.size_fact
|
||||
)
|
||||
|
||||
goal_shape_y = round(np.sqrt(bg_area * fg_scale * fg_img.size[1] / fg_img.size[0]))
|
||||
goal_shape_x = round(np.sqrt(bg_area * fg_scale * fg_img.size[0] / fg_img.size[1]))
|
||||
|
||||
fg_img = fg_img.resize((goal_shape_x, goal_shape_y))
|
||||
|
||||
if fg_img.size[0] > bg_size[0] or fg_img.size[1] > bg_size[1]:
|
||||
# random crop to fit
|
||||
goal_w, goal_h = (min(fg_img.size[0], bg_size[0]), min(fg_img.size[1], bg_size[1]))
|
||||
fg_img = T.RandomCrop((goal_h, goal_w))(fg_img) if self.train else T.CenterCrop((goal_h, goal_w))(fg_img)
|
||||
|
||||
# paste fg on bg
|
||||
z1, z2 = (
|
||||
(
|
||||
np.random.uniform(0, 1, abs(self.fg_bates_n)).mean(), # bates distribution n=1 => uniform
|
||||
np.random.uniform(0, 1, abs(self.fg_bates_n)).mean(),
|
||||
)
|
||||
if self.fg_bates_n != 0
|
||||
else (0.5, 0.5)
|
||||
)
|
||||
if self.fg_bates_n < 0:
|
||||
z1 = z1 + 0.5 - floor(z1 + 0.5)
|
||||
z2 = z2 + 0.5 - floor(z2 + 0.5)
|
||||
|
||||
x_min = -self.rel_jut_out * fg_img.size[0]
|
||||
x_max = bg_size[0] - fg_img.size[0] * (1 - self.rel_jut_out)
|
||||
y_min = -self.rel_jut_out * fg_img.size[1]
|
||||
y_max = bg_size[1] - fg_img.size[1] * (1 - self.rel_jut_out)
|
||||
|
||||
if self.fg_in_nonant is not None and self.fg_in_nonant >= 0:
|
||||
x_min = (self.fg_in_nonant % 3) * bg_size[0] / 3
|
||||
x_max = ((self.fg_in_nonant % 3) + 1) * bg_size[0] / 3 - fg_img.size[0]
|
||||
y_min = (self.fg_in_nonant // 3) * bg_size[1] / 3
|
||||
y_max = ((self.fg_in_nonant // 3) + 1) * bg_size[1] / 3 - fg_img.size[1]
|
||||
|
||||
if x_min > x_max:
|
||||
x_min = x_max = (x_min + x_max) / 2
|
||||
if y_min > y_max:
|
||||
y_min = y_max = (y_min + y_max) / 2
|
||||
|
||||
offs_x = round(z1 * (x_max - x_min) + x_min)
|
||||
offs_y = round(z2 * (y_max - y_min) + y_min)
|
||||
|
||||
paste_mask = fg_img.split()[-1]
|
||||
if self.mask_smoothing_sigma > 0.0:
|
||||
sigma = (np.random.rand() * 0.9 + 0.1) * self.mask_smoothing_sigma
|
||||
paste_mask = paste_mask.filter(ImageFilter.GaussianBlur(radius=sigma))
|
||||
paste_mask = paste_mask.point(lambda p: 2 * p - 255 if p > 128 else 0)
|
||||
|
||||
bg_img.paste(fg_img.convert("RGB"), (offs_x, offs_y), paste_mask)
|
||||
bg_img = bg_img.convert("RGB")
|
||||
|
||||
if self.return_fg_masks:
|
||||
fg_mask = Image.new("L", bg_size, 0)
|
||||
fg_mask.paste(paste_mask, (offs_x, offs_y))
|
||||
|
||||
fg_mask = T.ToTensor()(fg_mask)[0]
|
||||
|
||||
bg_img = T.ToTensor()(bg_img)
|
||||
|
||||
if self.join_transform:
|
||||
# img_mask_stack = torch.cat([bg_img, fg_mask.unsqueeze(0)], dim=0)
|
||||
# img_mask_stack = self.join_transform(img_mask_stack)
|
||||
# bg_img, fg_mask = img_mask_stack[:-1], img_mask_stack[-1]
|
||||
bg_img, fg_mask = apply_dense_transforms(bg_img, fg_mask, self.join_transform)
|
||||
else:
|
||||
bg_img = self.join_transform(bg_img)
|
||||
|
||||
if trgt_cls not in self.trgt_map:
|
||||
raise ValueError(f"trgt_cls={trgt_cls} not in trgt_map: {self.trgt_map}")
|
||||
trgt = self.trgt_map[trgt_cls]
|
||||
if self.target_transform:
|
||||
trgt = self.target_transform(trgt)
|
||||
|
||||
if self.return_fg_masks:
|
||||
return bg_img, trgt, fg_mask
|
||||
|
||||
return bg_img, trgt
|
||||
Reference in New Issue
Block a user