43 lines
1.8 KiB
Python
43 lines
1.8 KiB
Python
import argparse
|
|
import json
|
|
|
|
from datadings.reader import MsgpackReader
|
|
from datadings.torch import Dataset
|
|
from tqdm.auto import tqdm
|
|
|
|
parser = argparse.ArgumentParser(description="Foreground size ratio")
|
|
parser.add_argument("-mode", choices=["train", "val"], default="train", help="Train or val data?")
|
|
parser.add_argument("-ds", "--dataset", choices=["imagenet", "tinyimagenet"], required=True, help="Dataset to use")
|
|
parser.add_argument("-r", "--dataset_root", required=True, type=str, help="Root directory of the dataset")
|
|
parser.add_argument("-s", "--segment_root", required=True, type=str, help="Root directory of the segmentation dataset")
|
|
args = parser.parse_args()
|
|
|
|
|
|
if args.dataset == "imagenet":
|
|
reader = MsgpackReader(f"{args.dataset_root}imagenet/msgpack/{args.mode}.msgpack")
|
|
dataset = Dataset(reader)
|
|
elif args.dataset == "tinyimagenet":
|
|
reader = MsgpackReader(f"{args.dataset_root}TinyINSegment/TinyIN_{args.mode}.msgpack")
|
|
dataset = Dataset(reader)
|
|
else:
|
|
raise ValueError(f"Unknown dataset: {args.dataset}")
|
|
|
|
if args.dataset.startswith("imagenet"):
|
|
with open("wordnet_data/imagenet1k_synsets.json", "r") as f:
|
|
id_to_synset = json.load(f)
|
|
id_to_synset = {int(k): v for k, v in id_to_synset.items()}
|
|
elif args.dataset.startswith("tinyimagenet"):
|
|
with open("wordnet_data/tinyimagenet_synset_names.txt", "r") as f:
|
|
synsets = f.readlines()
|
|
id_to_synset = [int(synset.split(":")[0].strip()[1:]) for synset in synsets]
|
|
id_to_synset = sorted(id_to_synset)
|
|
|
|
key_to_idx = {
|
|
f"n{id_to_synset[data['label']]:08d}/{data['key'].split('.')[0]}": i
|
|
for i, data in enumerate(tqdm(dataset, leave=True))
|
|
}
|
|
|
|
print(len(key_to_idx))
|
|
with open(f"{args.segment_root}/{args.mode}_indices.json", "w") as f:
|
|
json.dump(key_to_idx, f)
|