File size: 1,833 Bytes
c1596ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch

from torchmetrics.classification import (
    MulticlassAccuracy,
    MulticlassF1Score,

    # precision / recall
    MulticlassPrecision,
    MulticlassRecall
)


def validation_one_epoch(
    model,
    loader,
    criterion,
    device,
    num_classes
):

    model.eval()

    acc_metric = MulticlassAccuracy(
        num_classes=num_classes
    ).to(device)

    f1_metric = MulticlassF1Score(
        num_classes=num_classes,
        average="macro"
    ).to(device)

    # precision metric
    # precision_metric = MulticlassPrecision(
    #     num_classes=num_classes,
    #     average="macro"
    # ).to(device)

    # recall metric
    # recall_metric = MulticlassRecall(
    #     num_classes=num_classes,
    #     average="macro"
    # ).to(device)

    total_loss = 0

    with torch.no_grad():

        for images, labels in loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            loss = criterion(
                outputs,
                labels
            )
            total_loss += loss.item()
            preds = outputs.argmax(dim=1)

            acc_metric.update(
                preds,
                labels
            )

            f1_metric.update(
                preds,
                labels
            )

            # precision_metric.update(
            #     preds,
            #     labels
            # )

            # recall_metric.update(
            #     preds,
            #     labels
            # )

    acc = acc_metric.compute().item()
    f1 = f1_metric.compute().item()
    # precision = precision_metric.compute().item()
    # recall = recall_metric.compute().item()

    return (
        total_loss / len(loader),
        acc,
        f1,
        # precision,
        # recall
    )