AAAI Version

This commit is contained in:
Tobias Christian Nauen
2026-02-24 12:22:44 +01:00
parent 5c08f9d31a
commit ff34712155
378 changed files with 19844 additions and 4780 deletions

View File

@@ -0,0 +1,400 @@
import argparse
import contextlib
import json
import os
from copy import copy
from nltk.corpus import wordnet as wn
class bcolors:
"""Colors for terminal output."""
HEADER = "\033[95m"
OKBLUE = "\033[94m"
OKCYAN = "\033[96m"
OKGREEN = "\033[92m"
WARNING = "\033[93m"
FAIL = "\033[91m"
ENDC = "\033[0m"
BOLD = "\033[1m"
UNDERLINE = "\033[4m"
def _lemmas_str(synset):
return ", ".join([lemma.name() for lemma in synset.lemmas()])
class WNEntry:
"""One wordnet synset."""
def __init__(
self,
name: str,
id: int,
lemmas: str,
parent_id: int,
depth: int = None,
in_image_net: bool = False,
child_ids: list = None,
in_main_tree: bool = True,
_n_images: int = 0,
_description: str = None,
_name: str = None,
_pruned: bool = False,
):
self.name = name
self.id = id
self.lemmas = lemmas
self.parent_id = parent_id
self.depth = depth
self.in_image_net = in_image_net
self.child_ids = child_ids
self.in_main_tree = in_main_tree
self._n_images = _n_images
self._description = _description
self._name = _name
self._pruned = _pruned
def __str__(self, tree=None, accumulate=True, colors=True, max_depth=0, max_children=None):
green = f"{bcolors.OKGREEN}" if colors else ""
red = f"{bcolors.FAIL}" if colors else ""
end = f"{bcolors.ENDC}" if colors else ""
start_symb = f"{green}+{end}" if self.in_image_net else f"{red}-{end}"
n_ims = f"{self._n_images} of Σ {self.n_images(tree)}" if accumulate and tree is not None else self._n_images
if self.child_ids is None or tree is None or max_depth == 0:
return f"{start_symb}{self.name} ({self.id}) > {n_ims}"
children = self.child_ids
if max_children is not None and len(children) > max_children:
children = children[:max_children]
return f"{start_symb}{self.name} ({self.id}) > {n_ims}\n " + "\n ".join(
[
"\n ".join(
tree.nodes[child_id]
.__str__(tree=tree, accumulate=accumulate, colors=colors, max_depth=max_depth - 1)
.split("\n")
)
for child_id in children
]
)
def tree_diff(self, tree_1, tree_2):
if tree_2[self.id]._n_images > tree_1[self.id]._n_images:
start_symb = f"{bcolors.OKGREEN}+{bcolors.ENDC}"
elif tree_2[self.id]._n_images < tree_1[self.id]._n_images:
start_symb = f"{bcolors.FAIL}-{bcolors.ENDC}"
else:
start_symb = f"{bcolors.OKBLUE}={bcolors.ENDC}"
n_ims = (
f"{tree_1[self.id]._n_images} + {tree_2[self.id]._n_images - tree_1[self.id]._n_images} of Σ"
f" {tree_1[self.id].n_images(tree_2)}/{tree_2[self.id].n_images(tree_2)}"
)
if self.child_ids is None:
return f"{start_symb}{self.name} ({self.id}) > {n_ims}"
return f"{start_symb}{self.name} ({self.id}) > {n_ims}\n " + "\n ".join(
["\n ".join(tree_1.nodes[child_id].tree_diff(tree_1, tree_2).split("\n")) for child_id in self.child_ids]
)
def prune(self, tree):
if self._pruned or self.parent_id is None:
return
if self.child_ids is not None:
for child_id in self.child_ids:
tree[child_id].prune(tree)
self._pruned = True
parent_node = tree.nodes[self.parent_id]
try:
parent_node.child_ids.remove(self.id)
except ValueError as e:
print(
f"Error removing {self.name} from"
f" {parent_node.name} ({[tree[cid].name for cid in parent_node.child_ids]}): {e}"
)
while parent_node._pruned:
parent_node = tree.nodes[parent_node.parent_id]
parent_node._n_images += self._n_images
self._n_images = 0
@property
def description(self):
if not self._description:
self._description = wn.synset_from_pos_and_offset("n", self.id).definition()
return self._description
@property
def print_name(self):
return self.name.split(".")[0]
@property
def is_leaf(self):
return self.child_ids is None or len(self.child_ids) == 0
def get_branch(self, tree=None):
if self.parent_id is None or tree is None:
return self.print_name
parent = tree.nodes[self.parent_id]
return parent.get_branch(tree) + " > " + self.print_name
def get_branch_list(self, tree):
if self.parent_id is None:
return [self]
parent = tree.nodes[self.parent_id]
return parent.get_branch_list(tree) + [self]
def to_dict(self):
return {
"name": self.name,
"id": self.id,
"lemmas": self.lemmas,
"parent_id": self.parent_id,
"depth": self.depth,
"in_image_net": self.in_image_net,
"child_ids": self.child_ids,
"in_main_tree": self.in_main_tree,
"_n_images": self._n_images,
"_description": self._description,
"_name": self._name,
"_pruned": self._pruned,
}
@classmethod
def from_dict(cls, d):
return cls(**d)
def n_images(self, tree=None):
if tree is None or self.child_ids is None or len(self.child_ids) == 0:
return self._n_images
return sum([tree.nodes[child_id].n_images(tree) for child_id in self.child_ids]) + self._n_images
def n_children(self, tree=None):
if self.child_ids is None:
return 0
if tree is None or len(self.child_ids) == 0:
return len(self.child_ids)
return len(self.child_ids) + sum([tree.nodes[child_id].n_children(tree) for child_id in self.child_ids])
def get_examples(self, tree, n_examples=3):
if self.child_ids is None or len(self.child_ids) == 0:
return ""
child_images = {child_id: tree.nodes[child_id].n_images(tree) for child_id in self.child_ids}
max_images = max(child_images.values())
if max_images == 0:
# go on number of child nodes
child_images = {child_id: tree.nodes[child_id].n_children(tree) for child_id in self.child_ids}
# sorted childids by number of images
top_children = [
child_id for child_id, n_images in sorted(child_images.items(), key=lambda x: x[1], reverse=True)
]
top_children = top_children[: min(n_examples, len(top_children))]
return ", ".join(
[f"{tree.nodes[child_id].print_name} ({tree.nodes[child_id].description})" for child_id in top_children]
)
class WNTree:
def __init__(self, root=1740, nodes=None):
if isinstance(root, int):
root_id = root
root_synset = wn.synset_from_pos_and_offset("n", root)
root_node = WNEntry(
root_synset.name(),
root_id,
_lemmas_str(root_synset),
parent_id=None,
depth=0,
)
else:
assert isinstance(root, WNEntry)
root_id = root.id
root_node = root
self.root = root_node
self.nodes = {root_id: self.root} if nodes is None else nodes
self.parentless = []
self.label_index = None
self.pruned = set()
def to_dict(self):
return {
"root": self.root.to_dict(),
"nodes": {node_id: node.to_dict() for node_id, node in self.nodes.items()},
"parentless": self.parentless,
"pruned": list(self.pruned),
}
def prune(self, min_images):
pruned_nodes = set()
# prune all nodes that have fewer than min_images below them
for node_id, node in self.nodes.items():
if node.n_images(self) < min_images:
pruned_nodes.add(node_id)
node.prune(self)
# prune all nodes that have fewer than min_images inside them, after all nodes below have been pruned
node_stack = [self.root]
node_idx = 0
while node_idx < len(node_stack):
node = node_stack[node_idx]
if node.child_ids is not None:
for child_id in node.child_ids:
child = self.nodes[child_id]
node_stack.append(child)
node_idx += 1
# now prune the stack from the bottom up
for node in node_stack[::-1]:
# only look at images of that class, not of additional children
if node.n_images() < min_images:
pruned_nodes.add(node.id)
node.prune(self)
self.pruned = pruned_nodes
return pruned_nodes
@classmethod
def from_dict(cls, d):
tree = cls()
tree.root = WNEntry.from_dict(d["root"])
tree.nodes = {int(node_id): WNEntry.from_dict(node) for node_id, node in d["nodes"].items()}
tree.parentless = d["parentless"]
if "pruned" in d:
tree.pruned = set(d["pruned"])
return tree
def add_node(self, node_id, in_in=True):
if node_id in self.nodes:
self.nodes[node_id].in_image_net = in_in or self.nodes[node_id].in_image_net
return
synset = wn.synset_from_pos_and_offset("n", node_id)
# print(f"adding node {synset.name()} with id {node_id}")
hypernyms = synset.hypernyms()
if len(hypernyms) == 0:
parent_id = None
self.parentless.append(node_id)
main_tree = False
print(f"--------- no hypernyms for {synset.name()} ({synset.offset()}) ------------")
else:
parent_id = synset.hypernyms()[0].offset()
if parent_id not in self.nodes:
self.add_node(parent_id, in_in=False)
parent = self.nodes[parent_id]
if parent.child_ids is None:
parent.child_ids = []
parent.child_ids.append(node_id)
main_tree = parent.in_main_tree
depth = self.nodes[parent_id].depth + 1 if parent_id is not None else 0
node = WNEntry(
synset.name(),
node_id,
_lemmas_str(synset),
parent_id=parent_id,
in_image_net=in_in,
depth=depth,
in_main_tree=main_tree,
)
self.nodes[node_id] = node
def __len__(self):
return len(self.nodes)
def image_net_len(self, only_main_tree=False):
return sum([node.in_image_net for node in self.nodes.values() if node.in_main_tree or not only_main_tree])
def max_depth(self, only_main_tree=False):
return max([node.depth for node in self.nodes.values() if node.in_main_tree or not only_main_tree])
def __str__(self, colors=True):
return (
f"WordNet Tree with {len(self)} nodes, {self.image_net_len()} in ImageNet21k;"
f" {len(self.parentless)} parentless nodes:\n{self.root.__str__(tree=self, colors=colors)}\nParentless:\n"
+ "\n".join([self.nodes[node_id].__str__(tree=self, colors=colors) for node_id in self.parentless])
)
def save(self, path):
with open(path, "w") as f:
json.dump(self.to_dict(), f)
@classmethod
def load(cls, path):
with open(path, "r") as f:
tree_dict = json.load(f)
return cls.from_dict(tree_dict)
def subtree(self, node):
node_id = self[node].id
if node_id not in self.nodes:
return None
node_queue = [self.nodes[node_id]]
subtree_ids = set()
while len(node_queue) > 0:
node = node_queue.pop(0)
subtree_ids.add(node.id)
if node.child_ids is not None:
node_queue += [self.nodes[child_id] for child_id in node.child_ids]
subtree_nodes = {node_id: copy(self.nodes[node_id]) for node_id in subtree_ids}
subtree_root = subtree_nodes[node_id]
subtree_root.parent_id = None
depth_diff = subtree_root.depth
for node in subtree_nodes.values():
node.depth -= depth_diff
return WNTree(root=subtree_root, nodes=subtree_nodes)
def _make_label_index(self):
self.label_index = sorted(
[node_id for node_id, node in self.nodes.items() if node.n_images(self) > 0 and not node._pruned]
)
def get_label(self, node_id):
if self.label_index is None:
self._make_label_index()
while self.nodes[node_id]._pruned:
node_id = self.nodes[node_id].parent_id
return self.label_index.index(node_id)
def n_labels(self):
if self.label_index is None:
self._make_label_index()
return len(self.label_index)
def __contains__(self, item):
if isinstance(item, str):
if item[0] == "n":
item = int(item[1:])
else:
return False
if isinstance(item, int):
return item in self.nodes
if isinstance(item, WNEntry):
return item.id in self.nodes
return False
def __getitem__(self, item):
if isinstance(item, str) and item[0].startswith("n"):
with contextlib.suppress(ValueError):
item = int(item[1:])
if isinstance(item, str) and ".n." in item:
for node in self.nodes.values():
if item == node.name:
return node
raise KeyError(f"Item {item} not found in tree")
if isinstance(item, int):
return self.nodes[item]
if isinstance(item, WNEntry):
return self.nodes[item.id]
raise KeyError(f"Item {item} not found in tree")
def __iter__(self):
return iter(self.nodes.keys())