AAAI Version

This commit is contained in:
Tobias Christian Nauen
2026-02-24 12:22:44 +01:00
parent 5c08f9d31a
commit ff34712155
378 changed files with 19844 additions and 4780 deletions

View File

@@ -0,0 +1,151 @@
import argparse
import shutil
import zipfile
from os import listdir, makedirs, path
from random import choice
from datadings.reader import MsgpackReader
from datadings.writer import FileWriter
from tqdm.auto import tqdm
parser = argparse.ArgumentParser()
parser.add_argument("-tiny_imagenet_zip", type=str, required=True, help="Path to the Tiny ImageNet zip file")
parser.add_argument("-output_dir", type=str, required=True, help="Directory to extract the image names to")
parser.add_argument("-in_segment_dir", type=str, required=True, help="Directory that holds the segmented ImageNet")
parser.add_argument(
"-imagenet_path", type=str, nargs="?", required=True, help="Path to the original ImageNet dataset (datadings)"
)
args = parser.parse_args()
images = {"train": set(), "val": set()}
with zipfile.ZipFile(args.tiny_imagenet_zip, "r") as zip_ref:
for info in tqdm(zip_ref.infolist(), desc="Gathering Images"):
if info.filename.endswith(".JPEG"):
if "/val/" in info.filename:
images["val"].add(info.filename.split("/")[-1])
elif "/train/" in info.filename:
images["train"].add(info.filename.split("/")[-1])
with open(path.join(args.output_dir, "tiny_imagenet_train_images.txt"), "w+") as f:
f.write("\n".join(images["train"]))
with open(path.join(args.output_dir, "tiny_imagenet_val_images.txt"), "w+") as f:
f.write("\n".join(images["val"]))
print(f"Found {len(images['train'])} training images and {len(images['val'])} validation images")
classes = {img_name.split("_")[0] for img_name in images["train"]}
classes = sorted(list(classes), key=lambda x: int(x[1:]))
assert len(classes) == 200, f"Expected 200 classes, found {len(classes)}"
assert len(images["train"]) == len(classes) * 500, f"Expected 100000 training images, found {len(images['train'])}"
assert len(images["val"]) == len(classes) * 50, f"Expected 10000 validation images, found {len(images['val'])}"
with open(path.join(args.output_dir, "tiny_imagenet_classes.txt"), "w+") as f:
f.write("\n".join(classes))
# copy over the relevant images
for split in ["train", "val"]:
ipc = 500 if split == "train" else 50
part = "foregrounds_WEBP"
with tqdm(total=len(images[split]), desc=f"Copying images for {split} > {part}") as pbar:
for synset in classes:
makedirs(path.join(args.output_dir, split, part, synset), exist_ok=True)
if len(listdir(path.join(args.output_dir, split, part, synset))) >= ipc:
tqdm.write(
f"skip {split} > {part} > {synset} with"
f" {len(listdir(path.join(args.output_dir, split, part, synset)))} ims"
)
pbar.update(ipc)
continue
for img in listdir(path.join(args.in_segment_dir, split, part, synset)):
orig_name = (
img.split(".")[0] + ".JPEG"
if split == "train"
else f"val_{int(img.split('_')[-1].split('.')[0])}.JPEG"
)
if orig_name in images[split]:
# makedirs(path.join(args.output_dir, split, part, synset), exist_ok=True)
shutil.copy(
path.join(args.in_segment_dir, split, part, synset, img),
path.join(args.output_dir, split, part, synset, img),
)
pbar.update(1)
while len(listdir(path.join(args.output_dir, split, part, synset))) < min(
ipc, len(listdir(path.join(args.in_segment_dir, split, part, synset)))
):
# copy over more random images
image_names = [
(
img,
(
img.split(".")[0] + ".JPEG"
if split == "train"
else f"val_{int(img.split('_')[-1].split('.')[0])}.JPEG"
),
)
for img in listdir(path.join(args.in_segment_dir, split, part, synset))
]
image_names = [
img for img in image_names if img[1] not in listdir(path.join(args.output_dir, split, part, synset))
]
img = choice(image_names)[0]
# makedirs(path.join(args.output_dir, split, part, synset), exist_ok=True)
shutil.copy(
path.join(args.in_segment_dir, split, part, synset, img),
path.join(args.output_dir, split, part, synset, img),
)
pbar.update(1)
tqdm.write(f"Extra image: {orig_name} to {split}/{part}/{synset}")
# copy over the background images corresponding to those foregrounds
part = "backgrounds_JPEG"
with tqdm(total=len(images[split]), desc=f"Copying images for {split} > {part}") as pbar:
for synset in classes:
makedirs(path.join(args.output_dir, split, part, synset), exist_ok=True)
if len(listdir(path.join(args.output_dir, split, part, synset))) >= ipc:
tqdm.write(f"skip {split} > {part} > {synset}")
pbar.update(ipc)
continue
for img in listdir(path.join(args.output_dir, split, "foregrounds_WEBP", synset)):
bg_name = img.replace(".WEBP", ".JPEG")
# makedirs(path.join(args.output_dir, split, part, synset), exist_ok=True)
shutil.copy(
path.join(args.in_segment_dir, split, part, synset, bg_name),
path.join(args.output_dir, split, part, synset, bg_name),
)
pbar.update(1)
assert len(listdir(path.join(args.output_dir, split, part, synset))) == len(
listdir(path.join(args.output_dir, split, "foregrounds_WEBP", synset))
), (
f"Expected {len(listdir(path.join(args.output_dir, split, 'foregrounds_WEBP', synset)))} background"
f" images, found {len(listdir(path.join(args.output_dir, split, part, synset)))}"
)
# write the original dataset to datadings
for part in ["train", "val"]:
reader = MsgpackReader(path.join(args.imagenet_path, f"{part}.msgpack"))
with FileWriter(path.join(args.output_dir, f"TinyIN_{part}.msgpack")) as writer:
for data in tqdm(reader, desc=f"Writing {part} to datadings"):
key = data["key"].split("/")[-1]
allowed_synsets = [key.split("_")[0]] if part == "train" else classes
if part == "train" and allowed_synsets[0] not in classes:
continue
keep_image = False
label_synset = None
for synset in allowed_synsets:
for img in listdir(path.join(args.output_dir, part, "foregrounds_WEBP", synset)):
if img.split(".")[0] == key.split(".")[0]:
keep_image = True
label_synset = synset
break
if not keep_image:
continue
data["label"] = classes.index(label_synset)
writer.write(data)

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,48 @@
import argparse
import os
import json
import torch
parser = argparse.ArgumentParser("Script to convert ImageNet trained models to ImageNet-9")
parser.add_argument("-m", "--model", type=str, required=True, help="Model weights (.pt file).")
parser.add_argument(
"--in_to_in9", type=str, default="/ds-sds/images/ImageNet-9/in_to_in9.json", help="Path to in_to_in9.json"
)
args = parser.parse_args()
checkpoint = torch.load(args.model, map_location="cpu")
model_state = checkpoint["model_state"]
head_keys = [k for k in model_state.keys() if ".head." in k or ".fc." in k]
print("weights that will be modified:", head_keys)
assert len(head_keys) > 0, "no head keys found :("
with open(args.in_to_in9, "r") as f:
in_to_in9_classes = json.load(f)
print(f"{len([k for k, v in in_to_in9_classes.items() if v == -1])} classes get mapped to -1")
print("map", len(in_to_in9_classes), " classes to", set(in_to_in9_classes.values()))
print("Building conversion matrix")
conversion_matrix = torch.zeros((9, 1000))
for in_idx, in9_idx in in_to_in9_classes.items():
if in9_idx == -1:
continue
in_idx = int(in_idx)
conversion_matrix[in9_idx, in_idx] = 1
print(f"Conversion matrix ({conversion_matrix.shape}) has {int(conversion_matrix.sum().item())} non-zero values")
for head_key in head_keys:
print(f"converting {head_key} of shape {model_state[head_key].shape}", end=" ")
model_state[head_key] = conversion_matrix @ model_state[head_key]
print(f"\tto {model_state[head_key].shape}")
checkpoint["model_state"] = model_state
checkpoint["args"]["n_classes"] = 9
save_folder = os.path.dirname(args.model)
orig_model_name = args.model.split(os.sep)[-1]
new_model_name = ".".join(orig_model_name.split(".")[:-1]) + "_to_in9." + orig_model_name.split(".")[-1]
print(f"saving model as {new_model_name} in {save_folder}")
torch.save(checkpoint, os.path.join(save_folder, new_model_name))

