import argparse import json import os from copy import copy from loguru import logger from nltk.corpus import wordnet as wn class bcolors: HEADER = "\033[95m" OKBLUE = "\033[94m" OKCYAN = "\033[96m" OKGREEN = "\033[92m" WARNING = "\033[93m" FAIL = "\033[91m" ENDC = "\033[0m" BOLD = "\033[1m" UNDERLINE = "\033[4m" def _lemmas_str(synset): return ", ".join([lemma.name() for lemma in synset.lemmas()]) class WNEntry: def __init__( self, name: str, id: int, lemmas: str, parent_id: int, depth: int = None, in_image_net: bool = False, child_ids: list = None, in_main_tree: bool = True, _n_images: int = 0, _description: str = None, _name: str = None, _pruned: bool = False, ): self.name = name self.id = id self.lemmas = lemmas self.parent_id = parent_id self.depth = depth self.in_image_net = in_image_net self.child_ids = child_ids self.in_main_tree = in_main_tree self._n_images = _n_images self._description = _description self._name = _name self._pruned = _pruned def __str__(self, tree=None, accumulate=True): start_symb = f"{bcolors.OKGREEN}+{bcolors.ENDC}" if self.in_image_net else f"{bcolors.FAIL}-{bcolors.ENDC}" n_ims = f"{self._n_images} of Σ {self.n_images(tree)}" if accumulate and tree is not None else self._n_images if self.child_ids is None or tree is None: return f"{start_symb}{self.name} ({self.id}) > {n_ims}" else: return f"{start_symb}{self.name} ({self.id}) > {n_ims}\n " + "\n ".join( ["\n ".join(tree.nodes[child_id].__str__(tree).split("\n")) for child_id in self.child_ids] ) def tree_diff(self, tree_1, tree_2): if tree_2[self.id]._n_images > tree_1[self.id]._n_images: start_symb = f"{bcolors.OKGREEN}+{bcolors.ENDC}" elif tree_2[self.id]._n_images < tree_1[self.id]._n_images: start_symb = f"{bcolors.FAIL}-{bcolors.ENDC}" else: start_symb = f"{bcolors.OKBLUE}={bcolors.ENDC}" n_ims = ( f"{tree_1[self.id]._n_images} + {tree_2[self.id]._n_images - tree_1[self.id]._n_images} of Σ" f" {tree_1[self.id].n_images(tree_2)}/{tree_2[self.id].n_images(tree_2)}" ) if self.child_ids is None: return f"{start_symb}{self.name} ({self.id}) > {n_ims}" return f"{start_symb}{self.name} ({self.id}) > {n_ims}\n " + "\n ".join( ["\n ".join(tree_1.nodes[child_id].tree_diff(tree_1, tree_2).split("\n")) for child_id in self.child_ids] ) def prune(self, tree): if self._pruned: return if self.child_ids is not None: for child_id in self.child_ids: tree[child_id].prune(tree) self._pruned = True parent_node = tree.nodes[self.parent_id] try: parent_node.child_ids.remove(self.id) except ValueError as e: print( f"Error removing {self.name} from" f" {parent_node.name} ({[tree[cid].name for cid in parent_node.child_ids]}): {e}" ) while parent_node._pruned: parent_node = tree.nodes[parent_node.parent_id] parent_node._n_images += self._n_images self._n_images = 0 @property def description(self): if not self._description: self._description = wn.synset_from_pos_and_offset("n", self.id).definition() return self._description @property def print_name(self): return self.name.split(".")[0] def get_branch(self, tree=None): if self.parent_id is None or tree is None: return self.print_name parent = tree.nodes[self.parent_id] return parent.get_branch(tree) + " > " + self.print_name def get_branch_list(self, tree): if self.parent_id is None: return [self] parent = tree.nodes[self.parent_id] return parent.get_branch_list(tree) + [self] def to_dict(self): return { "name": self.name, "id": self.id, "lemmas": self.lemmas, "parent_id": self.parent_id, "depth": self.depth, "in_image_net": self.in_image_net, "child_ids": self.child_ids, "in_main_tree": self.in_main_tree, "_n_images": self._n_images, "_description": self._description, "_name": self._name, "_pruned": self._pruned, } @classmethod def from_dict(cls, d): return cls(**d) def n_images(self, tree=None): if tree is None or self.child_ids is None or len(self.child_ids) == 0: return self._n_images return sum([tree.nodes[child_id].n_images(tree) for child_id in self.child_ids]) + self._n_images def n_children(self, tree=None): if self.child_ids is None: return 0 if tree is None or len(self.child_ids) == 0: return len(self.child_ids) return len(self.child_ids) + sum([tree.nodes[child_id].n_children(tree) for child_id in self.child_ids]) def get_examples(self, tree, n_examples=3): if self.child_ids is None or len(self.child_ids) == 0: return "" child_images = {child_id: tree.nodes[child_id].n_images(tree) for child_id in self.child_ids} max_images = max(child_images.values()) if max_images == 0: # go on number of child nodes child_images = {child_id: tree.nodes[child_id].n_children(tree) for child_id in self.child_ids} # sorted childids by number of images top_children = [ child_id for child_id, n_images in sorted(child_images.items(), key=lambda x: x[1], reverse=True) ] top_children = top_children[: min(n_examples, len(top_children))] return ", ".join( [f"{tree.nodes[child_id].print_name} ({tree.nodes[child_id].description})" for child_id in top_children] ) class WNTree: def __init__(self, root=1740, nodes=None): if isinstance(root, int): root_id = root root_synset = wn.synset_from_pos_and_offset("n", root) root_node = WNEntry( root_synset.name(), root_id, _lemmas_str(root_synset), parent_id=None, depth=0, ) else: assert isinstance(root, WNEntry) root_id = root.id root_node = root self.root = root_node self.nodes = {root_id: self.root} if nodes is None else nodes self.parentless = [] self.label_index = None self.pruned = set() def to_dict(self): return { "root": self.root.to_dict(), "nodes": {node_id: node.to_dict() for node_id, node in self.nodes.items()}, "parentless": self.parentless, "pruned": list(self.pruned), } def prune(self, min_images): pruned_nodes = set() # prune all nodes that have fewer than min_images below them for node_id, node in self.nodes.items(): if node.n_images(self) < min_images: pruned_nodes.add(node_id) node.prune(self) # prune all nodes that have fewer than min_images inside them, after all nodes below have been pruned node_stack = [self.root] node_idx = 0 while node_idx < len(node_stack): node = node_stack[node_idx] if node.child_ids is not None: for child_id in node.child_ids: child = self.nodes[child_id] node_stack.append(child) node_idx += 1 # now prune the stack from the bottom up for node in node_stack[::-1]: # only look at images of that class, not of additional children if node.n_images() < min_images: pruned_nodes.add(node.id) node.prune(self) self.pruned = pruned_nodes return pruned_nodes @classmethod def from_dict(cls, d): tree = cls() tree.root = WNEntry.from_dict(d["root"]) tree.nodes = {int(node_id): WNEntry.from_dict(node) for node_id, node in d["nodes"].items()} tree.parentless = d["parentless"] if "pruned" in d: tree.pruned = set(d["pruned"]) return tree def add_node(self, node_id, in_in=True): if node_id in self.nodes: self.nodes[node_id].in_image_net = in_in or self.nodes[node_id].in_image_net return synset = wn.synset_from_pos_and_offset("n", node_id) # print(f"adding node {synset.name()} with id {node_id}") hypernyms = synset.hypernyms() if len(hypernyms) == 0: parent_id = None self.parentless.append(node_id) main_tree = False print(f"--------- no hypernyms for {synset.name()} ({synset.offset()}) ------------") else: parent_id = synset.hypernyms()[0].offset() if parent_id not in self.nodes: self.add_node(parent_id, in_in=False) parent = self.nodes[parent_id] if parent.child_ids is None: parent.child_ids = [] parent.child_ids.append(node_id) main_tree = parent.in_main_tree depth = self.nodes[parent_id].depth + 1 if parent_id is not None else 0 node = WNEntry( synset.name(), node_id, _lemmas_str(synset), parent_id=parent_id, in_image_net=in_in, depth=depth, in_main_tree=main_tree, ) self.nodes[node_id] = node def __len__(self): return len(self.nodes) def image_net_len(self, only_main_tree=False): return sum([node.in_image_net for node in self.nodes.values() if node.in_main_tree or not only_main_tree]) def max_depth(self, only_main_tree=False): return max([node.depth for node in self.nodes.values() if node.in_main_tree or not only_main_tree]) def __str__(self): return ( f"WordNet Tree with {len(self)} nodes, {self.image_net_len()} in ImageNet21k;" f" {len(self.parentless)} parentless nodes:\n{self.root.__str__(tree=self)}\nParentless:\n" + "\n".join([self.nodes[node_id].__str__(tree=self) for node_id in self.parentless]) ) def save(self, path): with open(path, "w") as f: json.dump(self.to_dict(), f) @classmethod def load(cls, path): with open(path, "r") as f: tree_dict = json.load(f) return cls.from_dict(tree_dict) def subtree(self, node_id): if node_id not in self.nodes: return None node_queue = [self.nodes[node_id]] subtree_ids = set() while len(node_queue) > 0: node = node_queue.pop(0) subtree_ids.add(node.id) if node.child_ids is not None: node_queue += [self.nodes[child_id] for child_id in node.child_ids] subtree_nodes = {node_id: copy(self.nodes[node_id]) for node_id in subtree_ids} subtree_root = subtree_nodes[node_id] subtree_root.parent_id = None depth_diff = subtree_root.depth for node in subtree_nodes.values(): node.depth -= depth_diff return WNTree(root=subtree_root, nodes=subtree_nodes) def _make_label_index(self, include_merged=False): self.label_index = sorted( [ node_id for node_id, node in self.nodes.items() if node.n_images(self if include_merged else None) > 0 and not node._pruned ] ) def get_label(self, node_id): if self.label_index is None: self._make_label_index() while self.nodes[node_id]._pruned: node_id = self.nodes[node_id].parent_id return self.label_index.index(node_id) def n_labels(self): if self.label_index is None: self._make_label_index() return len(self.label_index) def __contains__(self, item): if isinstance(item, str): if item[0] == "n": item = int(item[1:]) else: return False if isinstance(item, int): return item in self.nodes if isinstance(item, WNEntry): return item.id in self.nodes return False def __getitem__(self, item): if isinstance(item, str) and item[0].startswith("n"): try: item = int(item[1:]) except ValueError: pass if isinstance(item, str) and ".n." in item: for node in self.nodes.values(): if item == node.name: return node raise KeyError(f"Item {item} not found in tree") if isinstance(item, int): return self.nodes[item] if isinstance(item, WNEntry): return self.nodes[item.id] raise KeyError(f"Item {item} not found in tree")