98 lines
3.4 KiB
Python
98 lines
3.4 KiB
Python
import math
|
|
import os
|
|
import sys
|
|
|
|
import numpy as np
|
|
from loguru import logger
|
|
from matplotlib import pyplot as plt
|
|
from PIL import Image
|
|
from tqdm.auto import tqdm
|
|
|
|
from load_dataset import prepare_dataset
|
|
from utils import log_args, log_formatter, prep_kwargs
|
|
|
|
|
|
def load_images(dataset, **kwargs):
|
|
args = prep_kwargs(kwargs)
|
|
args.dataset = dataset
|
|
|
|
args.aug_normalize = False
|
|
|
|
loader, args.n_classes, args.ignore_index, args.multi_label, _ = prepare_dataset(dataset, args)
|
|
images = next(iter(loader))[0]
|
|
|
|
images = images.permute(0, 2, 3, 1).numpy()
|
|
images = [images[i] for i in range(images.shape[0])]
|
|
|
|
rows = math.ceil(math.sqrt(len(images) / 2))
|
|
ims_per_row = len(images) // rows
|
|
|
|
fig, axs = plt.subplots(rows, ims_per_row)
|
|
axs = [ax for row in axs for ax in row]
|
|
for img, ax in zip(images, axs):
|
|
ax.imshow(img)
|
|
fig.suptitle(f"Examples from {dataset}")
|
|
fig.tight_layout(pad=0)
|
|
|
|
plt.show()
|
|
|
|
|
|
def save_images(dataset, out_dir, ipc=None, **kwargs):
|
|
args = prep_kwargs(kwargs)
|
|
args.dataset = dataset
|
|
args.out_dir = out_dir
|
|
args.ipc = ipc
|
|
args.aug_normalize = False
|
|
|
|
log_file = os.path.join(out_dir, "save_images.log")
|
|
logger.remove()
|
|
logger.configure(extra=dict(run_name=f"Save images of {dataset}", rank=0, world_size=-1))
|
|
logger.add(sys.stderr, format=log_formatter, enqueue=True, colorize=True, level=args.log_level.upper())
|
|
logger.add(log_file, format=log_formatter, enqueue=True, colorize=True, level=args.log_level.upper())
|
|
logger.info(f"Out dir '{out_dir}'")
|
|
log_args(args)
|
|
|
|
loader, args.n_classes, args.ignore_index, args.multi_label, _ = prepare_dataset(dataset, args)
|
|
n_ims = [0 for i in range(args.n_classes)]
|
|
|
|
if args.n_classes == 1000:
|
|
# assume its ImageNet classes
|
|
logger.info("1000 classes => assuming ImageNet class names")
|
|
with open("data/misc_dataset_files/imagenet_labels.txt", "r") as f:
|
|
lines = f.readlines()
|
|
labels = [l.strip().split(" ")[0] for l in lines]
|
|
lbl_to_cls_name = sorted(labels)
|
|
else:
|
|
lbl_to_cls_name = [i for i in range(args.n_classes)]
|
|
|
|
for cls_name in lbl_to_cls_name:
|
|
os.makedirs(os.path.join(args.out_dir, cls_name), exist_ok=True)
|
|
|
|
skipped_ims = 0
|
|
tqdm_is_disabled = (not args.tqdm) or os.environ.get("TQDM_DISABLE", 0) != 0
|
|
for i, (images, labels) in (
|
|
pbar := tqdm(enumerate(loader), desc="Loading and saving images", disable=tqdm_is_disabled, total=len(loader))
|
|
):
|
|
images = (images.permute(0, 2, 3, 1).numpy() * 255).astype(np.uint8)
|
|
images = [images[i] for i in range(images.shape[0])]
|
|
labels = labels.tolist()
|
|
|
|
for img, lbl in zip(images, labels):
|
|
if ipc is not None and n_ims[lbl] >= ipc:
|
|
skipped_ims += 1
|
|
continue
|
|
|
|
img = Image.fromarray(img).save(
|
|
os.path.join(args.out_dir, lbl_to_cls_name[lbl], f"{lbl_to_cls_name[lbl]}_{n_ims[lbl]}.JPEG")
|
|
)
|
|
n_ims[lbl] += 1
|
|
|
|
if ipc is not None and sum(n_ims) >= args.n_classes * ipc:
|
|
break
|
|
if tqdm_is_disabled:
|
|
if i % 1000 == 0:
|
|
logger.info(f"Batch [{i+1}/{len(loader)}]: Saved {sum(n_ims)}, skipped {skipped_ims}")
|
|
else:
|
|
pbar.set_description(f"Loading and saving (saved {sum(n_ims)}, skipped {skipped_ims})")
|
|
logger.success(f"Extracted {sum(n_ims)} images.")
|