AAAI Version
This commit is contained in:
@@ -0,0 +1,211 @@
|
||||
import argparse
|
||||
import os
|
||||
from functools import partial
|
||||
from math import log
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision.transforms import Compose, RandomCrop, Resize, ToTensor
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
try:
|
||||
from models import load_pretrained
|
||||
except ModuleNotFoundError:
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
||||
from models import load_pretrained
|
||||
|
||||
from utils import prep_kwargs
|
||||
|
||||
|
||||
def score_5( # noqa: D103
|
||||
idx,
|
||||
bg_probs,
|
||||
mean_probs,
|
||||
fg_ratio,
|
||||
max_idx,
|
||||
fg_ratio_max=0.9,
|
||||
fg_ratio_min=0.002,
|
||||
fg_ratio_exp=0.4, # learned: 0.487
|
||||
idx_exp=0.01, # learned: 0.043
|
||||
bg_probs_exp=0.2, # learned: 0.24
|
||||
opt_fg_ratio=0.1, # learned: 0.1
|
||||
mean_probs_exp=0.2, # learned: 0.2
|
||||
fg_ratio_penalty=1, # learned: -0.446 ???
|
||||
):
|
||||
return (
|
||||
log(mean_probs) * mean_probs_exp
|
||||
+ log(1 - bg_probs) * bg_probs_exp
|
||||
+ log(1 - abs(fg_ratio - opt_fg_ratio)) * fg_ratio_exp
|
||||
+ log(1 - idx / (max_idx + 1)) * idx_exp
|
||||
+ (fg_ratio_min < fg_ratio < fg_ratio_max) * fg_ratio_penalty
|
||||
) # TODO: KEEP IT LIKE IT IS
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="Inspect image versions")
|
||||
parser.add_argument("-f", "--base_folder", type=str, required=True, help="Base folder to inspect")
|
||||
parser.add_argument(
|
||||
"-batch_size",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Batch size for model inspection. Will be 1 background and batch_size - 1 foregrounds",
|
||||
)
|
||||
parser.add_argument("-imsize", type=int, default=224, help="Image size")
|
||||
parser.add_argument(
|
||||
"-score_f_weights", choices=["manual", "automatic"], default="manual", help="Score function hyperparameters"
|
||||
)
|
||||
parser.add_argument("-auto_fg_pen_val", type=float, default=-0.446, help="Automatic foreground penalty value")
|
||||
parser.add_argument("-d", "--dataset", choices=["tinyimagenet", "imagenet"], default="tinyimagenet", help="Dataset")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
score_f = (
|
||||
score_5
|
||||
if args.score_f_weights == "manual"
|
||||
else partial(
|
||||
score_5, **dict(fg_ratio_exp=0.487, idx_exp=0.043, bg_probs_exp=0.24, fg_ratio_penalty=args.auto_fg_pen_val)
|
||||
)
|
||||
)
|
||||
|
||||
bg_folder = os.path.join(args.base_folder, "backgrounds")
|
||||
fg_folder = os.path.join(args.base_folder, "foregrounds")
|
||||
|
||||
classes = os.listdir(fg_folder)
|
||||
classes = sorted(classes, key=lambda x: int(x[1:]))
|
||||
assert len(classes) in [200, 1_000], f"Expected 200 or 1_000 classes, got {len(classes)}"
|
||||
|
||||
total_images = set()
|
||||
for in_cls in classes:
|
||||
cls_images = {
|
||||
os.path.join(in_cls, "_".join(img.split(".")[0].split("_")[:-1]))
|
||||
for img in os.listdir(os.path.join(fg_folder, in_cls))
|
||||
if img.split(".")[0].split("_")[-1].startswith("v")
|
||||
}
|
||||
total_images.update(cls_images)
|
||||
total_images = list(total_images)
|
||||
|
||||
# base_folder = os.path.join(*(os.path.dirname(__file__).split("/")[:-1]))
|
||||
|
||||
# in_cls to print name/lemma
|
||||
with open(os.path.join("data", "misc_dataset_files", "tinyimagenet_synset_names.txt"), "r") as f:
|
||||
in_cls_to_name = {line.split(":")[0].strip(): line.split(":")[1].strip() for line in f.readlines() if len(line) > 2}
|
||||
|
||||
if args.dataset == "tinyimagenet":
|
||||
inspection_model_paths = [] # PATHS TO MODEL WEIGHTS (.pt) PRETRAINED ON TinyImageNet
|
||||
elif args.dataset == "imagenet":
|
||||
inspection_model_paths = [] # PATHS TO MODEL WEIGHTS (.pt) PRETRAINED ON ImageNet
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset {args.dataset}")
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
inspection_models = [
|
||||
load_pretrained(path, prep_kwargs({}), new_dataset_params=False)[0].to(device) for path in inspection_model_paths
|
||||
]
|
||||
img_transform = Compose([Resize((args.imsize, args.imsize)), RandomCrop(args.imsize), ToTensor()])
|
||||
|
||||
total_versions = []
|
||||
|
||||
for img_name in tqdm(total_images, desc="Image version computation"):
|
||||
in_cls, img_name = img_name.split("/")
|
||||
versions = set()
|
||||
for img in os.listdir(os.path.join(fg_folder, in_cls)):
|
||||
if "_".join(img.split("_")[: len(img_name.split("_"))]) == img_name:
|
||||
versions.add(img)
|
||||
if len(versions) == 1:
|
||||
version = list(versions)[0]
|
||||
if version.split(".")[0].split("_")[-1].startswith("v"):
|
||||
tqdm.write(f"renaming single version image {version} to {img_name}.WEBP")
|
||||
os.rename(os.path.join(fg_folder, in_cls, version), os.path.join(fg_folder, in_cls, f"{img_name}.WEBP"))
|
||||
os.rename(
|
||||
os.path.join(bg_folder, in_cls, version.replace(".WEBP", ".JPEG")),
|
||||
os.path.join(bg_folder, in_cls, f"{img_name}.JPEG"),
|
||||
)
|
||||
continue
|
||||
elif len(versions) == 0:
|
||||
tqdm.write(f"Image {img_name} has no versions")
|
||||
continue
|
||||
versions = sorted(list(versions))
|
||||
assert all(
|
||||
[version.split(".")[0].split("_")[-1].startswith("v") for version in versions]
|
||||
), f"Weird Versions: {versions} for image {img_name}"
|
||||
assert len(versions) <= 3, f"Too many versions for image {img_name}: {versions}"
|
||||
|
||||
version_scores = []
|
||||
for v_idx, version in enumerate(versions):
|
||||
img = Image.open(os.path.join(fg_folder, in_cls, version))
|
||||
bg_img = Image.open(os.path.join(bg_folder, in_cls, f"{version.split('.')[0]}.JPEG"))
|
||||
img_mask = np.array(img.convert("RGBA").split()[-1])
|
||||
|
||||
fg_ratio = np.sum(img_mask) / (255 * bg_img.size[0] * bg_img.size[1])
|
||||
|
||||
fg_size = img.size
|
||||
monochrome_backgrounds = [
|
||||
Image.new(
|
||||
"RGB",
|
||||
(max(args.imsize, fg_size[0]), max(args.imsize, fg_size[1])),
|
||||
(255 * i // (args.batch_size - 2), 255 * i // (args.batch_size - 2), 255 * i // (args.batch_size - 2)),
|
||||
)
|
||||
for i in range(args.batch_size - 1)
|
||||
]
|
||||
pasting_error = False
|
||||
for mc_bg in monochrome_backgrounds:
|
||||
try:
|
||||
mc_bg.paste(img, ((args.imsize - fg_size[0]) // 2, (args.imsize - fg_size[1]) // 2), img)
|
||||
except ValueError as e:
|
||||
tqdm.write(f"Image {img_name} could not be pasted into background: {e}")
|
||||
pasting_error = True
|
||||
break
|
||||
|
||||
inp_batch = torch.stack(
|
||||
[img_transform(bg_img)] + [img_transform(mc_bg) for mc_bg in monochrome_backgrounds], dim=0
|
||||
).to(device)
|
||||
|
||||
cls_idx = classes.index(in_cls)
|
||||
bg_probs = []
|
||||
mean_probs = []
|
||||
for model in inspection_models:
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
out_probs = model(inp_batch).softmax(dim=-1)[:, cls_idx].cpu().numpy()
|
||||
bg_probs.append(out_probs[0])
|
||||
mean_probs.append(np.mean(out_probs[1:]))
|
||||
|
||||
# average the lists
|
||||
bg_probs = np.mean(bg_probs)
|
||||
mean_probs = np.mean(mean_probs)
|
||||
|
||||
version_score = (
|
||||
score_f(
|
||||
idx=v_idx,
|
||||
bg_probs=float(bg_probs),
|
||||
mean_probs=float(mean_probs),
|
||||
fg_ratio=float(fg_ratio),
|
||||
max_idx=len(versions) - 1,
|
||||
)
|
||||
if not pasting_error
|
||||
else -100
|
||||
)
|
||||
version_scores.append(version_score)
|
||||
|
||||
assert len(versions) == len(version_scores), f"Expected {len(versions)} scores, got {len(version_scores)}"
|
||||
|
||||
if max(version_scores) > min(version_scores):
|
||||
# find best version
|
||||
best_version_idx = int(np.argmax(version_scores))
|
||||
best_version = versions[best_version_idx]
|
||||
|
||||
# delete all other versions
|
||||
for version in versions:
|
||||
if version != best_version:
|
||||
os.remove(os.path.join(fg_folder, in_cls, version))
|
||||
os.remove(os.path.join(bg_folder, in_cls, f"{version.split('.')[0]}.JPEG"))
|
||||
# remove version tag in name
|
||||
new_version_name = "_".join(best_version.split("_")[:-1]) + "." + best_version.split(".")[-1]
|
||||
os.rename(os.path.join(fg_folder, in_cls, best_version), os.path.join(fg_folder, in_cls, new_version_name))
|
||||
os.rename(
|
||||
os.path.join(bg_folder, in_cls, f"{best_version.split('.')[0]}.JPEG"),
|
||||
os.path.join(bg_folder, in_cls, f"{new_version_name.split('.')[0]}.JPEG"),
|
||||
)
|
||||
else:
|
||||
tqdm.write(f"All versions have the same score for image {img_name}")
|
||||
@@ -0,0 +1,16 @@
|
||||
#!/bin/bash
|
||||
|
||||
srun -K \
|
||||
--container-image=PATH/TO/SLURM/IMAGE \
|
||||
--container-workdir="$(pwd)" \
|
||||
--container-mounts=/ALL/IMPORTANT/MOUNTS,"$(pwd)":"$(pwd)" \
|
||||
--partition=RTXA6000,RTX3090,A100-40GB,A100-80GB,H100,H200 \
|
||||
--job-name="python" \
|
||||
--nodes=1 \
|
||||
--gpus=1 \
|
||||
--ntasks=1 \
|
||||
--cpus-per-task=24 \
|
||||
--mem=64G \
|
||||
--time=1-0 \
|
||||
--export="NLTK_DATA=/PATH/TO/NLTK_DATA" \
|
||||
python3 "$@"
|
||||
Reference in New Issue
Block a user