View File

@@ -0,0 +1,200 @@
n07695742: pretzel
n03902125: pay-phone, pay-station
n03980874: poncho
n01644900: tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui
n02730930: apron
n02699494: altar
n03201208: dining table, board
n02056570: king penguin, Aptenodytes patagonica
n04099969: rocking chair, rocker
n04366367: suspension bridge
n04067472: reel
n02808440: bathtub, bathing tub, bath, tub
n04540053: volleyball
n02403003: ox
n03100240: convertible
n04562935: water tower
n02788148: bannister, banister, balustrade, balusters, handrail
n02988304: CD player
n02423022: gazelle
n03637318: lampshade, lamp shade
n01774384: black widow, Latrodectus mactans
n01768244: trilobite
n07614500: ice cream, icecream
n04254777: sock
n02085620: Chihuahua
n01443537: goldfish, Carassius auratus
n01629819: European fire salamander, Salamandra salamandra
n02099601: golden retriever
n02321529: sea cucumber, holothurian
n03837869: obelisk
n02002724: black stork, Ciconia nigra
n02841315: binoculars, field glasses, opera glasses
n04560804: water jug
n02364673: guinea pig, Cavia cobaya
n03706229: magnetic compass
n09256479: coral reef
n09332890: lakeside, lakeshore
n03544143: hourglass
n02124075: Egyptian cat
n02948072: candle, taper, wax light
n01950731: sea slug, nudibranch
n02791270: barbershop
n03179701: desk
n02190166: fly
n04275548: spider web, spider's web
n04417672: thatch, thatched roof
n03930313: picket fence, paling
n02236044: mantis, mantid
n03976657: pole
n01774750: tarantula
n04376876: syringe
n04133789: sandal
n02099712: Labrador retriever
n04532670: viaduct
n04487081: trolleybus, trolley coach, trackless trolley
n09428293: seashore, coast, seacoast, sea-coast
n03160309: dam, dike, dyke
n03250847: drumstick
n02843684: birdhouse
n07768694: pomegranate
n03670208: limousine, limo
n03085013: computer keyboard, keypad
n02892201: brass, memorial tablet, plaque
n02233338: cockroach, roach
n03649909: lawn mower, mower
n03388043: fountain
n02917067: bullet train, bullet
n02486410: baboon
n04596742: wok
n03255030: dumbbell
n03937543: pill bottle
n02113799: standard poodle
n03977966: police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria
n02906734: broom
n07920052: espresso
n01698640: American alligator, Alligator mississipiensis
n02123394: Persian cat
n03424325: gasmask, respirator, gas helmet
n02129165: lion, king of beasts, Panthera leo
n04008634: projectile, missile
n03042490: cliff dwelling
n02415577: bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis
n02815834: beaker
n02395406: hog, pig, grunter, squealer, Sus scrofa
n01784675: centipede
n03126707: crane
n04399382: teddy, teddy bear
n07875152: potpie
n03733131: maypole
n02802426: basketball
n03891332: parking meter
n01910747: jellyfish
n03838899: oboe, hautboy, hautbois
n03770439: miniskirt, mini
n02281406: sulphur butterfly, sulfur butterfly
n03970156: plunger, plumber's helper
n09246464: cliff, drop, drop-off
n02206856: bee
n02074367: dugong, Dugong dugon
n03584254: iPod
n04179913: sewing machine
n04328186: stopwatch, stop watch
n07583066: guacamole
n01917289: brain coral
n03447447: gondola
n02823428: beer bottle
n03854065: organ, pipe organ
n02793495: barn
n04285008: sports car, sport car
n02231487: walking stick, walkingstick, stick insect
n04465501: tractor
n02814860: beacon, lighthouse, beacon light, pharos
n02883205: bow tie, bow-tie, bowtie
n02165456: ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle
n04149813: scoreboard
n04023962: punching bag, punch bag, punching ball, punchball
n02226429: grasshopper, hopper
n02279972: monarch, monarch butterfly, milkweed butterfly, Danaus plexippus
n02669723: academic gown, academic robe, judge's robe
n04486054: triumphal arch
n04070727: refrigerator, icebox
n03444034: go-kart
n02666196: abacus
n01945685: slug
n04251144: snorkel
n03617480: kimono
n03599486: jinrikisha, ricksha, rickshaw
n02437312: Arabian camel, dromedary, Camelus dromedarius
n01984695: spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish
n04118538: rugby ball
n01770393: scorpion
n04356056: sunglasses, dark glasses, shades
n03804744: nail
n02132136: brown bear, bruin, Ursus arctos
n03400231: frying pan, frypan, skillet
n03983396: pop bottle, soda bottle
n07734744: mushroom
n02480495: orangutan, orang, orangutang, Pongo pygmaeus
n02410509: bison
n03404251: fur coat
n04456115: torch
n02123045: tabby, tabby cat
n03026506: Christmas stocking
n07715103: cauliflower
n04398044: teapot
n02927161: butcher shop, meat market
n07749582: lemon
n07615774: ice lolly, lolly, lollipop, popsicle
n02795169: barrel, cask
n04532106: vestment
n02837789: bikini, two-piece
n02814533: beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon
n04265275: space heater
n02481823: chimpanzee, chimp, Pan troglodytes
n02509815: lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens
n06596364: comic book
n01983481: American lobster, Northern lobster, Maine lobster, Homarus americanus
n02504458: African elephant, Loxodonta africana
n03014705: chest
n01944390: snail
n04146614: school bus
n01641577: bullfrog, Rana catesbeiana
n07720875: bell pepper
n02999410: chain
n01855672: goose
n02125311: cougar, puma, catamount, mountain lion, painter, panther, Felis concolor
n07753592: banana
n07871810: meat loaf, meatloaf
n04501370: turnstile
n04311004: steel arch bridge
n02977058: cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM
n04074963: remote control, remote
n03662601: lifeboat
n02106662: German shepherd, German shepherd dog, German police dog, alsatian
n03089624: confectionery, confectionary, candy store
n04259630: sombrero
n03393912: freight car
n04597913: wooden spoon
n07711569: mashed potato
n03355925: flagpole, flagstaff
n02963159: cardigan
n07579787: plate
n02950826: cannon
n01882714: koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus
n02094433: Yorkshire terrier
n02909870: bucket, pail
n02058221: albatross, mollymawk
n01742172: boa constrictor, Constrictor constrictor
n09193705: alp
n04371430: swimming trunks, bathing trunks
n07747607: orange
n03814639: neck brace
n04507155: umbrella
n02268443: dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk
n02769748: backpack, back pack, knapsack, packsack, rucksack, haversack
n03763968: military uniform
n07873807: pizza, pizza pie
n03992509: potter's wheel
n03796401: moving van
n12267677: acorn