import math import os import sys import numpy as np from loguru import logger from matplotlib import pyplot as plt from PIL import Image from tqdm.auto import tqdm from load_dataset import prepare_dataset from utils import log_args, log_formatter, prep_kwargs def load_images(dataset, **kwargs): args = prep_kwargs(kwargs) args.dataset = dataset args.aug_normalize = False loader, args.n_classes, args.ignore_index, args.multi_label, _ = prepare_dataset(dataset, args) images = next(iter(loader))[0] images = images.permute(0, 2, 3, 1).numpy() images = [images[i] for i in range(images.shape[0])] rows = math.ceil(math.sqrt(len(images) / 2)) ims_per_row = len(images) // rows fig, axs = plt.subplots(rows, ims_per_row) axs = [ax for row in axs for ax in row] for img, ax in zip(images, axs): ax.imshow(img) fig.suptitle(f"Examples from {dataset}") fig.tight_layout(pad=0) plt.show() def save_images(dataset, out_dir, ipc=None, **kwargs): args = prep_kwargs(kwargs) args.dataset = dataset args.out_dir = out_dir args.ipc = ipc args.aug_normalize = False log_file = os.path.join(out_dir, "save_images.log") logger.remove() logger.configure(extra=dict(run_name=f"Save images of {dataset}", rank=0, world_size=-1)) logger.add(sys.stderr, format=log_formatter, enqueue=True, colorize=True, level=args.log_level.upper()) logger.add(log_file, format=log_formatter, enqueue=True, colorize=True, level=args.log_level.upper()) logger.info(f"Out dir '{out_dir}'") log_args(args) loader, args.n_classes, args.ignore_index, args.multi_label, _ = prepare_dataset(dataset, args) n_ims = [0 for i in range(args.n_classes)] if args.n_classes == 1000: # assume its ImageNet classes logger.info("1000 classes => assuming ImageNet class names") with open("data/misc_dataset_files/imagenet_labels.txt", "r") as f: lines = f.readlines() labels = [l.strip().split(" ")[0] for l in lines] lbl_to_cls_name = sorted(labels) else: lbl_to_cls_name = [i for i in range(args.n_classes)] for cls_name in lbl_to_cls_name: os.makedirs(os.path.join(args.out_dir, cls_name), exist_ok=True) skipped_ims = 0 tqdm_is_disabled = (not args.tqdm) or os.environ.get("TQDM_DISABLE", 0) != 0 for i, (images, labels) in ( pbar := tqdm(enumerate(loader), desc="Loading and saving images", disable=tqdm_is_disabled, total=len(loader)) ): images = (images.permute(0, 2, 3, 1).numpy() * 255).astype(np.uint8) images = [images[i] for i in range(images.shape[0])] labels = labels.tolist() for img, lbl in zip(images, labels): if ipc is not None and n_ims[lbl] >= ipc: skipped_ims += 1 continue img = Image.fromarray(img).save( os.path.join(args.out_dir, lbl_to_cls_name[lbl], f"{lbl_to_cls_name[lbl]}_{n_ims[lbl]}.JPEG") ) n_ims[lbl] += 1 if ipc is not None and sum(n_ims) >= args.n_classes * ipc: break if tqdm_is_disabled: if i % 1000 == 0: logger.info(f"Batch [{i+1}/{len(loader)}]: Saved {sum(n_ims)}, skipped {skipped_ims}") else: pbar.set_description(f"Loading and saving (saved {sum(n_ims)}, skipped {skipped_ims})") logger.success(f"Extracted {sum(n_ims)} images.")