AAAI Version

This commit is contained in:
Tobias Christian Nauen
2026-02-24 12:22:44 +01:00
parent 5c08f9d31a
commit ff34712155
378 changed files with 19844 additions and 4780 deletions

View File

@@ -0,0 +1,97 @@
import math
import os
import sys
import numpy as np
from loguru import logger
from matplotlib import pyplot as plt
from PIL import Image
from tqdm.auto import tqdm
from load_dataset import prepare_dataset
from utils import log_args, log_formatter, prep_kwargs
def load_images(dataset, **kwargs):
args = prep_kwargs(kwargs)
args.dataset = dataset
args.aug_normalize = False
loader, args.n_classes, args.ignore_index, args.multi_label, _ = prepare_dataset(dataset, args)
images = next(iter(loader))[0]
images = images.permute(0, 2, 3, 1).numpy()
images = [images[i] for i in range(images.shape[0])]
rows = math.ceil(math.sqrt(len(images) / 2))
ims_per_row = len(images) // rows
fig, axs = plt.subplots(rows, ims_per_row)
axs = [ax for row in axs for ax in row]
for img, ax in zip(images, axs):
ax.imshow(img)
fig.suptitle(f"Examples from {dataset}")
fig.tight_layout(pad=0)
plt.show()
def save_images(dataset, out_dir, ipc=None, **kwargs):
args = prep_kwargs(kwargs)
args.dataset = dataset
args.out_dir = out_dir
args.ipc = ipc
args.aug_normalize = False
log_file = os.path.join(out_dir, "save_images.log")
logger.remove()
logger.configure(extra=dict(run_name=f"Save images of {dataset}", rank=0, world_size=-1))
logger.add(sys.stderr, format=log_formatter, enqueue=True, colorize=True, level=args.log_level.upper())
logger.add(log_file, format=log_formatter, enqueue=True, colorize=True, level=args.log_level.upper())
logger.info(f"Out dir '{out_dir}'")
log_args(args)
loader, args.n_classes, args.ignore_index, args.multi_label, _ = prepare_dataset(dataset, args)
n_ims = [0 for i in range(args.n_classes)]
if args.n_classes == 1000:
# assume its ImageNet classes
logger.info("1000 classes => assuming ImageNet class names")
with open("data/misc_dataset_files/imagenet_labels.txt", "r") as f:
lines = f.readlines()
labels = [l.strip().split(" ")[0] for l in lines]
lbl_to_cls_name = sorted(labels)
else:
lbl_to_cls_name = [i for i in range(args.n_classes)]
for cls_name in lbl_to_cls_name:
os.makedirs(os.path.join(args.out_dir, cls_name), exist_ok=True)
skipped_ims = 0
tqdm_is_disabled = (not args.tqdm) or os.environ.get("TQDM_DISABLE", 0) != 0
for i, (images, labels) in (
pbar := tqdm(enumerate(loader), desc="Loading and saving images", disable=tqdm_is_disabled, total=len(loader))
):
images = (images.permute(0, 2, 3, 1).numpy() * 255).astype(np.uint8)
images = [images[i] for i in range(images.shape[0])]
labels = labels.tolist()
for img, lbl in zip(images, labels):
if ipc is not None and n_ims[lbl] >= ipc:
skipped_ims += 1
continue
img = Image.fromarray(img).save(
os.path.join(args.out_dir, lbl_to_cls_name[lbl], f"{lbl_to_cls_name[lbl]}_{n_ims[lbl]}.JPEG")
)
n_ims[lbl] += 1
if ipc is not None and sum(n_ims) >= args.n_classes * ipc:
break
if tqdm_is_disabled:
if i % 1000 == 0:
logger.info(f"Batch [{i+1}/{len(loader)}]: Saved {sum(n_ims)}, skipped {skipped_ims}")
else:
pbar.set_description(f"Loading and saving (saved {sum(n_ims)}, skipped {skipped_ims})")
logger.success(f"Extracted {sum(n_ims)} images.")