"""Module to evaluate trained models.""" import os from contextlib import nullcontext from datetime import datetime from math import sqrt from time import time import torch from loguru import logger from timm.loss import LabelSmoothingCrossEntropy from timm.models.resnet import ResNet as TimmResNet from torch import distributed as dist from torch.nn import functional as F from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader from tqdm import tqdm from engine import ( _evaluate, setup_criteria_mixup, setup_model_optim_sched_scaler, setup_tracking_and_logging, wandb_available, ) from load_dataset import prepare_dataset from metrics import calculate_metrics from models import load_pretrained from utils import ( RepeatedDataset, ddp_cleanup, ddp_setup, denormalize, get_cpu_name, grad_cam_reshape_transform, prep_kwargs, set_filter_warnings, ) def evaluate_metrics(model, dataset, **kwargs): """Evaluate efficiency metrics for a given model. Args: model (str): path to model state .tar dataset (str): name of the dataset to evaluate on **kwargs: further arguments """ set_filter_warnings() model_path = model args = prep_kwargs(kwargs) if args.cuda: args.distributed, device, world_size, rank, _ = ddp_setup() torch.cuda.set_device(device) else: args.distributed = False device = torch.device("cpu") rank = 0 args.compile_model = False save_state = torch.load(model_path, map_location="cpu") old_args = prep_kwargs(save_state["args"]) args.model = old_args.model args.dataset = dataset args.run_name = old_args.run_name args.experiment_name = old_args.experiment_name args.wandb_run_id = old_args.wandb_run_id setup_tracking_and_logging(args, rank=rank, append_model_path=model_path, log_wandb=args.wandb_run_id is not None) train_loader, args.n_classes, args.ignore_index, args.multi_label, _ = prepare_dataset(dataset, args) model, args, old_args, save_state = load_pretrained(model, args, new_dataset_params=True) old_args["eval_imsize"] = args.imsize args.model = model_name = old_args.model args.dataset = dataset args.epochs = 5 model, optim, _, scaler = setup_model_optim_sched_scaler(model, device, epochs=10, args=args, head_only=False) if rank == 0: logger.info( f"Evaluate metrics for model {model_name} on {dataset}. " f"It was {old_args.task.replace('-','')}d on {old_args.dataset} for {save_state['epoch']} " "epochs." ) # logger.info(f"full set of arguments: {args}") logger.info(f"full set of training arguments: {old_args}") logger.info(f"full set of eval-metrics arguments: {args}") logger.info( f"evaluating on {device} -> {torch.cuda.get_device_name(device) if device.type != 'cpu' else get_cpu_name()}" ) metrics = calculate_metrics( args, model, rank=rank, device=device, optim=optim, scaler=scaler, train_loader=train_loader, key_start="eval/" ) if rank == 0: logger.info(f"Metrics: {metrics}") if wandb_available(): import wandb wandb.log(metrics) def evaluate(model, dataset=None, val_dataset=None, **kwargs): """Evaluate model accuracy. Args: model (str): path to model state .tar dataset (str, optional): name of the dataset to evaluate on (Default value = None) val_dataset (str, optional): name of the dataset to evaluate on (Default value = None) **kwargs: further arguments Note: If `val_dataset` is not provided, the model will be evaluated on `dataset`. """ set_filter_warnings() model_path = model args = prep_kwargs(kwargs) if val_dataset is None: val_dataset = dataset args.dataset = dataset args.val_dataset = val_dataset if args.cuda: args.distributed, device, world_size, rank, _ = ddp_setup() torch.cuda.set_device(device) else: args.distributed = False device = torch.device("cpu") world_size = 1 rank = 0 args.compile_model = False args.batch_size = int(args.batch_size / world_size) save_state = torch.load(model_path, map_location="cpu") old_args = prep_kwargs(save_state["args"]) args.model = old_args.model args.dataset = dataset args.run_name = old_args.run_name args.experiment_name = old_args.experiment_name args.wandb_run_id = old_args.wandb_run_id run_folder = setup_tracking_and_logging( args, rank, append_model_path=model_path, log_wandb=args.wandb_run_id is not None ) val_loader, args.n_classes, args.ignore_index, args.multi_label, dali_server = prepare_dataset( val_dataset, args, train=False ) model, args, old_args, save_state = load_pretrained(model, args, new_dataset_params=True) model = model.to(device) args.model = model_name = old_args.model args.dataset = dataset if rank == 0: logger.info( f"Evaluate model {model_name} on {val_dataset}. " f"It was pretrained on {old_args.dataset} for {save_state['epoch']} epochs." ) if args.distributed: model = DDP(model) if args.compile_model: model = torch.compile(model) # log all devices logger.info( f"evaluating on {device} -> {torch.cuda.get_device_name(device) if device.type != 'cpu' else get_cpu_name()}" ) if rank == 0: logger.info(f"torch version {torch.__version__}") logger.info(f"full set of arguments: {args}") logger.info(f"full set of old arguments: {old_args}") if args.seed: torch.manual_seed(args.seed) val_criterion = LabelSmoothingCrossEntropy(smoothing=args.label_smoothing) if rank == 0: logger.info("start evaluation") logger.info(f"Run info at: '{run_folder}'") if rank == 0: val_time, val_stats = _evaluate( model.to(device), val_loader, epoch=save_state["epoch"] - 1, rank=rank, device=device, val_criterion=val_criterion, args=args, dali_server=dali_server, acc_dict_key=f"eval_{val_dataset}/acc{{}}", ) log_s = f"Evaluation done in {val_time}s" for key, val in val_stats.items(): log_s += f", {key}={val:.4f}" logger.info(log_s) if wandb_available(): import wandb wandb.log(val_stats) else: _evaluate( model.to(device), val_loader, epoch=save_state["epoch"] - 1, rank=rank, device=device, val_criterion=val_criterion, args=args, dali_server=dali_server, acc_dict_key=f"eval_{val_dataset}/acc{{}}", ) ddp_cleanup(args=args, rank=rank) def evaluate_center_bias(model, dataset=None, val_dataset=None, **kwargs): """Evaluate model accuracy in different nonants. Args: model (str): path to model state .tar dataset (str, optional): name of the dataset to evaluate on (Default value = None) val_dataset (str, optional): name of the dataset to evaluate on (Default value = None) **kwargs: further arguments Note: If `val_dataset` is not provided, the model will be evaluated on `dataset`. """ set_filter_warnings() model_path = model args = prep_kwargs(kwargs) if val_dataset is None: val_dataset = dataset if dataset is None: dataset = val_dataset assert dataset is not None, "Specify validation dataset (-valds) or dataset (-ds)." args.dataset = dataset args.val_dataset = val_dataset if args.cuda: args.distributed, device, world_size, rank, _ = ddp_setup() torch.cuda.set_device(device) else: args.distributed = False device = torch.device("cpu") world_size = 1 rank = 0 args.compile_model = False args.batch_size = int(args.batch_size / world_size) save_state = torch.load(model_path, map_location="cpu") old_args = prep_kwargs(save_state["args"]) args.model = old_args.model args.dataset = dataset args.run_name = old_args.run_name args.experiment_name = old_args.experiment_name args.wandb_run_id = old_args.wandb_run_id run_folder = setup_tracking_and_logging( args, rank, append_model_path=model_path, log_wandb=args.wandb_run_id is not None ) assert "fornet" in val_dataset.lower(), "Only ForNet supported for center bias evaluation." _, args.n_classes, args.ignore_index, args.multi_label, __ = prepare_dataset(val_dataset, args, train=False) model, args, old_args, save_state = load_pretrained(model, args, new_dataset_params=True) model = model.to(device) args.model = model_name = old_args.model args.dataset = dataset if rank == 0: logger.info( f"Evaluate model {model_name} on {val_dataset}. " f"It was pretrained on {old_args.dataset} for {save_state['epoch']} epochs." ) if args.distributed: model = DDP(model) if args.compile_model: model = torch.compile(model) # log all devices logger.info( f"evaluating on {device} -> {torch.cuda.get_device_name(device) if device.type != 'cpu' else get_cpu_name()}" ) if rank == 0: logger.info(f"torch version {torch.__version__}") logger.info(f"full set of arguments: {args}") logger.info(f"full set of old arguments: {old_args}") if args.seed: torch.manual_seed(args.seed) val_criterion = LabelSmoothingCrossEntropy(smoothing=args.label_smoothing) if rank == 0: logger.info("start evaluation") logger.info(f"Run info at: '{run_folder}'") if rank == 0: nonant_accs = [] for nonant in range(-1, 9): val_loader, _, __, ___, dali_server = prepare_dataset(val_dataset, args, train=False) val_loader.dataset.fg_in_nonant = nonant logger.info(f"Evaluate nonant {nonant} for 5 rounds.") round_accs = [] for _ in range(5): val_time, val_stats = _evaluate( model.to(device), val_loader, epoch=save_state["epoch"] - 1, rank=rank, device=device, val_criterion=val_criterion, args=args, dali_server=dali_server, ) round_accs.append(val_stats["acc1"]) nonant_accs.append(sum(round_accs) / len(round_accs)) log_s = f"Evaluation done in {val_time}s: " for nonant, val in enumerate(nonant_accs[1:]): log_s += f", nonant {nonant}={val}% acc ({val / nonant_accs[0]} rel acc)" center_bias_val = 1 - ( min([nonant_accs[1], nonant_accs[3], nonant_accs[7], nonant_accs[9]]) + min([nonant_accs[2], nonant_accs[4], nonant_accs[6], nonant_accs[8]]) ) / (2 * nonant_accs[5]) log_s += f", center_bias={center_bias_val:.4f}" logger.info(log_s) if wandb_available(): import wandb wandb.log({f"eval_{args.val_dataset}/center_bias": center_bias_val}) else: raise NotImplementedError("Center bias evaluation not supported in distributed mode.") ddp_cleanup(args=args, rank=rank) def evaluate_size_bias(model, dataset=None, val_dataset=None, **kwargs): """Evaluate model accuracy for differently scaled foregrounds. Args: model (str): path to model state .tar dataset (str, optional): name of the dataset to evaluate on (Default value = None) val_dataset (str, optional): name of the dataset to evaluate on (Default value = None) **kwargs: further arguments Note: If `val_dataset` is not provided, the model will be evaluated on `dataset`. """ set_filter_warnings() model_path = model args = prep_kwargs(kwargs) if val_dataset is None: val_dataset = dataset if dataset is None: dataset = val_dataset assert val_dataset is not None and dataset is not None args.dataset = dataset args.val_dataset = val_dataset if args.cuda: args.distributed, device, world_size, rank, _ = ddp_setup() torch.cuda.set_device(device) else: args.distributed = False device = torch.device("cpu") world_size = 1 rank = 0 args.compile_model = False args.batch_size = int(args.batch_size / world_size) save_state = torch.load(model_path, map_location="cpu") old_args = prep_kwargs(save_state["args"]) args.model = old_args.model args.dataset = dataset args.run_name = old_args.run_name args.experiment_name = old_args.experiment_name args.wandb_run_id = old_args.wandb_run_id run_folder = setup_tracking_and_logging(args, rank, append_model_path=model_path, log_wandb=False) assert "fornet" in val_dataset.lower(), "Only ForNet supported for center bias evaluation." _, args.n_classes, args.ignore_index, args.multi_label, __ = prepare_dataset(val_dataset, args, train=False) model, args, old_args, save_state = load_pretrained(model, args, new_dataset_params=True) model = model.to(device) args.model = model_name = old_args.model args.dataset = dataset if rank == 0: logger.info( f"Evaluate model {model_name} on {val_dataset}. " f"It was pretrained on {old_args.dataset} for {save_state['epoch']} epochs." ) if args.distributed: model = DDP(model) if args.compile_model: model = torch.compile(model) # log all devices logger.info( f"evaluating on {device} -> {torch.cuda.get_device_name(device) if device.type != 'cpu' else get_cpu_name()}" ) if rank == 0: logger.info(f"torch version {torch.__version__}") logger.info(f"full set of arguments: {args}") logger.info(f"full set of old arguments: {old_args}") if args.seed: torch.manual_seed(args.seed) val_criterion = LabelSmoothingCrossEntropy(smoothing=args.label_smoothing) if rank == 0: logger.info("start evaluation") logger.info(f"Run info at: '{run_folder}'") sizes = [0.1, 0.2, 0.4, 0.6, 0.8, 1.0, 1.25, 1.5, 1.75, 2.0] if rank == 0: size_accs = [] val_times = 0 for size in sizes: val_loader, _, __, ___, dali_server = prepare_dataset(val_dataset, args, train=False) val_loader.dataset.size_fact = size val_loader.dataset.fg_scale_jitter = 0.0 logger.info(f"Evaluate size factor {size} for 5 rounds.") round_accs = [] for _ in range(5): val_time, val_stats = _evaluate( model.to(device), val_loader, epoch=save_state["epoch"] - 1, rank=rank, device=device, val_criterion=val_criterion, args=args, dali_server=dali_server, ) round_accs.append(val_stats["acc1"]) val_times += val_time size_accs.append(sum(round_accs) / len(round_accs)) log_s = f"Evaluation done in {val_times}s: " for size, val in zip(sizes, size_accs): log_s += f", rel_size {size}={val}% acc ({val / size_accs[sizes.index(1.0)]} rel acc)" logger.info(log_s) else: raise NotImplementedError("Center bias evaluation not supported in distributed mode.") ddp_cleanup(args=args, rank=rank) def evaluate_attributions(model, dataset=None, val_dataset=None, **kwargs): """Evaluate model attributions using captum. Args: model (str): path to model state .tar dataset (str, optional): name of the dataset to evaluate on (Default value = None) val_dataset (str, optional): name of the dataset to evaluate on (Default value = None) **kwargs: further arguments Note: If `val_dataset` is not provided, the model will be evaluated on `dataset`. The `captum` package is required. """ from captum.attr import IntegratedGradients from pytorch_grad_cam import GradCAM, GradCAMPlusPlus from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget set_filter_warnings() model_path = model args = prep_kwargs(kwargs) if val_dataset is None: val_dataset = dataset assert val_dataset is not None, "Please set dataset (-ds) or validation dataset (-valds)" args.dataset = val_dataset args.val_dataset = val_dataset if args.cuda: args.distributed, device, world_size, rank, _ = ddp_setup() torch.cuda.set_device(device) else: args.distributed = False device = torch.device("cpu") world_size = 1 rank = 0 args.compile_model = False args.batch_size = int(args.batch_size / world_size) save_state = torch.load(model_path, map_location="cpu") old_args = prep_kwargs(save_state["args"]) args.model = old_args.model args.dataset = val_dataset args.run_name = old_args.run_name args.experiment_name = old_args.experiment_name args.wandb_run_id = old_args.wandb_run_id run_folder = setup_tracking_and_logging( args, rank, append_model_path=model_path, log_wandb=args.wandb_run_id is not None ) assert "fornet" in val_dataset.lower(), "Only ForNet supported for attribution evaluation." val_loader, args.n_classes, args.ignore_index, args.multi_label, dali_server = prepare_dataset( val_dataset, args, train=False ) val_loader.dataset.return_fg_masks = True model, args, old_args, save_state = load_pretrained(model, args, new_dataset_params=True) model = model.to(device) args.model = model_name = old_args.model args.dataset = dataset # assert ( # args.imsize == old_args.imsize # ), f"Model was trained on {old_args.imsize}x{old_args.imsize} images. Not {args.imsize}x{args.imsize}." epoch = save_state["epoch"] if rank == 0: logger.info( f"Evaluate attributions of model {model_name} on {dataset}. " f"It was pretrained on {old_args.dataset} for {epoch} epochs." ) if args.distributed: model = DDP(model) if args.compile_model: model = torch.compile(model) # log all devices logger.info( f"evaluating on {device} -> {torch.cuda.get_device_name(device) if device.type != 'cpu' else get_cpu_name()}" ) if rank == 0: logger.info(f"torch version {torch.__version__}") if args.new_log: logger.info(f"full set of arguments: {args}") logger.info(f"full set of old arguments: {old_args}") else: logger.info(f"full set of attribution evaluation arguments: {old_args}") if args.seed: torch.manual_seed(args.seed) if rank == 0: logger.info(f"Run info at: '{run_folder}'") iterator = ( tqdm(val_loader, total=len(val_loader), desc=f"Validating epoch {epoch}") if rank == 0 and args.tqdm else val_loader ) if args.debug: from matplotlib import pyplot as plt eval_attn_importance = False if isinstance(model, TimmResNet): reshape_transform = None target_layers = [model.layer4[-1]] elif model_name.lower().startswith("vit-"): reshape_transform = grad_cam_reshape_transform target_layers = [model.blocks[-1].norm1] eval_attn_importance = True from architectures.vit import _MatrixSaveAttn model.blocks[-1].attn = _MatrixSaveAttn.cast(model.blocks[-1].attn) elif model_name.lower().startswith("swin_"): reshape_transform = grad_cam_reshape_transform target_layers = [model.layers[-1].blocks[-1].norm1] else: raise NotImplementedError(f"Model {model_name} not supported for attribution evaluation.") model.eval() val_start = time() rel_ig_weights = 0.0 rel_attn_weights = 0.0 rel_cam_weights = {"GradCAM": 0.0, "GradCAM++": 0.0} if rank == 0: logger.info("Start attribution evaluation") if dali_server: dali_server.start_thread() for batch_data in iterator: xs, ys, fg_masks = batch_data xs, ys, fg_masks = ( xs.to(device, non_blocking=True), ys.to(device, non_blocking=True), fg_masks.float().to(device, non_blocking=True), ) with torch.amp.autocast("cuda") if args.eval_amp else nullcontext(): model.zero_grad() ig = IntegratedGradients(model) # we use attention temperature of 10 to make differences more apparent after exp attr_ig = ( ig.attribute(xs, target=ys, baselines=0.0, internal_batch_size=args.batch_size * 4).sum(dim=1) * 10 ) # B x W x H attr_probs = attr_ig.view(xs.shape[0], -1).softmax(dim=-1).view(xs.shape[0], *xs.shape[2:]) fg_masks = fg_masks.view(attr_probs.shape) fg_attrs = (attr_probs * fg_masks).sum(dim=(-1, -2)) rel_attr_weight = fg_attrs / fg_masks.mean(dim=(-1, -2)) rel_attr_weight = torch.where(fg_masks.mean(dim=(-1, -2)) > 0, rel_attr_weight, 1.0) if rel_attr_weight.isnan().any(): logger.error(f"NaNs in rel_attr_weight: {rel_attr_weight}, fg_mask_weights: {fg_masks.mean(dim=(-1, -2))}") break rel_ig_weights += rel_attr_weight.mean().item() cam_targets = [ClassifierOutputTarget(int(trgt)) for trgt in ys.tolist()] for method, name in zip([GradCAM, GradCAMPlusPlus], ["GradCAM", "GradCAM++"]): with method(model=model, target_layers=target_layers, reshape_transform=reshape_transform) as cam, ( torch.amp.autocast("cuda") if args.eval_amp else nullcontext() ): cam_attr = cam(input_tensor=xs, targets=cam_targets) cam_attr = torch.from_numpy(cam_attr).to(device) rel_cam_attr = (cam_attr * fg_masks).sum(dim=(-1, -2)) / cam_attr.sum(dim=(-1, -2)) cam_attr_weight = rel_cam_attr / fg_masks.mean(dim=(-1, -2)) cam_attr_weight = torch.where( (fg_masks.mean(dim=(-1, -2)) > 0) & (cam_attr.sum(dim=(-1, -2)) > 0), cam_attr_weight, 1.0 ) rel_cam_weights[name] += cam_attr_weight.mean().item() if cam_attr_weight.isnan().any(): logger.error( f"NaNs in cam_attr_weight ({name}): {cam_attr_weight}, fg_mask_weights:" f" {fg_masks.mean(dim=(-1, -2))}" ) break if eval_attn_importance: with torch.amp.autocast("cuda") if args.eval_amp else nullcontext(): pred = model(xs) # noqa: F841 last_attn_mat = model.blocks[-1].attn.attn_mat cls_tkn_attn = last_attn_mat[:, :, 0, 1:].mean(dim=1).squeeze(dim=1) # B x H x 1(CLS Token) X N -> B x N B, N = cls_tkn_attn.shape att_HW = int(sqrt(N)) cls_tkn_attn = cls_tkn_attn.view(B, 1, att_HW, att_HW) attn_attr = F.interpolate( cls_tkn_attn, size=(xs.shape[-2], xs.shape[-1]), mode="bilinear", align_corners=False ).view(B, xs.shape[-2], xs.shape[-1]) rel_attn_attr = (attn_attr * fg_masks).sum(dim=(-1, -2)) / attn_attr.sum(dim=(-1, -2)) attn_attr_weight = rel_attn_attr / fg_masks.mean(dim=(-1, -2)) attn_attr_weight = torch.where( (fg_masks.mean(dim=(-1, -2)) > 0) & (attn_attr.sum(dim=(-1, -2)) > 0), attn_attr_weight, 1.0 ) rel_attn_weights += attn_attr_weight.mean().item() if args.debug: logger.debug(f"Attribution scores: IG: {rel_attr_weight[:5]}, GradCAM(++): {cam_attr_weight[:5]}") num_subplots = 5 if eval_attn_importance else 4 fig, axs = plt.subplots(num_subplots, 4) for plt_i in range(4): axs[0][plt_i].imshow(denormalize(xs[plt_i]).permute(1, 2, 0).cpu().numpy()) axs[1][plt_i].imshow(fg_masks[plt_i].cpu().numpy()) axs[2][plt_i].imshow(attr_probs[plt_i].cpu().numpy()) axs[3][plt_i].imshow(cam_attr[plt_i].cpu().numpy()) if eval_attn_importance: axs[4][plt_i].imshow(attn_attr[plt_i].cpu().numpy()) plt.show() iterator_desc = ( f"IG weights: {rel_ig_weights / (iterator.n + 1):.4f}, GradCAM weights:" f" {rel_cam_weights['GradCAM'] / (iterator.n + 1):.4f}, GradCAM++ weights:" f" {rel_cam_weights['GradCAM++'] / (iterator.n + 1):.4f}" ) if eval_attn_importance: iterator_desc += f", Attn weights: {rel_attn_weights / (iterator.n + 1):.4f}" iterator.set_description(iterator_desc) if args.distributed: dist.barrier() val_end = time() rel_ig_weights /= len(iterator) rel_grad_cam = rel_cam_weights["GradCAM"] / len(iterator) rel_grad_cam_pp = rel_cam_weights["GradCAM++"] / len(iterator) rel_attn_weights /= len(iterator) if dali_server: dali_server.stop_thread() if args.distributed: gather_tensor = torch.Tensor([rel_ig_weights, rel_grad_cam, rel_grad_cam_pp, rel_attn_weights]).to(device) dist.barrier() dist.all_reduce(gather_tensor) gather_tensor = (gather_tensor / world_size).tolist() rel_ig_weights, rel_grad_cam, rel_grad_cam_pp, rel_attn_weights = gather_tensor if rank == 0: output_text = ( f"epoch {epoch}: eval_{args.val_dataset}/rel_ig_weights={rel_ig_weights}," f" eval_{args.val_dataset}/rel_grad_cam={rel_grad_cam}," f" eval_{args.val_dataset}/rel_grad_cam_pp={rel_grad_cam_pp}" ) if eval_attn_importance: output_text += f", eval_{args.val_dataset}/rel_attn_weights={rel_attn_weights}" output_text += f", eval_{args.val_dataset}/attribution_eval_time={val_end - val_start}s" logger.info(output_text) if wandb_available(): import wandb wandb_data = { f"eval_{args.val_dataset}/importance_ig": rel_ig_weights, f"eval_{args.val_dataset}/importance_grad_cam": rel_grad_cam, f"eval_{args.val_dataset}/importance_grad_cam_pp": rel_grad_cam_pp, } if eval_attn_importance: wandb_data[f"eval_{args.val_dataset}/importance_attn"] = rel_attn_weights wandb.log(wandb_data) ddp_cleanup(args=args, rank=rank)