Files
ForAug/AAAI Supplementary Material/Model Training Code/evaluate.py
Tobias Christian Nauen ff34712155 AAAI Version
2026-02-24 12:22:44 +01:00

711 lines
26 KiB
Python

"""Module to evaluate trained models."""
import os
from contextlib import nullcontext
from datetime import datetime
from math import sqrt
from time import time
import torch
from loguru import logger
from timm.loss import LabelSmoothingCrossEntropy
from timm.models.resnet import ResNet as TimmResNet
from torch import distributed as dist
from torch.nn import functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from tqdm import tqdm
from engine import (
_evaluate,
setup_criteria_mixup,
setup_model_optim_sched_scaler,
setup_tracking_and_logging,
wandb_available,
)
from load_dataset import prepare_dataset
from metrics import calculate_metrics
from models import load_pretrained
from utils import (
RepeatedDataset,
ddp_cleanup,
ddp_setup,
denormalize,
get_cpu_name,
grad_cam_reshape_transform,
prep_kwargs,
set_filter_warnings,
)
def evaluate_metrics(model, dataset, **kwargs):
"""Evaluate efficiency metrics for a given model.
Args:
model (str): path to model state .tar
dataset (str): name of the dataset to evaluate on
**kwargs: further arguments
"""
set_filter_warnings()
model_path = model
args = prep_kwargs(kwargs)
if args.cuda:
args.distributed, device, world_size, rank, _ = ddp_setup()
torch.cuda.set_device(device)
else:
args.distributed = False
device = torch.device("cpu")
rank = 0
args.compile_model = False
save_state = torch.load(model_path, map_location="cpu")
old_args = prep_kwargs(save_state["args"])
args.model = old_args.model
args.dataset = dataset
args.run_name = old_args.run_name
args.experiment_name = old_args.experiment_name
args.wandb_run_id = old_args.wandb_run_id
setup_tracking_and_logging(args, rank=rank, append_model_path=model_path, log_wandb=args.wandb_run_id is not None)
train_loader, args.n_classes, args.ignore_index, args.multi_label, _ = prepare_dataset(dataset, args)
model, args, old_args, save_state = load_pretrained(model, args, new_dataset_params=True)
old_args["eval_imsize"] = args.imsize
args.model = model_name = old_args.model
args.dataset = dataset
args.epochs = 5
model, optim, _, scaler = setup_model_optim_sched_scaler(model, device, epochs=10, args=args, head_only=False)
if rank == 0:
logger.info(
f"Evaluate metrics for model {model_name} on {dataset}. "
f"It was {old_args.task.replace('-','')}d on {old_args.dataset} for {save_state['epoch']} "
"epochs."
)
# logger.info(f"full set of arguments: {args}")
logger.info(f"full set of training arguments: {old_args}")
logger.info(f"full set of eval-metrics arguments: {args}")
logger.info(
f"evaluating on {device} -> {torch.cuda.get_device_name(device) if device.type != 'cpu' else get_cpu_name()}"
)
metrics = calculate_metrics(
args, model, rank=rank, device=device, optim=optim, scaler=scaler, train_loader=train_loader, key_start="eval/"
)
if rank == 0:
logger.info(f"Metrics: {metrics}")
if wandb_available():
import wandb
wandb.log(metrics)
def evaluate(model, dataset=None, val_dataset=None, **kwargs):
"""Evaluate model accuracy.
Args:
model (str): path to model state .tar
dataset (str, optional): name of the dataset to evaluate on (Default value = None)
val_dataset (str, optional): name of the dataset to evaluate on (Default value = None)
**kwargs: further arguments
Note:
If `val_dataset` is not provided, the model will be evaluated on `dataset`.
"""
set_filter_warnings()
model_path = model
args = prep_kwargs(kwargs)
if val_dataset is None:
val_dataset = dataset
args.dataset = dataset
args.val_dataset = val_dataset
if args.cuda:
args.distributed, device, world_size, rank, _ = ddp_setup()
torch.cuda.set_device(device)
else:
args.distributed = False
device = torch.device("cpu")
world_size = 1
rank = 0
args.compile_model = False
args.batch_size = int(args.batch_size / world_size)
save_state = torch.load(model_path, map_location="cpu")
old_args = prep_kwargs(save_state["args"])
args.model = old_args.model
args.dataset = dataset
args.run_name = old_args.run_name
args.experiment_name = old_args.experiment_name
args.wandb_run_id = old_args.wandb_run_id
run_folder = setup_tracking_and_logging(
args, rank, append_model_path=model_path, log_wandb=args.wandb_run_id is not None
)
val_loader, args.n_classes, args.ignore_index, args.multi_label, dali_server = prepare_dataset(
val_dataset, args, train=False
)
model, args, old_args, save_state = load_pretrained(model, args, new_dataset_params=True)
model = model.to(device)
args.model = model_name = old_args.model
args.dataset = dataset
if rank == 0:
logger.info(
f"Evaluate model {model_name} on {val_dataset}. "
f"It was pretrained on {old_args.dataset} for {save_state['epoch']} epochs."
)
if args.distributed:
model = DDP(model)
if args.compile_model:
model = torch.compile(model)
# log all devices
logger.info(
f"evaluating on {device} -> {torch.cuda.get_device_name(device) if device.type != 'cpu' else get_cpu_name()}"
)
if rank == 0:
logger.info(f"torch version {torch.__version__}")
logger.info(f"full set of arguments: {args}")
logger.info(f"full set of old arguments: {old_args}")
if args.seed:
torch.manual_seed(args.seed)
val_criterion = LabelSmoothingCrossEntropy(smoothing=args.label_smoothing)
if rank == 0:
logger.info("start evaluation")
logger.info(f"Run info at: '{run_folder}'")
if rank == 0:
val_time, val_stats = _evaluate(
model.to(device),
val_loader,
epoch=save_state["epoch"] - 1,
rank=rank,
device=device,
val_criterion=val_criterion,
args=args,
dali_server=dali_server,
acc_dict_key=f"eval_{val_dataset}/acc{{}}",
)
log_s = f"Evaluation done in {val_time}s"
for key, val in val_stats.items():
log_s += f", {key}={val:.4f}"
logger.info(log_s)
if wandb_available():
import wandb
wandb.log(val_stats)
else:
_evaluate(
model.to(device),
val_loader,
epoch=save_state["epoch"] - 1,
rank=rank,
device=device,
val_criterion=val_criterion,
args=args,
dali_server=dali_server,
acc_dict_key=f"eval_{val_dataset}/acc{{}}",
)
ddp_cleanup(args=args, rank=rank)
def evaluate_center_bias(model, dataset=None, val_dataset=None, **kwargs):
"""Evaluate model accuracy in different nonants.
Args:
model (str): path to model state .tar
dataset (str, optional): name of the dataset to evaluate on (Default value = None)
val_dataset (str, optional): name of the dataset to evaluate on (Default value = None)
**kwargs: further arguments
Note:
If `val_dataset` is not provided, the model will be evaluated on `dataset`.
"""
set_filter_warnings()
model_path = model
args = prep_kwargs(kwargs)
if val_dataset is None:
val_dataset = dataset
if dataset is None:
dataset = val_dataset
assert dataset is not None, "Specify validation dataset (-valds) or dataset (-ds)."
args.dataset = dataset
args.val_dataset = val_dataset
if args.cuda:
args.distributed, device, world_size, rank, _ = ddp_setup()
torch.cuda.set_device(device)
else:
args.distributed = False
device = torch.device("cpu")
world_size = 1
rank = 0
args.compile_model = False
args.batch_size = int(args.batch_size / world_size)
save_state = torch.load(model_path, map_location="cpu")
old_args = prep_kwargs(save_state["args"])
args.model = old_args.model
args.dataset = dataset
args.run_name = old_args.run_name
args.experiment_name = old_args.experiment_name
args.wandb_run_id = old_args.wandb_run_id
run_folder = setup_tracking_and_logging(
args, rank, append_model_path=model_path, log_wandb=args.wandb_run_id is not None
)
assert "fornet" in val_dataset.lower(), "Only ForNet supported for center bias evaluation."
_, args.n_classes, args.ignore_index, args.multi_label, __ = prepare_dataset(val_dataset, args, train=False)
model, args, old_args, save_state = load_pretrained(model, args, new_dataset_params=True)
model = model.to(device)
args.model = model_name = old_args.model
args.dataset = dataset
if rank == 0:
logger.info(
f"Evaluate model {model_name} on {val_dataset}. "
f"It was pretrained on {old_args.dataset} for {save_state['epoch']} epochs."
)
if args.distributed:
model = DDP(model)
if args.compile_model:
model = torch.compile(model)
# log all devices
logger.info(
f"evaluating on {device} -> {torch.cuda.get_device_name(device) if device.type != 'cpu' else get_cpu_name()}"
)
if rank == 0:
logger.info(f"torch version {torch.__version__}")
logger.info(f"full set of arguments: {args}")
logger.info(f"full set of old arguments: {old_args}")
if args.seed:
torch.manual_seed(args.seed)
val_criterion = LabelSmoothingCrossEntropy(smoothing=args.label_smoothing)
if rank == 0:
logger.info("start evaluation")
logger.info(f"Run info at: '{run_folder}'")
if rank == 0:
nonant_accs = []
for nonant in range(-1, 9):
val_loader, _, __, ___, dali_server = prepare_dataset(val_dataset, args, train=False)
val_loader.dataset.fg_in_nonant = nonant
logger.info(f"Evaluate nonant {nonant} for 5 rounds.")
round_accs = []
for _ in range(5):
val_time, val_stats = _evaluate(
model.to(device),
val_loader,
epoch=save_state["epoch"] - 1,
rank=rank,
device=device,
val_criterion=val_criterion,
args=args,
dali_server=dali_server,
)
round_accs.append(val_stats["acc1"])
nonant_accs.append(sum(round_accs) / len(round_accs))
log_s = f"Evaluation done in {val_time}s: "
for nonant, val in enumerate(nonant_accs[1:]):
log_s += f", nonant {nonant}={val}% acc ({val / nonant_accs[0]} rel acc)"
center_bias_val = 1 - (
min([nonant_accs[1], nonant_accs[3], nonant_accs[7], nonant_accs[9]])
+ min([nonant_accs[2], nonant_accs[4], nonant_accs[6], nonant_accs[8]])
) / (2 * nonant_accs[5])
log_s += f", center_bias={center_bias_val:.4f}"
logger.info(log_s)
if wandb_available():
import wandb
wandb.log({f"eval_{args.val_dataset}/center_bias": center_bias_val})
else:
raise NotImplementedError("Center bias evaluation not supported in distributed mode.")
ddp_cleanup(args=args, rank=rank)
def evaluate_size_bias(model, dataset=None, val_dataset=None, **kwargs):
"""Evaluate model accuracy for differently scaled foregrounds.
Args:
model (str): path to model state .tar
dataset (str, optional): name of the dataset to evaluate on (Default value = None)
val_dataset (str, optional): name of the dataset to evaluate on (Default value = None)
**kwargs: further arguments
Note:
If `val_dataset` is not provided, the model will be evaluated on `dataset`.
"""
set_filter_warnings()
model_path = model
args = prep_kwargs(kwargs)
if val_dataset is None:
val_dataset = dataset
if dataset is None:
dataset = val_dataset
assert val_dataset is not None and dataset is not None
args.dataset = dataset
args.val_dataset = val_dataset
if args.cuda:
args.distributed, device, world_size, rank, _ = ddp_setup()
torch.cuda.set_device(device)
else:
args.distributed = False
device = torch.device("cpu")
world_size = 1
rank = 0
args.compile_model = False
args.batch_size = int(args.batch_size / world_size)
save_state = torch.load(model_path, map_location="cpu")
old_args = prep_kwargs(save_state["args"])
args.model = old_args.model
args.dataset = dataset
args.run_name = old_args.run_name
args.experiment_name = old_args.experiment_name
args.wandb_run_id = old_args.wandb_run_id
run_folder = setup_tracking_and_logging(args, rank, append_model_path=model_path, log_wandb=False)
assert "fornet" in val_dataset.lower(), "Only ForNet supported for center bias evaluation."
_, args.n_classes, args.ignore_index, args.multi_label, __ = prepare_dataset(val_dataset, args, train=False)
model, args, old_args, save_state = load_pretrained(model, args, new_dataset_params=True)
model = model.to(device)
args.model = model_name = old_args.model
args.dataset = dataset
if rank == 0:
logger.info(
f"Evaluate model {model_name} on {val_dataset}. "
f"It was pretrained on {old_args.dataset} for {save_state['epoch']} epochs."
)
if args.distributed:
model = DDP(model)
if args.compile_model:
model = torch.compile(model)
# log all devices
logger.info(
f"evaluating on {device} -> {torch.cuda.get_device_name(device) if device.type != 'cpu' else get_cpu_name()}"
)
if rank == 0:
logger.info(f"torch version {torch.__version__}")
logger.info(f"full set of arguments: {args}")
logger.info(f"full set of old arguments: {old_args}")
if args.seed:
torch.manual_seed(args.seed)
val_criterion = LabelSmoothingCrossEntropy(smoothing=args.label_smoothing)
if rank == 0:
logger.info("start evaluation")
logger.info(f"Run info at: '{run_folder}'")
sizes = [0.1, 0.2, 0.4, 0.6, 0.8, 1.0, 1.25, 1.5, 1.75, 2.0]
if rank == 0:
size_accs = []
val_times = 0
for size in sizes:
val_loader, _, __, ___, dali_server = prepare_dataset(val_dataset, args, train=False)
val_loader.dataset.size_fact = size
val_loader.dataset.fg_scale_jitter = 0.0
logger.info(f"Evaluate size factor {size} for 5 rounds.")
round_accs = []
for _ in range(5):
val_time, val_stats = _evaluate(
model.to(device),
val_loader,
epoch=save_state["epoch"] - 1,
rank=rank,
device=device,
val_criterion=val_criterion,
args=args,
dali_server=dali_server,
)
round_accs.append(val_stats["acc1"])
val_times += val_time
size_accs.append(sum(round_accs) / len(round_accs))
log_s = f"Evaluation done in {val_times}s: "
for size, val in zip(sizes, size_accs):
log_s += f", rel_size {size}={val}% acc ({val / size_accs[sizes.index(1.0)]} rel acc)"
logger.info(log_s)
else:
raise NotImplementedError("Center bias evaluation not supported in distributed mode.")
ddp_cleanup(args=args, rank=rank)
def evaluate_attributions(model, dataset=None, val_dataset=None, **kwargs):
"""Evaluate model attributions using captum.
Args:
model (str): path to model state .tar
dataset (str, optional): name of the dataset to evaluate on (Default value = None)
val_dataset (str, optional): name of the dataset to evaluate on (Default value = None)
**kwargs: further arguments
Note:
If `val_dataset` is not provided, the model will be evaluated on `dataset`.
The `captum` package is required.
"""
from captum.attr import IntegratedGradients
from pytorch_grad_cam import GradCAM, GradCAMPlusPlus
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
set_filter_warnings()
model_path = model
args = prep_kwargs(kwargs)
if val_dataset is None:
val_dataset = dataset
assert val_dataset is not None, "Please set dataset (-ds) or validation dataset (-valds)"
args.dataset = val_dataset
args.val_dataset = val_dataset
if args.cuda:
args.distributed, device, world_size, rank, _ = ddp_setup()
torch.cuda.set_device(device)
else:
args.distributed = False
device = torch.device("cpu")
world_size = 1
rank = 0
args.compile_model = False
args.batch_size = int(args.batch_size / world_size)
save_state = torch.load(model_path, map_location="cpu")
old_args = prep_kwargs(save_state["args"])
args.model = old_args.model
args.dataset = val_dataset
args.run_name = old_args.run_name
args.experiment_name = old_args.experiment_name
args.wandb_run_id = old_args.wandb_run_id
run_folder = setup_tracking_and_logging(
args, rank, append_model_path=model_path, log_wandb=args.wandb_run_id is not None
)
assert "fornet" in val_dataset.lower(), "Only ForNet supported for attribution evaluation."
val_loader, args.n_classes, args.ignore_index, args.multi_label, dali_server = prepare_dataset(
val_dataset, args, train=False
)
val_loader.dataset.return_fg_masks = True
model, args, old_args, save_state = load_pretrained(model, args, new_dataset_params=True)
model = model.to(device)
args.model = model_name = old_args.model
args.dataset = dataset
# assert (
# args.imsize == old_args.imsize
# ), f"Model was trained on {old_args.imsize}x{old_args.imsize} images. Not {args.imsize}x{args.imsize}."
epoch = save_state["epoch"]
if rank == 0:
logger.info(
f"Evaluate attributions of model {model_name} on {dataset}. "
f"It was pretrained on {old_args.dataset} for {epoch} epochs."
)
if args.distributed:
model = DDP(model)
if args.compile_model:
model = torch.compile(model)
# log all devices
logger.info(
f"evaluating on {device} -> {torch.cuda.get_device_name(device) if device.type != 'cpu' else get_cpu_name()}"
)
if rank == 0:
logger.info(f"torch version {torch.__version__}")
if args.new_log:
logger.info(f"full set of arguments: {args}")
logger.info(f"full set of old arguments: {old_args}")
else:
logger.info(f"full set of attribution evaluation arguments: {old_args}")
if args.seed:
torch.manual_seed(args.seed)
if rank == 0:
logger.info(f"Run info at: '{run_folder}'")
iterator = (
tqdm(val_loader, total=len(val_loader), desc=f"Validating epoch {epoch}")
if rank == 0 and args.tqdm
else val_loader
)
if args.debug:
from matplotlib import pyplot as plt
eval_attn_importance = False
if isinstance(model, TimmResNet):
reshape_transform = None
target_layers = [model.layer4[-1]]
elif model_name.lower().startswith("vit-"):
reshape_transform = grad_cam_reshape_transform
target_layers = [model.blocks[-1].norm1]
eval_attn_importance = True
from architectures.vit import _MatrixSaveAttn
model.blocks[-1].attn = _MatrixSaveAttn.cast(model.blocks[-1].attn)
elif model_name.lower().startswith("swin_"):
reshape_transform = grad_cam_reshape_transform
target_layers = [model.layers[-1].blocks[-1].norm1]
else:
raise NotImplementedError(f"Model {model_name} not supported for attribution evaluation.")
model.eval()
val_start = time()
rel_ig_weights = 0.0
rel_attn_weights = 0.0
rel_cam_weights = {"GradCAM": 0.0, "GradCAM++": 0.0}
if rank == 0:
logger.info("Start attribution evaluation")
if dali_server:
dali_server.start_thread()
for batch_data in iterator:
xs, ys, fg_masks = batch_data
xs, ys, fg_masks = (
xs.to(device, non_blocking=True),
ys.to(device, non_blocking=True),
fg_masks.float().to(device, non_blocking=True),
)
with torch.amp.autocast("cuda") if args.eval_amp else nullcontext():
model.zero_grad()
ig = IntegratedGradients(model)
# we use attention temperature of 10 to make differences more apparent after exp
attr_ig = (
ig.attribute(xs, target=ys, baselines=0.0, internal_batch_size=args.batch_size * 4).sum(dim=1) * 10
) # B x W x H
attr_probs = attr_ig.view(xs.shape[0], -1).softmax(dim=-1).view(xs.shape[0], *xs.shape[2:])
fg_masks = fg_masks.view(attr_probs.shape)
fg_attrs = (attr_probs * fg_masks).sum(dim=(-1, -2))
rel_attr_weight = fg_attrs / fg_masks.mean(dim=(-1, -2))
rel_attr_weight = torch.where(fg_masks.mean(dim=(-1, -2)) > 0, rel_attr_weight, 1.0)
if rel_attr_weight.isnan().any():
logger.error(f"NaNs in rel_attr_weight: {rel_attr_weight}, fg_mask_weights: {fg_masks.mean(dim=(-1, -2))}")
break
rel_ig_weights += rel_attr_weight.mean().item()
cam_targets = [ClassifierOutputTarget(int(trgt)) for trgt in ys.tolist()]
for method, name in zip([GradCAM, GradCAMPlusPlus], ["GradCAM", "GradCAM++"]):
with method(model=model, target_layers=target_layers, reshape_transform=reshape_transform) as cam, (
torch.amp.autocast("cuda") if args.eval_amp else nullcontext()
):
cam_attr = cam(input_tensor=xs, targets=cam_targets)
cam_attr = torch.from_numpy(cam_attr).to(device)
rel_cam_attr = (cam_attr * fg_masks).sum(dim=(-1, -2)) / cam_attr.sum(dim=(-1, -2))
cam_attr_weight = rel_cam_attr / fg_masks.mean(dim=(-1, -2))
cam_attr_weight = torch.where(
(fg_masks.mean(dim=(-1, -2)) > 0) & (cam_attr.sum(dim=(-1, -2)) > 0), cam_attr_weight, 1.0
)
rel_cam_weights[name] += cam_attr_weight.mean().item()
if cam_attr_weight.isnan().any():
logger.error(
f"NaNs in cam_attr_weight ({name}): {cam_attr_weight}, fg_mask_weights:"
f" {fg_masks.mean(dim=(-1, -2))}"
)
break
if eval_attn_importance:
with torch.amp.autocast("cuda") if args.eval_amp else nullcontext():
pred = model(xs) # noqa: F841
last_attn_mat = model.blocks[-1].attn.attn_mat
cls_tkn_attn = last_attn_mat[:, :, 0, 1:].mean(dim=1).squeeze(dim=1) # B x H x 1(CLS Token) X N -> B x N
B, N = cls_tkn_attn.shape
att_HW = int(sqrt(N))
cls_tkn_attn = cls_tkn_attn.view(B, 1, att_HW, att_HW)
attn_attr = F.interpolate(
cls_tkn_attn, size=(xs.shape[-2], xs.shape[-1]), mode="bilinear", align_corners=False
).view(B, xs.shape[-2], xs.shape[-1])
rel_attn_attr = (attn_attr * fg_masks).sum(dim=(-1, -2)) / attn_attr.sum(dim=(-1, -2))
attn_attr_weight = rel_attn_attr / fg_masks.mean(dim=(-1, -2))
attn_attr_weight = torch.where(
(fg_masks.mean(dim=(-1, -2)) > 0) & (attn_attr.sum(dim=(-1, -2)) > 0), attn_attr_weight, 1.0
)
rel_attn_weights += attn_attr_weight.mean().item()
if args.debug:
logger.debug(f"Attribution scores: IG: {rel_attr_weight[:5]}, GradCAM(++): {cam_attr_weight[:5]}")
num_subplots = 5 if eval_attn_importance else 4
fig, axs = plt.subplots(num_subplots, 4)
for plt_i in range(4):
axs[0][plt_i].imshow(denormalize(xs[plt_i]).permute(1, 2, 0).cpu().numpy())
axs[1][plt_i].imshow(fg_masks[plt_i].cpu().numpy())
axs[2][plt_i].imshow(attr_probs[plt_i].cpu().numpy())
axs[3][plt_i].imshow(cam_attr[plt_i].cpu().numpy())
if eval_attn_importance:
axs[4][plt_i].imshow(attn_attr[plt_i].cpu().numpy())
plt.show()
iterator_desc = (
f"IG weights: {rel_ig_weights / (iterator.n + 1):.4f}, GradCAM weights:"
f" {rel_cam_weights['GradCAM'] / (iterator.n + 1):.4f}, GradCAM++ weights:"
f" {rel_cam_weights['GradCAM++'] / (iterator.n + 1):.4f}"
)
if eval_attn_importance:
iterator_desc += f", Attn weights: {rel_attn_weights / (iterator.n + 1):.4f}"
iterator.set_description(iterator_desc)
if args.distributed:
dist.barrier()
val_end = time()
rel_ig_weights /= len(iterator)
rel_grad_cam = rel_cam_weights["GradCAM"] / len(iterator)
rel_grad_cam_pp = rel_cam_weights["GradCAM++"] / len(iterator)
rel_attn_weights /= len(iterator)
if dali_server:
dali_server.stop_thread()
if args.distributed:
gather_tensor = torch.Tensor([rel_ig_weights, rel_grad_cam, rel_grad_cam_pp, rel_attn_weights]).to(device)
dist.barrier()
dist.all_reduce(gather_tensor)
gather_tensor = (gather_tensor / world_size).tolist()
rel_ig_weights, rel_grad_cam, rel_grad_cam_pp, rel_attn_weights = gather_tensor
if rank == 0:
output_text = (
f"epoch {epoch}: eval_{args.val_dataset}/rel_ig_weights={rel_ig_weights},"
f" eval_{args.val_dataset}/rel_grad_cam={rel_grad_cam},"
f" eval_{args.val_dataset}/rel_grad_cam_pp={rel_grad_cam_pp}"
)
if eval_attn_importance:
output_text += f", eval_{args.val_dataset}/rel_attn_weights={rel_attn_weights}"
output_text += f", eval_{args.val_dataset}/attribution_eval_time={val_end - val_start}s"
logger.info(output_text)
if wandb_available():
import wandb
wandb_data = {
f"eval_{args.val_dataset}/importance_ig": rel_ig_weights,
f"eval_{args.val_dataset}/importance_grad_cam": rel_grad_cam,
f"eval_{args.val_dataset}/importance_grad_cam_pp": rel_grad_cam_pp,
}
if eval_attn_importance:
wandb_data[f"eval_{args.val_dataset}/importance_attn"] = rel_attn_weights
wandb.log(wandb_data)
ddp_cleanup(args=args, rank=rank)