Spaces:
Running
Running
| from torchmetrics.classification import ( | |
| MulticlassAccuracy | |
| ) | |
| from transforms.mixup import mixup_data | |
| from transforms.cutmix import cutmix_data | |
| def train_one_epoch( | |
| model, | |
| loader, | |
| criterion, | |
| optimizer, | |
| device, | |
| num_classes, | |
| augmentation=None | |
| ): | |
| model.train() | |
| metric = MulticlassAccuracy( | |
| num_classes=num_classes | |
| ).to(device) | |
| total_loss = 0 | |
| for images, labels in loader: | |
| images = images.to(device) | |
| labels = labels.to(device) | |
| if augmentation == "mixup": | |
| images, labels_a, labels_b, lam = mixup_data( | |
| images, | |
| labels | |
| ) | |
| elif augmentation == "cutmix": | |
| images, labels_a, labels_b, lam = cutmix_data( | |
| images, | |
| labels | |
| ) | |
| outputs = model(images) | |
| if augmentation in ["mixup", "cutmix"]: | |
| loss = ( | |
| lam * criterion(outputs, labels_a) | |
| + (1 - lam) * criterion(outputs, labels_b) | |
| ) | |
| else: | |
| loss = criterion( | |
| outputs, | |
| labels | |
| ) | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| total_loss += loss.item() | |
| preds = outputs.argmax(dim=1) | |
| metric.update( | |
| preds, | |
| labels | |
| ) | |
| acc = metric.compute().item() | |
| return total_loss / len(loader), acc |