Spaces:
Running
Running
File size: 1,833 Bytes
c1596ac | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 | import torch
from torchmetrics.classification import (
MulticlassAccuracy,
MulticlassF1Score,
# precision / recall
MulticlassPrecision,
MulticlassRecall
)
def validation_one_epoch(
model,
loader,
criterion,
device,
num_classes
):
model.eval()
acc_metric = MulticlassAccuracy(
num_classes=num_classes
).to(device)
f1_metric = MulticlassF1Score(
num_classes=num_classes,
average="macro"
).to(device)
# precision metric
# precision_metric = MulticlassPrecision(
# num_classes=num_classes,
# average="macro"
# ).to(device)
# recall metric
# recall_metric = MulticlassRecall(
# num_classes=num_classes,
# average="macro"
# ).to(device)
total_loss = 0
with torch.no_grad():
for images, labels in loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
loss = criterion(
outputs,
labels
)
total_loss += loss.item()
preds = outputs.argmax(dim=1)
acc_metric.update(
preds,
labels
)
f1_metric.update(
preds,
labels
)
# precision_metric.update(
# preds,
# labels
# )
# recall_metric.update(
# preds,
# labels
# )
acc = acc_metric.compute().item()
f1 = f1_metric.compute().item()
# precision = precision_metric.compute().item()
# recall = recall_metric.compute().item()
return (
total_loss / len(loader),
acc,
f1,
# precision,
# recall
) |