import os from loguru import logger from PIL import Image from torch.utils.data import Dataset class CounterAnimal(Dataset): """Dataset to load the CounterAnimal dataset with ImageNet labels.""" def __init__(self, base_path, mode, transform=None, target_transform=None, train=False): """Create the dataset. Args: base_path (str): path to the base folder (the one where the class folders are in) mode (str): mode/variant of the dataset (common/counter) transform: Image augmentation target_transform: label augmentation train: train or test set. Train set is not supported """ super().__init__() self.base = base_path assert mode in ["counter", "common"], f"Supported modes are counter and common, but got '{mode}'" assert not train, "CounterAnimal only consists of test data, not training data." self.transform = transform self.target_transform = target_transform self.index = [] for class_folder in os.listdir(self.base): if not os.path.isdir(os.path.join(self.base, class_folder)): continue # print(f"looking in folder {class_folder}") class_idx = int(class_folder.split(" ")[0]) for variant_folder in os.listdir(os.path.join(self.base, class_folder)): # print(f"\tlooking in variant {variant_folder}") if not variant_folder.startswith(mode): # print("\t\tskip") continue _folder = os.path.join(self.base, class_folder, variant_folder) # print(f"\t\tadding {len(os.listdir(_folder))} files to index") for file in os.listdir(_folder): if file.lower().split(".")[-1] in ["jpg", "jpeg", "png"]: self.index.append((os.path.join(_folder, file), class_idx)) # print(f"loaded {len(self.index)} images into the index: {self.index[:5]}...") assert len(self.index) > 0, "did not find any images :(" def __len__(self): return len(self.index) def __getitem__(self, idx): path, label = self.index[idx] img = Image.open(path).convert("RGB") if self.transform: img = self.transform(img) if self.target_transform: label = self.target_transform(label) return img, label