AAAI Version
This commit is contained in:
97
AAAI Supplementary Material/Model Training Code/test.py
Normal file
97
AAAI Supplementary Material/Model Training Code/test.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
from matplotlib import pyplot as plt
|
||||
from PIL import Image
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from load_dataset import prepare_dataset
|
||||
from utils import log_args, log_formatter, prep_kwargs
|
||||
|
||||
|
||||
def load_images(dataset, **kwargs):
|
||||
args = prep_kwargs(kwargs)
|
||||
args.dataset = dataset
|
||||
|
||||
args.aug_normalize = False
|
||||
|
||||
loader, args.n_classes, args.ignore_index, args.multi_label, _ = prepare_dataset(dataset, args)
|
||||
images = next(iter(loader))[0]
|
||||
|
||||
images = images.permute(0, 2, 3, 1).numpy()
|
||||
images = [images[i] for i in range(images.shape[0])]
|
||||
|
||||
rows = math.ceil(math.sqrt(len(images) / 2))
|
||||
ims_per_row = len(images) // rows
|
||||
|
||||
fig, axs = plt.subplots(rows, ims_per_row)
|
||||
axs = [ax for row in axs for ax in row]
|
||||
for img, ax in zip(images, axs):
|
||||
ax.imshow(img)
|
||||
fig.suptitle(f"Examples from {dataset}")
|
||||
fig.tight_layout(pad=0)
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
def save_images(dataset, out_dir, ipc=None, **kwargs):
|
||||
args = prep_kwargs(kwargs)
|
||||
args.dataset = dataset
|
||||
args.out_dir = out_dir
|
||||
args.ipc = ipc
|
||||
args.aug_normalize = False
|
||||
|
||||
log_file = os.path.join(out_dir, "save_images.log")
|
||||
logger.remove()
|
||||
logger.configure(extra=dict(run_name=f"Save images of {dataset}", rank=0, world_size=-1))
|
||||
logger.add(sys.stderr, format=log_formatter, enqueue=True, colorize=True, level=args.log_level.upper())
|
||||
logger.add(log_file, format=log_formatter, enqueue=True, colorize=True, level=args.log_level.upper())
|
||||
logger.info(f"Out dir '{out_dir}'")
|
||||
log_args(args)
|
||||
|
||||
loader, args.n_classes, args.ignore_index, args.multi_label, _ = prepare_dataset(dataset, args)
|
||||
n_ims = [0 for i in range(args.n_classes)]
|
||||
|
||||
if args.n_classes == 1000:
|
||||
# assume its ImageNet classes
|
||||
logger.info("1000 classes => assuming ImageNet class names")
|
||||
with open("data/misc_dataset_files/imagenet_labels.txt", "r") as f:
|
||||
lines = f.readlines()
|
||||
labels = [l.strip().split(" ")[0] for l in lines]
|
||||
lbl_to_cls_name = sorted(labels)
|
||||
else:
|
||||
lbl_to_cls_name = [i for i in range(args.n_classes)]
|
||||
|
||||
for cls_name in lbl_to_cls_name:
|
||||
os.makedirs(os.path.join(args.out_dir, cls_name), exist_ok=True)
|
||||
|
||||
skipped_ims = 0
|
||||
tqdm_is_disabled = (not args.tqdm) or os.environ.get("TQDM_DISABLE", 0) != 0
|
||||
for i, (images, labels) in (
|
||||
pbar := tqdm(enumerate(loader), desc="Loading and saving images", disable=tqdm_is_disabled, total=len(loader))
|
||||
):
|
||||
images = (images.permute(0, 2, 3, 1).numpy() * 255).astype(np.uint8)
|
||||
images = [images[i] for i in range(images.shape[0])]
|
||||
labels = labels.tolist()
|
||||
|
||||
for img, lbl in zip(images, labels):
|
||||
if ipc is not None and n_ims[lbl] >= ipc:
|
||||
skipped_ims += 1
|
||||
continue
|
||||
|
||||
img = Image.fromarray(img).save(
|
||||
os.path.join(args.out_dir, lbl_to_cls_name[lbl], f"{lbl_to_cls_name[lbl]}_{n_ims[lbl]}.JPEG")
|
||||
)
|
||||
n_ims[lbl] += 1
|
||||
|
||||
if ipc is not None and sum(n_ims) >= args.n_classes * ipc:
|
||||
break
|
||||
if tqdm_is_disabled:
|
||||
if i % 1000 == 0:
|
||||
logger.info(f"Batch [{i+1}/{len(loader)}]: Saved {sum(n_ims)}, skipped {skipped_ims}")
|
||||
else:
|
||||
pbar.set_description(f"Loading and saving (saved {sum(n_ims)}, skipped {skipped_ims})")
|
||||
logger.success(f"Extracted {sum(n_ims)} images.")
|
||||
Reference in New Issue
Block a user