File size: 1,457 Bytes
c1596ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from torchmetrics.classification import (
    MulticlassAccuracy
)

from transforms.mixup import mixup_data
from transforms.cutmix import cutmix_data


def train_one_epoch(
    model,
    loader,
    criterion,
    optimizer,
    device,
    num_classes,
    augmentation=None
):

    model.train()

    metric = MulticlassAccuracy(
        num_classes=num_classes
    ).to(device)

    total_loss = 0

    for images, labels in loader:
        images = images.to(device)
        labels = labels.to(device)

        if augmentation == "mixup":
            images, labels_a, labels_b, lam = mixup_data(
                images,
                labels
            )

        elif augmentation == "cutmix":
            images, labels_a, labels_b, lam = cutmix_data(
                images,
                labels
            )

        outputs = model(images)

        if augmentation in ["mixup", "cutmix"]:
            loss = (
                lam * criterion(outputs, labels_a)
                + (1 - lam) * criterion(outputs, labels_b)
            )

        else:
            loss = criterion(
                outputs,
                labels
            )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        preds = outputs.argmax(dim=1)

        metric.update(
            preds,
            labels
        )

    acc = metric.compute().item()

    return total_loss / len(loader), acc