import torch from torchmetrics.classification import ( MulticlassAccuracy, MulticlassF1Score, # precision / recall MulticlassPrecision, MulticlassRecall ) def validation_one_epoch( model, loader, criterion, device, num_classes ): model.eval() acc_metric = MulticlassAccuracy( num_classes=num_classes ).to(device) f1_metric = MulticlassF1Score( num_classes=num_classes, average="macro" ).to(device) # precision metric # precision_metric = MulticlassPrecision( # num_classes=num_classes, # average="macro" # ).to(device) # recall metric # recall_metric = MulticlassRecall( # num_classes=num_classes, # average="macro" # ).to(device) total_loss = 0 with torch.no_grad(): for images, labels in loader: images = images.to(device) labels = labels.to(device) outputs = model(images) loss = criterion( outputs, labels ) total_loss += loss.item() preds = outputs.argmax(dim=1) acc_metric.update( preds, labels ) f1_metric.update( preds, labels ) # precision_metric.update( # preds, # labels # ) # recall_metric.update( # preds, # labels # ) acc = acc_metric.compute().item() f1 = f1_metric.compute().item() # precision = precision_metric.compute().item() # recall = recall_metric.compute().item() return ( total_loss / len(loader), acc, f1, # precision, # recall )