711 lines
26 KiB
Python
711 lines
26 KiB
Python
"""Module to evaluate trained models."""
|
|
|
|
import os
|
|
from contextlib import nullcontext
|
|
from datetime import datetime
|
|
from math import sqrt
|
|
from time import time
|
|
|
|
import torch
|
|
from loguru import logger
|
|
from timm.loss import LabelSmoothingCrossEntropy
|
|
from timm.models.resnet import ResNet as TimmResNet
|
|
from torch import distributed as dist
|
|
from torch.nn import functional as F
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
from torch.utils.data import DataLoader
|
|
from tqdm import tqdm
|
|
|
|
from engine import (
|
|
_evaluate,
|
|
setup_criteria_mixup,
|
|
setup_model_optim_sched_scaler,
|
|
setup_tracking_and_logging,
|
|
wandb_available,
|
|
)
|
|
from load_dataset import prepare_dataset
|
|
from metrics import calculate_metrics
|
|
from models import load_pretrained
|
|
from utils import (
|
|
RepeatedDataset,
|
|
ddp_cleanup,
|
|
ddp_setup,
|
|
denormalize,
|
|
get_cpu_name,
|
|
grad_cam_reshape_transform,
|
|
prep_kwargs,
|
|
set_filter_warnings,
|
|
)
|
|
|
|
|
|
def evaluate_metrics(model, dataset, **kwargs):
|
|
"""Evaluate efficiency metrics for a given model.
|
|
|
|
Args:
|
|
model (str): path to model state .tar
|
|
dataset (str): name of the dataset to evaluate on
|
|
**kwargs: further arguments
|
|
|
|
"""
|
|
set_filter_warnings()
|
|
model_path = model
|
|
args = prep_kwargs(kwargs)
|
|
if args.cuda:
|
|
args.distributed, device, world_size, rank, _ = ddp_setup()
|
|
torch.cuda.set_device(device)
|
|
else:
|
|
args.distributed = False
|
|
device = torch.device("cpu")
|
|
rank = 0
|
|
args.compile_model = False
|
|
|
|
save_state = torch.load(model_path, map_location="cpu")
|
|
old_args = prep_kwargs(save_state["args"])
|
|
args.model = old_args.model
|
|
args.dataset = dataset
|
|
args.run_name = old_args.run_name
|
|
args.experiment_name = old_args.experiment_name
|
|
args.wandb_run_id = old_args.wandb_run_id
|
|
setup_tracking_and_logging(args, rank=rank, append_model_path=model_path, log_wandb=args.wandb_run_id is not None)
|
|
|
|
train_loader, args.n_classes, args.ignore_index, args.multi_label, _ = prepare_dataset(dataset, args)
|
|
|
|
model, args, old_args, save_state = load_pretrained(model, args, new_dataset_params=True)
|
|
old_args["eval_imsize"] = args.imsize
|
|
args.model = model_name = old_args.model
|
|
args.dataset = dataset
|
|
args.epochs = 5
|
|
|
|
model, optim, _, scaler = setup_model_optim_sched_scaler(model, device, epochs=10, args=args, head_only=False)
|
|
|
|
if rank == 0:
|
|
logger.info(
|
|
f"Evaluate metrics for model {model_name} on {dataset}. "
|
|
f"It was {old_args.task.replace('-','')}d on {old_args.dataset} for {save_state['epoch']} "
|
|
"epochs."
|
|
)
|
|
# logger.info(f"full set of arguments: {args}")
|
|
logger.info(f"full set of training arguments: {old_args}")
|
|
logger.info(f"full set of eval-metrics arguments: {args}")
|
|
|
|
logger.info(
|
|
f"evaluating on {device} -> {torch.cuda.get_device_name(device) if device.type != 'cpu' else get_cpu_name()}"
|
|
)
|
|
metrics = calculate_metrics(
|
|
args, model, rank=rank, device=device, optim=optim, scaler=scaler, train_loader=train_loader, key_start="eval/"
|
|
)
|
|
if rank == 0:
|
|
logger.info(f"Metrics: {metrics}")
|
|
if wandb_available():
|
|
import wandb
|
|
|
|
wandb.log(metrics)
|
|
|
|
|
|
def evaluate(model, dataset=None, val_dataset=None, **kwargs):
|
|
"""Evaluate model accuracy.
|
|
|
|
Args:
|
|
model (str): path to model state .tar
|
|
dataset (str, optional): name of the dataset to evaluate on (Default value = None)
|
|
val_dataset (str, optional): name of the dataset to evaluate on (Default value = None)
|
|
**kwargs: further arguments
|
|
Note:
|
|
If `val_dataset` is not provided, the model will be evaluated on `dataset`.
|
|
|
|
"""
|
|
set_filter_warnings()
|
|
model_path = model
|
|
args = prep_kwargs(kwargs)
|
|
if val_dataset is None:
|
|
val_dataset = dataset
|
|
args.dataset = dataset
|
|
args.val_dataset = val_dataset
|
|
if args.cuda:
|
|
args.distributed, device, world_size, rank, _ = ddp_setup()
|
|
torch.cuda.set_device(device)
|
|
else:
|
|
args.distributed = False
|
|
device = torch.device("cpu")
|
|
world_size = 1
|
|
rank = 0
|
|
args.compile_model = False
|
|
args.batch_size = int(args.batch_size / world_size)
|
|
|
|
save_state = torch.load(model_path, map_location="cpu")
|
|
old_args = prep_kwargs(save_state["args"])
|
|
args.model = old_args.model
|
|
args.dataset = dataset
|
|
args.run_name = old_args.run_name
|
|
args.experiment_name = old_args.experiment_name
|
|
args.wandb_run_id = old_args.wandb_run_id
|
|
run_folder = setup_tracking_and_logging(
|
|
args, rank, append_model_path=model_path, log_wandb=args.wandb_run_id is not None
|
|
)
|
|
|
|
val_loader, args.n_classes, args.ignore_index, args.multi_label, dali_server = prepare_dataset(
|
|
val_dataset, args, train=False
|
|
)
|
|
|
|
model, args, old_args, save_state = load_pretrained(model, args, new_dataset_params=True)
|
|
model = model.to(device)
|
|
args.model = model_name = old_args.model
|
|
args.dataset = dataset
|
|
|
|
if rank == 0:
|
|
logger.info(
|
|
f"Evaluate model {model_name} on {val_dataset}. "
|
|
f"It was pretrained on {old_args.dataset} for {save_state['epoch']} epochs."
|
|
)
|
|
|
|
if args.distributed:
|
|
model = DDP(model)
|
|
|
|
if args.compile_model:
|
|
model = torch.compile(model)
|
|
|
|
# log all devices
|
|
logger.info(
|
|
f"evaluating on {device} -> {torch.cuda.get_device_name(device) if device.type != 'cpu' else get_cpu_name()}"
|
|
)
|
|
if rank == 0:
|
|
logger.info(f"torch version {torch.__version__}")
|
|
logger.info(f"full set of arguments: {args}")
|
|
logger.info(f"full set of old arguments: {old_args}")
|
|
|
|
if args.seed:
|
|
torch.manual_seed(args.seed)
|
|
|
|
val_criterion = LabelSmoothingCrossEntropy(smoothing=args.label_smoothing)
|
|
if rank == 0:
|
|
logger.info("start evaluation")
|
|
logger.info(f"Run info at: '{run_folder}'")
|
|
|
|
if rank == 0:
|
|
val_time, val_stats = _evaluate(
|
|
model.to(device),
|
|
val_loader,
|
|
epoch=save_state["epoch"] - 1,
|
|
rank=rank,
|
|
device=device,
|
|
val_criterion=val_criterion,
|
|
args=args,
|
|
dali_server=dali_server,
|
|
acc_dict_key=f"eval_{val_dataset}/acc{{}}",
|
|
)
|
|
log_s = f"Evaluation done in {val_time}s"
|
|
for key, val in val_stats.items():
|
|
log_s += f", {key}={val:.4f}"
|
|
logger.info(log_s)
|
|
if wandb_available():
|
|
import wandb
|
|
|
|
wandb.log(val_stats)
|
|
else:
|
|
_evaluate(
|
|
model.to(device),
|
|
val_loader,
|
|
epoch=save_state["epoch"] - 1,
|
|
rank=rank,
|
|
device=device,
|
|
val_criterion=val_criterion,
|
|
args=args,
|
|
dali_server=dali_server,
|
|
acc_dict_key=f"eval_{val_dataset}/acc{{}}",
|
|
)
|
|
|
|
ddp_cleanup(args=args, rank=rank)
|
|
|
|
|
|
def evaluate_center_bias(model, dataset=None, val_dataset=None, **kwargs):
|
|
"""Evaluate model accuracy in different nonants.
|
|
|
|
Args:
|
|
model (str): path to model state .tar
|
|
dataset (str, optional): name of the dataset to evaluate on (Default value = None)
|
|
val_dataset (str, optional): name of the dataset to evaluate on (Default value = None)
|
|
**kwargs: further arguments
|
|
Note:
|
|
If `val_dataset` is not provided, the model will be evaluated on `dataset`.
|
|
|
|
"""
|
|
set_filter_warnings()
|
|
model_path = model
|
|
args = prep_kwargs(kwargs)
|
|
if val_dataset is None:
|
|
val_dataset = dataset
|
|
if dataset is None:
|
|
dataset = val_dataset
|
|
assert dataset is not None, "Specify validation dataset (-valds) or dataset (-ds)."
|
|
args.dataset = dataset
|
|
args.val_dataset = val_dataset
|
|
if args.cuda:
|
|
args.distributed, device, world_size, rank, _ = ddp_setup()
|
|
torch.cuda.set_device(device)
|
|
else:
|
|
args.distributed = False
|
|
device = torch.device("cpu")
|
|
world_size = 1
|
|
rank = 0
|
|
args.compile_model = False
|
|
args.batch_size = int(args.batch_size / world_size)
|
|
|
|
save_state = torch.load(model_path, map_location="cpu")
|
|
old_args = prep_kwargs(save_state["args"])
|
|
args.model = old_args.model
|
|
args.dataset = dataset
|
|
args.run_name = old_args.run_name
|
|
args.experiment_name = old_args.experiment_name
|
|
args.wandb_run_id = old_args.wandb_run_id
|
|
run_folder = setup_tracking_and_logging(
|
|
args, rank, append_model_path=model_path, log_wandb=args.wandb_run_id is not None
|
|
)
|
|
|
|
assert "fornet" in val_dataset.lower(), "Only ForNet supported for center bias evaluation."
|
|
_, args.n_classes, args.ignore_index, args.multi_label, __ = prepare_dataset(val_dataset, args, train=False)
|
|
|
|
model, args, old_args, save_state = load_pretrained(model, args, new_dataset_params=True)
|
|
model = model.to(device)
|
|
args.model = model_name = old_args.model
|
|
args.dataset = dataset
|
|
|
|
if rank == 0:
|
|
logger.info(
|
|
f"Evaluate model {model_name} on {val_dataset}. "
|
|
f"It was pretrained on {old_args.dataset} for {save_state['epoch']} epochs."
|
|
)
|
|
|
|
if args.distributed:
|
|
model = DDP(model)
|
|
|
|
if args.compile_model:
|
|
model = torch.compile(model)
|
|
|
|
# log all devices
|
|
logger.info(
|
|
f"evaluating on {device} -> {torch.cuda.get_device_name(device) if device.type != 'cpu' else get_cpu_name()}"
|
|
)
|
|
if rank == 0:
|
|
logger.info(f"torch version {torch.__version__}")
|
|
logger.info(f"full set of arguments: {args}")
|
|
logger.info(f"full set of old arguments: {old_args}")
|
|
|
|
if args.seed:
|
|
torch.manual_seed(args.seed)
|
|
|
|
val_criterion = LabelSmoothingCrossEntropy(smoothing=args.label_smoothing)
|
|
if rank == 0:
|
|
logger.info("start evaluation")
|
|
logger.info(f"Run info at: '{run_folder}'")
|
|
|
|
if rank == 0:
|
|
nonant_accs = []
|
|
for nonant in range(-1, 9):
|
|
val_loader, _, __, ___, dali_server = prepare_dataset(val_dataset, args, train=False)
|
|
val_loader.dataset.fg_in_nonant = nonant
|
|
logger.info(f"Evaluate nonant {nonant} for 5 rounds.")
|
|
round_accs = []
|
|
for _ in range(5):
|
|
val_time, val_stats = _evaluate(
|
|
model.to(device),
|
|
val_loader,
|
|
epoch=save_state["epoch"] - 1,
|
|
rank=rank,
|
|
device=device,
|
|
val_criterion=val_criterion,
|
|
args=args,
|
|
dali_server=dali_server,
|
|
)
|
|
round_accs.append(val_stats["acc1"])
|
|
nonant_accs.append(sum(round_accs) / len(round_accs))
|
|
log_s = f"Evaluation done in {val_time}s: "
|
|
for nonant, val in enumerate(nonant_accs[1:]):
|
|
log_s += f", nonant {nonant}={val}% acc ({val / nonant_accs[0]} rel acc)"
|
|
center_bias_val = 1 - (
|
|
min([nonant_accs[1], nonant_accs[3], nonant_accs[7], nonant_accs[9]])
|
|
+ min([nonant_accs[2], nonant_accs[4], nonant_accs[6], nonant_accs[8]])
|
|
) / (2 * nonant_accs[5])
|
|
log_s += f", center_bias={center_bias_val:.4f}"
|
|
logger.info(log_s)
|
|
if wandb_available():
|
|
import wandb
|
|
|
|
wandb.log({f"eval_{args.val_dataset}/center_bias": center_bias_val})
|
|
else:
|
|
raise NotImplementedError("Center bias evaluation not supported in distributed mode.")
|
|
|
|
ddp_cleanup(args=args, rank=rank)
|
|
|
|
|
|
def evaluate_size_bias(model, dataset=None, val_dataset=None, **kwargs):
|
|
"""Evaluate model accuracy for differently scaled foregrounds.
|
|
|
|
Args:
|
|
model (str): path to model state .tar
|
|
dataset (str, optional): name of the dataset to evaluate on (Default value = None)
|
|
val_dataset (str, optional): name of the dataset to evaluate on (Default value = None)
|
|
**kwargs: further arguments
|
|
Note:
|
|
If `val_dataset` is not provided, the model will be evaluated on `dataset`.
|
|
|
|
"""
|
|
set_filter_warnings()
|
|
model_path = model
|
|
args = prep_kwargs(kwargs)
|
|
if val_dataset is None:
|
|
val_dataset = dataset
|
|
if dataset is None:
|
|
dataset = val_dataset
|
|
assert val_dataset is not None and dataset is not None
|
|
args.dataset = dataset
|
|
args.val_dataset = val_dataset
|
|
if args.cuda:
|
|
args.distributed, device, world_size, rank, _ = ddp_setup()
|
|
torch.cuda.set_device(device)
|
|
else:
|
|
args.distributed = False
|
|
device = torch.device("cpu")
|
|
world_size = 1
|
|
rank = 0
|
|
args.compile_model = False
|
|
args.batch_size = int(args.batch_size / world_size)
|
|
|
|
save_state = torch.load(model_path, map_location="cpu")
|
|
old_args = prep_kwargs(save_state["args"])
|
|
args.model = old_args.model
|
|
args.dataset = dataset
|
|
args.run_name = old_args.run_name
|
|
args.experiment_name = old_args.experiment_name
|
|
args.wandb_run_id = old_args.wandb_run_id
|
|
run_folder = setup_tracking_and_logging(args, rank, append_model_path=model_path, log_wandb=False)
|
|
|
|
assert "fornet" in val_dataset.lower(), "Only ForNet supported for center bias evaluation."
|
|
_, args.n_classes, args.ignore_index, args.multi_label, __ = prepare_dataset(val_dataset, args, train=False)
|
|
|
|
model, args, old_args, save_state = load_pretrained(model, args, new_dataset_params=True)
|
|
model = model.to(device)
|
|
args.model = model_name = old_args.model
|
|
args.dataset = dataset
|
|
|
|
if rank == 0:
|
|
logger.info(
|
|
f"Evaluate model {model_name} on {val_dataset}. "
|
|
f"It was pretrained on {old_args.dataset} for {save_state['epoch']} epochs."
|
|
)
|
|
|
|
if args.distributed:
|
|
model = DDP(model)
|
|
|
|
if args.compile_model:
|
|
model = torch.compile(model)
|
|
|
|
# log all devices
|
|
logger.info(
|
|
f"evaluating on {device} -> {torch.cuda.get_device_name(device) if device.type != 'cpu' else get_cpu_name()}"
|
|
)
|
|
if rank == 0:
|
|
logger.info(f"torch version {torch.__version__}")
|
|
logger.info(f"full set of arguments: {args}")
|
|
logger.info(f"full set of old arguments: {old_args}")
|
|
|
|
if args.seed:
|
|
torch.manual_seed(args.seed)
|
|
|
|
val_criterion = LabelSmoothingCrossEntropy(smoothing=args.label_smoothing)
|
|
if rank == 0:
|
|
logger.info("start evaluation")
|
|
logger.info(f"Run info at: '{run_folder}'")
|
|
|
|
sizes = [0.1, 0.2, 0.4, 0.6, 0.8, 1.0, 1.25, 1.5, 1.75, 2.0]
|
|
if rank == 0:
|
|
size_accs = []
|
|
val_times = 0
|
|
for size in sizes:
|
|
val_loader, _, __, ___, dali_server = prepare_dataset(val_dataset, args, train=False)
|
|
val_loader.dataset.size_fact = size
|
|
val_loader.dataset.fg_scale_jitter = 0.0
|
|
logger.info(f"Evaluate size factor {size} for 5 rounds.")
|
|
round_accs = []
|
|
for _ in range(5):
|
|
val_time, val_stats = _evaluate(
|
|
model.to(device),
|
|
val_loader,
|
|
epoch=save_state["epoch"] - 1,
|
|
rank=rank,
|
|
device=device,
|
|
val_criterion=val_criterion,
|
|
args=args,
|
|
dali_server=dali_server,
|
|
)
|
|
round_accs.append(val_stats["acc1"])
|
|
val_times += val_time
|
|
size_accs.append(sum(round_accs) / len(round_accs))
|
|
log_s = f"Evaluation done in {val_times}s: "
|
|
for size, val in zip(sizes, size_accs):
|
|
log_s += f", rel_size {size}={val}% acc ({val / size_accs[sizes.index(1.0)]} rel acc)"
|
|
logger.info(log_s)
|
|
else:
|
|
raise NotImplementedError("Center bias evaluation not supported in distributed mode.")
|
|
|
|
ddp_cleanup(args=args, rank=rank)
|
|
|
|
|
|
def evaluate_attributions(model, dataset=None, val_dataset=None, **kwargs):
|
|
"""Evaluate model attributions using captum.
|
|
|
|
Args:
|
|
model (str): path to model state .tar
|
|
dataset (str, optional): name of the dataset to evaluate on (Default value = None)
|
|
val_dataset (str, optional): name of the dataset to evaluate on (Default value = None)
|
|
**kwargs: further arguments
|
|
Note:
|
|
If `val_dataset` is not provided, the model will be evaluated on `dataset`.
|
|
The `captum` package is required.
|
|
|
|
"""
|
|
from captum.attr import IntegratedGradients
|
|
from pytorch_grad_cam import GradCAM, GradCAMPlusPlus
|
|
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
|
|
|
|
set_filter_warnings()
|
|
model_path = model
|
|
args = prep_kwargs(kwargs)
|
|
if val_dataset is None:
|
|
val_dataset = dataset
|
|
assert val_dataset is not None, "Please set dataset (-ds) or validation dataset (-valds)"
|
|
args.dataset = val_dataset
|
|
args.val_dataset = val_dataset
|
|
if args.cuda:
|
|
args.distributed, device, world_size, rank, _ = ddp_setup()
|
|
torch.cuda.set_device(device)
|
|
else:
|
|
args.distributed = False
|
|
device = torch.device("cpu")
|
|
world_size = 1
|
|
rank = 0
|
|
args.compile_model = False
|
|
args.batch_size = int(args.batch_size / world_size)
|
|
|
|
save_state = torch.load(model_path, map_location="cpu")
|
|
old_args = prep_kwargs(save_state["args"])
|
|
args.model = old_args.model
|
|
args.dataset = val_dataset
|
|
args.run_name = old_args.run_name
|
|
args.experiment_name = old_args.experiment_name
|
|
args.wandb_run_id = old_args.wandb_run_id
|
|
run_folder = setup_tracking_and_logging(
|
|
args, rank, append_model_path=model_path, log_wandb=args.wandb_run_id is not None
|
|
)
|
|
|
|
assert "fornet" in val_dataset.lower(), "Only ForNet supported for attribution evaluation."
|
|
val_loader, args.n_classes, args.ignore_index, args.multi_label, dali_server = prepare_dataset(
|
|
val_dataset, args, train=False
|
|
)
|
|
val_loader.dataset.return_fg_masks = True
|
|
|
|
model, args, old_args, save_state = load_pretrained(model, args, new_dataset_params=True)
|
|
model = model.to(device)
|
|
args.model = model_name = old_args.model
|
|
args.dataset = dataset
|
|
# assert (
|
|
# args.imsize == old_args.imsize
|
|
# ), f"Model was trained on {old_args.imsize}x{old_args.imsize} images. Not {args.imsize}x{args.imsize}."
|
|
epoch = save_state["epoch"]
|
|
|
|
if rank == 0:
|
|
logger.info(
|
|
f"Evaluate attributions of model {model_name} on {dataset}. "
|
|
f"It was pretrained on {old_args.dataset} for {epoch} epochs."
|
|
)
|
|
|
|
if args.distributed:
|
|
model = DDP(model)
|
|
|
|
if args.compile_model:
|
|
model = torch.compile(model)
|
|
|
|
# log all devices
|
|
logger.info(
|
|
f"evaluating on {device} -> {torch.cuda.get_device_name(device) if device.type != 'cpu' else get_cpu_name()}"
|
|
)
|
|
if rank == 0:
|
|
logger.info(f"torch version {torch.__version__}")
|
|
if args.new_log:
|
|
logger.info(f"full set of arguments: {args}")
|
|
logger.info(f"full set of old arguments: {old_args}")
|
|
else:
|
|
logger.info(f"full set of attribution evaluation arguments: {old_args}")
|
|
|
|
if args.seed:
|
|
torch.manual_seed(args.seed)
|
|
|
|
if rank == 0:
|
|
logger.info(f"Run info at: '{run_folder}'")
|
|
|
|
iterator = (
|
|
tqdm(val_loader, total=len(val_loader), desc=f"Validating epoch {epoch}")
|
|
if rank == 0 and args.tqdm
|
|
else val_loader
|
|
)
|
|
|
|
if args.debug:
|
|
from matplotlib import pyplot as plt
|
|
|
|
eval_attn_importance = False
|
|
if isinstance(model, TimmResNet):
|
|
reshape_transform = None
|
|
target_layers = [model.layer4[-1]]
|
|
elif model_name.lower().startswith("vit-"):
|
|
reshape_transform = grad_cam_reshape_transform
|
|
target_layers = [model.blocks[-1].norm1]
|
|
eval_attn_importance = True
|
|
from architectures.vit import _MatrixSaveAttn
|
|
|
|
model.blocks[-1].attn = _MatrixSaveAttn.cast(model.blocks[-1].attn)
|
|
elif model_name.lower().startswith("swin_"):
|
|
reshape_transform = grad_cam_reshape_transform
|
|
target_layers = [model.layers[-1].blocks[-1].norm1]
|
|
else:
|
|
raise NotImplementedError(f"Model {model_name} not supported for attribution evaluation.")
|
|
|
|
model.eval()
|
|
val_start = time()
|
|
rel_ig_weights = 0.0
|
|
rel_attn_weights = 0.0
|
|
rel_cam_weights = {"GradCAM": 0.0, "GradCAM++": 0.0}
|
|
if rank == 0:
|
|
logger.info("Start attribution evaluation")
|
|
if dali_server:
|
|
dali_server.start_thread()
|
|
for batch_data in iterator:
|
|
xs, ys, fg_masks = batch_data
|
|
|
|
xs, ys, fg_masks = (
|
|
xs.to(device, non_blocking=True),
|
|
ys.to(device, non_blocking=True),
|
|
fg_masks.float().to(device, non_blocking=True),
|
|
)
|
|
|
|
with torch.amp.autocast("cuda") if args.eval_amp else nullcontext():
|
|
model.zero_grad()
|
|
ig = IntegratedGradients(model)
|
|
# we use attention temperature of 10 to make differences more apparent after exp
|
|
attr_ig = (
|
|
ig.attribute(xs, target=ys, baselines=0.0, internal_batch_size=args.batch_size * 4).sum(dim=1) * 10
|
|
) # B x W x H
|
|
attr_probs = attr_ig.view(xs.shape[0], -1).softmax(dim=-1).view(xs.shape[0], *xs.shape[2:])
|
|
fg_masks = fg_masks.view(attr_probs.shape)
|
|
|
|
fg_attrs = (attr_probs * fg_masks).sum(dim=(-1, -2))
|
|
rel_attr_weight = fg_attrs / fg_masks.mean(dim=(-1, -2))
|
|
rel_attr_weight = torch.where(fg_masks.mean(dim=(-1, -2)) > 0, rel_attr_weight, 1.0)
|
|
if rel_attr_weight.isnan().any():
|
|
logger.error(f"NaNs in rel_attr_weight: {rel_attr_weight}, fg_mask_weights: {fg_masks.mean(dim=(-1, -2))}")
|
|
break
|
|
rel_ig_weights += rel_attr_weight.mean().item()
|
|
|
|
cam_targets = [ClassifierOutputTarget(int(trgt)) for trgt in ys.tolist()]
|
|
for method, name in zip([GradCAM, GradCAMPlusPlus], ["GradCAM", "GradCAM++"]):
|
|
with method(model=model, target_layers=target_layers, reshape_transform=reshape_transform) as cam, (
|
|
torch.amp.autocast("cuda") if args.eval_amp else nullcontext()
|
|
):
|
|
cam_attr = cam(input_tensor=xs, targets=cam_targets)
|
|
|
|
cam_attr = torch.from_numpy(cam_attr).to(device)
|
|
rel_cam_attr = (cam_attr * fg_masks).sum(dim=(-1, -2)) / cam_attr.sum(dim=(-1, -2))
|
|
cam_attr_weight = rel_cam_attr / fg_masks.mean(dim=(-1, -2))
|
|
cam_attr_weight = torch.where(
|
|
(fg_masks.mean(dim=(-1, -2)) > 0) & (cam_attr.sum(dim=(-1, -2)) > 0), cam_attr_weight, 1.0
|
|
)
|
|
rel_cam_weights[name] += cam_attr_weight.mean().item()
|
|
if cam_attr_weight.isnan().any():
|
|
logger.error(
|
|
f"NaNs in cam_attr_weight ({name}): {cam_attr_weight}, fg_mask_weights:"
|
|
f" {fg_masks.mean(dim=(-1, -2))}"
|
|
)
|
|
break
|
|
|
|
if eval_attn_importance:
|
|
with torch.amp.autocast("cuda") if args.eval_amp else nullcontext():
|
|
pred = model(xs) # noqa: F841
|
|
last_attn_mat = model.blocks[-1].attn.attn_mat
|
|
cls_tkn_attn = last_attn_mat[:, :, 0, 1:].mean(dim=1).squeeze(dim=1) # B x H x 1(CLS Token) X N -> B x N
|
|
B, N = cls_tkn_attn.shape
|
|
att_HW = int(sqrt(N))
|
|
cls_tkn_attn = cls_tkn_attn.view(B, 1, att_HW, att_HW)
|
|
attn_attr = F.interpolate(
|
|
cls_tkn_attn, size=(xs.shape[-2], xs.shape[-1]), mode="bilinear", align_corners=False
|
|
).view(B, xs.shape[-2], xs.shape[-1])
|
|
rel_attn_attr = (attn_attr * fg_masks).sum(dim=(-1, -2)) / attn_attr.sum(dim=(-1, -2))
|
|
attn_attr_weight = rel_attn_attr / fg_masks.mean(dim=(-1, -2))
|
|
attn_attr_weight = torch.where(
|
|
(fg_masks.mean(dim=(-1, -2)) > 0) & (attn_attr.sum(dim=(-1, -2)) > 0), attn_attr_weight, 1.0
|
|
)
|
|
rel_attn_weights += attn_attr_weight.mean().item()
|
|
|
|
if args.debug:
|
|
logger.debug(f"Attribution scores: IG: {rel_attr_weight[:5]}, GradCAM(++): {cam_attr_weight[:5]}")
|
|
num_subplots = 5 if eval_attn_importance else 4
|
|
fig, axs = plt.subplots(num_subplots, 4)
|
|
for plt_i in range(4):
|
|
axs[0][plt_i].imshow(denormalize(xs[plt_i]).permute(1, 2, 0).cpu().numpy())
|
|
axs[1][plt_i].imshow(fg_masks[plt_i].cpu().numpy())
|
|
axs[2][plt_i].imshow(attr_probs[plt_i].cpu().numpy())
|
|
axs[3][plt_i].imshow(cam_attr[plt_i].cpu().numpy())
|
|
if eval_attn_importance:
|
|
axs[4][plt_i].imshow(attn_attr[plt_i].cpu().numpy())
|
|
plt.show()
|
|
|
|
iterator_desc = (
|
|
f"IG weights: {rel_ig_weights / (iterator.n + 1):.4f}, GradCAM weights:"
|
|
f" {rel_cam_weights['GradCAM'] / (iterator.n + 1):.4f}, GradCAM++ weights:"
|
|
f" {rel_cam_weights['GradCAM++'] / (iterator.n + 1):.4f}"
|
|
)
|
|
if eval_attn_importance:
|
|
iterator_desc += f", Attn weights: {rel_attn_weights / (iterator.n + 1):.4f}"
|
|
|
|
iterator.set_description(iterator_desc)
|
|
|
|
if args.distributed:
|
|
dist.barrier()
|
|
|
|
val_end = time()
|
|
rel_ig_weights /= len(iterator)
|
|
rel_grad_cam = rel_cam_weights["GradCAM"] / len(iterator)
|
|
rel_grad_cam_pp = rel_cam_weights["GradCAM++"] / len(iterator)
|
|
rel_attn_weights /= len(iterator)
|
|
|
|
if dali_server:
|
|
dali_server.stop_thread()
|
|
|
|
if args.distributed:
|
|
gather_tensor = torch.Tensor([rel_ig_weights, rel_grad_cam, rel_grad_cam_pp, rel_attn_weights]).to(device)
|
|
dist.barrier()
|
|
dist.all_reduce(gather_tensor)
|
|
gather_tensor = (gather_tensor / world_size).tolist()
|
|
rel_ig_weights, rel_grad_cam, rel_grad_cam_pp, rel_attn_weights = gather_tensor
|
|
|
|
if rank == 0:
|
|
output_text = (
|
|
f"epoch {epoch}: eval_{args.val_dataset}/rel_ig_weights={rel_ig_weights},"
|
|
f" eval_{args.val_dataset}/rel_grad_cam={rel_grad_cam},"
|
|
f" eval_{args.val_dataset}/rel_grad_cam_pp={rel_grad_cam_pp}"
|
|
)
|
|
if eval_attn_importance:
|
|
output_text += f", eval_{args.val_dataset}/rel_attn_weights={rel_attn_weights}"
|
|
output_text += f", eval_{args.val_dataset}/attribution_eval_time={val_end - val_start}s"
|
|
logger.info(output_text)
|
|
if wandb_available():
|
|
import wandb
|
|
|
|
wandb_data = {
|
|
f"eval_{args.val_dataset}/importance_ig": rel_ig_weights,
|
|
f"eval_{args.val_dataset}/importance_grad_cam": rel_grad_cam,
|
|
f"eval_{args.val_dataset}/importance_grad_cam_pp": rel_grad_cam_pp,
|
|
}
|
|
if eval_attn_importance:
|
|
wandb_data[f"eval_{args.val_dataset}/importance_attn"] = rel_attn_weights
|
|
wandb.log(wandb_data)
|
|
|
|
ddp_cleanup(args=args, rank=rank)
|