import argparse import json from datadings.reader import MsgpackReader from datadings.torch import Dataset from tqdm.auto import tqdm parser = argparse.ArgumentParser(description="Foreground size ratio") parser.add_argument("-mode", choices=["train", "val"], default="train", help="Train or val data?") parser.add_argument("-ds", "--dataset", choices=["imagenet", "tinyimagenet"], required=True, help="Dataset to use") parser.add_argument("-r", "--dataset_root", required=True, type=str, help="Root directory of the dataset") parser.add_argument("-s", "--segment_root", required=True, type=str, help="Root directory of the segmentation dataset") args = parser.parse_args() if args.dataset == "imagenet": reader = MsgpackReader(f"{args.dataset_root}imagenet/msgpack/{args.mode}.msgpack") dataset = Dataset(reader) elif args.dataset == "tinyimagenet": reader = MsgpackReader(f"{args.dataset_root}TinyINSegment/TinyIN_{args.mode}.msgpack") dataset = Dataset(reader) else: raise ValueError(f"Unknown dataset: {args.dataset}") if args.dataset.startswith("imagenet"): with open("wordnet_data/imagenet1k_synsets.json", "r") as f: id_to_synset = json.load(f) id_to_synset = {int(k): v for k, v in id_to_synset.items()} elif args.dataset.startswith("tinyimagenet"): with open("wordnet_data/tinyimagenet_synset_names.txt", "r") as f: synsets = f.readlines() id_to_synset = [int(synset.split(":")[0].strip()[1:]) for synset in synsets] id_to_synset = sorted(id_to_synset) key_to_idx = { f"n{id_to_synset[data['label']]:08d}/{data['key'].split('.')[0]}": i for i, data in enumerate(tqdm(dataset, leave=True)) } print(len(key_to_idx)) with open(f"{args.segment_root}/{args.mode}_indices.json", "w") as f: json.dump(key_to_idx, f)