382 lines
13 KiB
Python
382 lines
13 KiB
Python
import argparse
|
|
import json
|
|
import os
|
|
from copy import copy
|
|
|
|
from loguru import logger
|
|
from nltk.corpus import wordnet as wn
|
|
|
|
|
|
class bcolors:
|
|
HEADER = "\033[95m"
|
|
OKBLUE = "\033[94m"
|
|
OKCYAN = "\033[96m"
|
|
OKGREEN = "\033[92m"
|
|
WARNING = "\033[93m"
|
|
FAIL = "\033[91m"
|
|
ENDC = "\033[0m"
|
|
BOLD = "\033[1m"
|
|
UNDERLINE = "\033[4m"
|
|
|
|
|
|
def _lemmas_str(synset):
|
|
return ", ".join([lemma.name() for lemma in synset.lemmas()])
|
|
|
|
|
|
class WNEntry:
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
id: int,
|
|
lemmas: str,
|
|
parent_id: int,
|
|
depth: int = None,
|
|
in_image_net: bool = False,
|
|
child_ids: list = None,
|
|
in_main_tree: bool = True,
|
|
_n_images: int = 0,
|
|
_description: str = None,
|
|
_name: str = None,
|
|
_pruned: bool = False,
|
|
):
|
|
self.name = name
|
|
self.id = id
|
|
self.lemmas = lemmas
|
|
self.parent_id = parent_id
|
|
self.depth = depth
|
|
self.in_image_net = in_image_net
|
|
self.child_ids = child_ids
|
|
self.in_main_tree = in_main_tree
|
|
self._n_images = _n_images
|
|
self._description = _description
|
|
self._name = _name
|
|
self._pruned = _pruned
|
|
|
|
def __str__(self, tree=None, accumulate=True):
|
|
start_symb = f"{bcolors.OKGREEN}+{bcolors.ENDC}" if self.in_image_net else f"{bcolors.FAIL}-{bcolors.ENDC}"
|
|
n_ims = f"{self._n_images} of Σ {self.n_images(tree)}" if accumulate and tree is not None else self._n_images
|
|
if self.child_ids is None or tree is None:
|
|
return f"{start_symb}{self.name} ({self.id}) > {n_ims}"
|
|
else:
|
|
return f"{start_symb}{self.name} ({self.id}) > {n_ims}\n " + "\n ".join(
|
|
["\n ".join(tree.nodes[child_id].__str__(tree).split("\n")) for child_id in self.child_ids]
|
|
)
|
|
|
|
def tree_diff(self, tree_1, tree_2):
|
|
if tree_2[self.id]._n_images > tree_1[self.id]._n_images:
|
|
start_symb = f"{bcolors.OKGREEN}+{bcolors.ENDC}"
|
|
elif tree_2[self.id]._n_images < tree_1[self.id]._n_images:
|
|
start_symb = f"{bcolors.FAIL}-{bcolors.ENDC}"
|
|
else:
|
|
start_symb = f"{bcolors.OKBLUE}={bcolors.ENDC}"
|
|
n_ims = (
|
|
f"{tree_1[self.id]._n_images} + {tree_2[self.id]._n_images - tree_1[self.id]._n_images} of Σ"
|
|
f" {tree_1[self.id].n_images(tree_2)}/{tree_2[self.id].n_images(tree_2)}"
|
|
)
|
|
|
|
if self.child_ids is None:
|
|
return f"{start_symb}{self.name} ({self.id}) > {n_ims}"
|
|
|
|
return f"{start_symb}{self.name} ({self.id}) > {n_ims}\n " + "\n ".join(
|
|
["\n ".join(tree_1.nodes[child_id].tree_diff(tree_1, tree_2).split("\n")) for child_id in self.child_ids]
|
|
)
|
|
|
|
def prune(self, tree):
|
|
if self._pruned:
|
|
return
|
|
|
|
if self.child_ids is not None:
|
|
for child_id in self.child_ids:
|
|
tree[child_id].prune(tree)
|
|
|
|
self._pruned = True
|
|
parent_node = tree.nodes[self.parent_id]
|
|
try:
|
|
parent_node.child_ids.remove(self.id)
|
|
except ValueError as e:
|
|
print(
|
|
f"Error removing {self.name} from"
|
|
f" {parent_node.name} ({[tree[cid].name for cid in parent_node.child_ids]}): {e}"
|
|
)
|
|
while parent_node._pruned:
|
|
parent_node = tree.nodes[parent_node.parent_id]
|
|
parent_node._n_images += self._n_images
|
|
self._n_images = 0
|
|
|
|
@property
|
|
def description(self):
|
|
if not self._description:
|
|
self._description = wn.synset_from_pos_and_offset("n", self.id).definition()
|
|
return self._description
|
|
|
|
@property
|
|
def print_name(self):
|
|
return self.name.split(".")[0]
|
|
|
|
def get_branch(self, tree=None):
|
|
if self.parent_id is None or tree is None:
|
|
return self.print_name
|
|
|
|
parent = tree.nodes[self.parent_id]
|
|
return parent.get_branch(tree) + " > " + self.print_name
|
|
|
|
def get_branch_list(self, tree):
|
|
if self.parent_id is None:
|
|
return [self]
|
|
parent = tree.nodes[self.parent_id]
|
|
return parent.get_branch_list(tree) + [self]
|
|
|
|
def to_dict(self):
|
|
return {
|
|
"name": self.name,
|
|
"id": self.id,
|
|
"lemmas": self.lemmas,
|
|
"parent_id": self.parent_id,
|
|
"depth": self.depth,
|
|
"in_image_net": self.in_image_net,
|
|
"child_ids": self.child_ids,
|
|
"in_main_tree": self.in_main_tree,
|
|
"_n_images": self._n_images,
|
|
"_description": self._description,
|
|
"_name": self._name,
|
|
"_pruned": self._pruned,
|
|
}
|
|
|
|
@classmethod
|
|
def from_dict(cls, d):
|
|
return cls(**d)
|
|
|
|
def n_images(self, tree=None):
|
|
if tree is None or self.child_ids is None or len(self.child_ids) == 0:
|
|
return self._n_images
|
|
return sum([tree.nodes[child_id].n_images(tree) for child_id in self.child_ids]) + self._n_images
|
|
|
|
def n_children(self, tree=None):
|
|
if self.child_ids is None:
|
|
return 0
|
|
if tree is None or len(self.child_ids) == 0:
|
|
return len(self.child_ids)
|
|
return len(self.child_ids) + sum([tree.nodes[child_id].n_children(tree) for child_id in self.child_ids])
|
|
|
|
def get_examples(self, tree, n_examples=3):
|
|
if self.child_ids is None or len(self.child_ids) == 0:
|
|
return ""
|
|
child_images = {child_id: tree.nodes[child_id].n_images(tree) for child_id in self.child_ids}
|
|
max_images = max(child_images.values())
|
|
if max_images == 0:
|
|
# go on number of child nodes
|
|
child_images = {child_id: tree.nodes[child_id].n_children(tree) for child_id in self.child_ids}
|
|
# sorted childids by number of images
|
|
top_children = [
|
|
child_id for child_id, n_images in sorted(child_images.items(), key=lambda x: x[1], reverse=True)
|
|
]
|
|
top_children = top_children[: min(n_examples, len(top_children))]
|
|
return ", ".join(
|
|
[f"{tree.nodes[child_id].print_name} ({tree.nodes[child_id].description})" for child_id in top_children]
|
|
)
|
|
|
|
|
|
class WNTree:
|
|
def __init__(self, root=1740, nodes=None):
|
|
if isinstance(root, int):
|
|
root_id = root
|
|
root_synset = wn.synset_from_pos_and_offset("n", root)
|
|
root_node = WNEntry(
|
|
root_synset.name(),
|
|
root_id,
|
|
_lemmas_str(root_synset),
|
|
parent_id=None,
|
|
depth=0,
|
|
)
|
|
else:
|
|
assert isinstance(root, WNEntry)
|
|
root_id = root.id
|
|
root_node = root
|
|
|
|
self.root = root_node
|
|
self.nodes = {root_id: self.root} if nodes is None else nodes
|
|
self.parentless = []
|
|
self.label_index = None
|
|
self.pruned = set()
|
|
|
|
def to_dict(self):
|
|
return {
|
|
"root": self.root.to_dict(),
|
|
"nodes": {node_id: node.to_dict() for node_id, node in self.nodes.items()},
|
|
"parentless": self.parentless,
|
|
"pruned": list(self.pruned),
|
|
}
|
|
|
|
def prune(self, min_images):
|
|
pruned_nodes = set()
|
|
|
|
# prune all nodes that have fewer than min_images below them
|
|
for node_id, node in self.nodes.items():
|
|
if node.n_images(self) < min_images:
|
|
pruned_nodes.add(node_id)
|
|
node.prune(self)
|
|
|
|
# prune all nodes that have fewer than min_images inside them, after all nodes below have been pruned
|
|
node_stack = [self.root]
|
|
node_idx = 0
|
|
while node_idx < len(node_stack):
|
|
node = node_stack[node_idx]
|
|
if node.child_ids is not None:
|
|
for child_id in node.child_ids:
|
|
child = self.nodes[child_id]
|
|
node_stack.append(child)
|
|
node_idx += 1
|
|
|
|
# now prune the stack from the bottom up
|
|
for node in node_stack[::-1]:
|
|
# only look at images of that class, not of additional children
|
|
if node.n_images() < min_images:
|
|
pruned_nodes.add(node.id)
|
|
node.prune(self)
|
|
|
|
self.pruned = pruned_nodes
|
|
return pruned_nodes
|
|
|
|
@classmethod
|
|
def from_dict(cls, d):
|
|
tree = cls()
|
|
tree.root = WNEntry.from_dict(d["root"])
|
|
tree.nodes = {int(node_id): WNEntry.from_dict(node) for node_id, node in d["nodes"].items()}
|
|
tree.parentless = d["parentless"]
|
|
if "pruned" in d:
|
|
tree.pruned = set(d["pruned"])
|
|
return tree
|
|
|
|
def add_node(self, node_id, in_in=True):
|
|
if node_id in self.nodes:
|
|
self.nodes[node_id].in_image_net = in_in or self.nodes[node_id].in_image_net
|
|
return
|
|
|
|
synset = wn.synset_from_pos_and_offset("n", node_id)
|
|
|
|
# print(f"adding node {synset.name()} with id {node_id}")
|
|
|
|
hypernyms = synset.hypernyms()
|
|
if len(hypernyms) == 0:
|
|
parent_id = None
|
|
self.parentless.append(node_id)
|
|
main_tree = False
|
|
print(f"--------- no hypernyms for {synset.name()} ({synset.offset()}) ------------")
|
|
else:
|
|
parent_id = synset.hypernyms()[0].offset()
|
|
if parent_id not in self.nodes:
|
|
self.add_node(parent_id, in_in=False)
|
|
parent = self.nodes[parent_id]
|
|
|
|
if parent.child_ids is None:
|
|
parent.child_ids = []
|
|
parent.child_ids.append(node_id)
|
|
main_tree = parent.in_main_tree
|
|
|
|
depth = self.nodes[parent_id].depth + 1 if parent_id is not None else 0
|
|
node = WNEntry(
|
|
synset.name(),
|
|
node_id,
|
|
_lemmas_str(synset),
|
|
parent_id=parent_id,
|
|
in_image_net=in_in,
|
|
depth=depth,
|
|
in_main_tree=main_tree,
|
|
)
|
|
|
|
self.nodes[node_id] = node
|
|
|
|
def __len__(self):
|
|
return len(self.nodes)
|
|
|
|
def image_net_len(self, only_main_tree=False):
|
|
return sum([node.in_image_net for node in self.nodes.values() if node.in_main_tree or not only_main_tree])
|
|
|
|
def max_depth(self, only_main_tree=False):
|
|
return max([node.depth for node in self.nodes.values() if node.in_main_tree or not only_main_tree])
|
|
|
|
def __str__(self):
|
|
return (
|
|
f"WordNet Tree with {len(self)} nodes, {self.image_net_len()} in ImageNet21k;"
|
|
f" {len(self.parentless)} parentless nodes:\n{self.root.__str__(tree=self)}\nParentless:\n"
|
|
+ "\n".join([self.nodes[node_id].__str__(tree=self) for node_id in self.parentless])
|
|
)
|
|
|
|
def save(self, path):
|
|
with open(path, "w") as f:
|
|
json.dump(self.to_dict(), f)
|
|
|
|
@classmethod
|
|
def load(cls, path):
|
|
with open(path, "r") as f:
|
|
tree_dict = json.load(f)
|
|
return cls.from_dict(tree_dict)
|
|
|
|
def subtree(self, node_id):
|
|
if node_id not in self.nodes:
|
|
return None
|
|
node_queue = [self.nodes[node_id]]
|
|
subtree_ids = set()
|
|
while len(node_queue) > 0:
|
|
node = node_queue.pop(0)
|
|
subtree_ids.add(node.id)
|
|
if node.child_ids is not None:
|
|
node_queue += [self.nodes[child_id] for child_id in node.child_ids]
|
|
subtree_nodes = {node_id: copy(self.nodes[node_id]) for node_id in subtree_ids}
|
|
subtree_root = subtree_nodes[node_id]
|
|
subtree_root.parent_id = None
|
|
depth_diff = subtree_root.depth
|
|
for node in subtree_nodes.values():
|
|
node.depth -= depth_diff
|
|
return WNTree(root=subtree_root, nodes=subtree_nodes)
|
|
|
|
def _make_label_index(self, include_merged=False):
|
|
self.label_index = sorted(
|
|
[
|
|
node_id
|
|
for node_id, node in self.nodes.items()
|
|
if node.n_images(self if include_merged else None) > 0 and not node._pruned
|
|
]
|
|
)
|
|
|
|
def get_label(self, node_id):
|
|
if self.label_index is None:
|
|
self._make_label_index()
|
|
while self.nodes[node_id]._pruned:
|
|
node_id = self.nodes[node_id].parent_id
|
|
return self.label_index.index(node_id)
|
|
|
|
def n_labels(self):
|
|
if self.label_index is None:
|
|
self._make_label_index()
|
|
return len(self.label_index)
|
|
|
|
def __contains__(self, item):
|
|
if isinstance(item, str):
|
|
if item[0] == "n":
|
|
item = int(item[1:])
|
|
else:
|
|
return False
|
|
if isinstance(item, int):
|
|
return item in self.nodes
|
|
if isinstance(item, WNEntry):
|
|
return item.id in self.nodes
|
|
return False
|
|
|
|
def __getitem__(self, item):
|
|
if isinstance(item, str) and item[0].startswith("n"):
|
|
try:
|
|
item = int(item[1:])
|
|
except ValueError:
|
|
pass
|
|
if isinstance(item, str) and ".n." in item:
|
|
for node in self.nodes.values():
|
|
if item == node.name:
|
|
return node
|
|
raise KeyError(f"Item {item} not found in tree")
|
|
if isinstance(item, int):
|
|
return self.nodes[item]
|
|
if isinstance(item, WNEntry):
|
|
return self.nodes[item.id]
|
|
raise KeyError(f"Item {item} not found in tree")
|