Files
ForAug/AAAI Supplementary Material/Model Training Code/test.py
Tobias Christian Nauen ff34712155 AAAI Version
2026-02-24 12:22:44 +01:00

98 lines
3.4 KiB
Python

import math
import os
import sys
import numpy as np
from loguru import logger
from matplotlib import pyplot as plt
from PIL import Image
from tqdm.auto import tqdm
from load_dataset import prepare_dataset
from utils import log_args, log_formatter, prep_kwargs
def load_images(dataset, **kwargs):
args = prep_kwargs(kwargs)
args.dataset = dataset
args.aug_normalize = False
loader, args.n_classes, args.ignore_index, args.multi_label, _ = prepare_dataset(dataset, args)
images = next(iter(loader))[0]
images = images.permute(0, 2, 3, 1).numpy()
images = [images[i] for i in range(images.shape[0])]
rows = math.ceil(math.sqrt(len(images) / 2))
ims_per_row = len(images) // rows
fig, axs = plt.subplots(rows, ims_per_row)
axs = [ax for row in axs for ax in row]
for img, ax in zip(images, axs):
ax.imshow(img)
fig.suptitle(f"Examples from {dataset}")
fig.tight_layout(pad=0)
plt.show()
def save_images(dataset, out_dir, ipc=None, **kwargs):
args = prep_kwargs(kwargs)
args.dataset = dataset
args.out_dir = out_dir
args.ipc = ipc
args.aug_normalize = False
log_file = os.path.join(out_dir, "save_images.log")
logger.remove()
logger.configure(extra=dict(run_name=f"Save images of {dataset}", rank=0, world_size=-1))
logger.add(sys.stderr, format=log_formatter, enqueue=True, colorize=True, level=args.log_level.upper())
logger.add(log_file, format=log_formatter, enqueue=True, colorize=True, level=args.log_level.upper())
logger.info(f"Out dir '{out_dir}'")
log_args(args)
loader, args.n_classes, args.ignore_index, args.multi_label, _ = prepare_dataset(dataset, args)
n_ims = [0 for i in range(args.n_classes)]
if args.n_classes == 1000:
# assume its ImageNet classes
logger.info("1000 classes => assuming ImageNet class names")
with open("data/misc_dataset_files/imagenet_labels.txt", "r") as f:
lines = f.readlines()
labels = [l.strip().split(" ")[0] for l in lines]
lbl_to_cls_name = sorted(labels)
else:
lbl_to_cls_name = [i for i in range(args.n_classes)]
for cls_name in lbl_to_cls_name:
os.makedirs(os.path.join(args.out_dir, cls_name), exist_ok=True)
skipped_ims = 0
tqdm_is_disabled = (not args.tqdm) or os.environ.get("TQDM_DISABLE", 0) != 0
for i, (images, labels) in (
pbar := tqdm(enumerate(loader), desc="Loading and saving images", disable=tqdm_is_disabled, total=len(loader))
):
images = (images.permute(0, 2, 3, 1).numpy() * 255).astype(np.uint8)
images = [images[i] for i in range(images.shape[0])]
labels = labels.tolist()
for img, lbl in zip(images, labels):
if ipc is not None and n_ims[lbl] >= ipc:
skipped_ims += 1
continue
img = Image.fromarray(img).save(
os.path.join(args.out_dir, lbl_to_cls_name[lbl], f"{lbl_to_cls_name[lbl]}_{n_ims[lbl]}.JPEG")
)
n_ims[lbl] += 1
if ipc is not None and sum(n_ims) >= args.n_classes * ipc:
break
if tqdm_is_disabled:
if i % 1000 == 0:
logger.info(f"Batch [{i+1}/{len(loader)}]: Saved {sum(n_ims)}, skipped {skipped_ims}")
else:
pbar.set_description(f"Loading and saving (saved {sum(n_ims)}, skipped {skipped_ims})")
logger.success(f"Extracted {sum(n_ims)} images.")