| import os |
| import cv2 |
| import torch |
| import numpy as np |
| import pandas as pd |
| from torch.utils.data import Dataset, DataLoader |
| from torchvision import transforms |
| from pytorchvideo.models.resnet import create_resnet |
| import torch.nn as nn |
| import torch.optim as optim |
| from tqdm import tqdm |
|
|
| |
| |
| |
| class AirLettersDataset(Dataset): |
| def __init__(self, csv_path, video_dir, num_frames=8, image_size=224): |
| self.df = pd.read_csv(csv_path) |
| self.df.columns = self.df.columns.str.strip() |
| self.video_dir = video_dir |
| self.num_frames = num_frames |
| self.image_size = image_size |
| self.transform = transforms.Compose([ |
| transforms.ToTensor(), |
| transforms.Resize((image_size, image_size)), |
| transforms.Normalize(mean=[0.45, 0.45, 0.45], std=[0.225, 0.225, 0.225]) |
| ]) |
|
|
| def __len__(self): |
| return len(self.df) |
|
|
| def __getitem__(self, idx): |
| for _ in range(10): |
| row = self.df.iloc[idx] |
| video_path = os.path.join(self.video_dir, row['filename']) |
| frames = self._load_video(video_path) |
| if frames is not None: |
| label = self._label_to_id(row['label']) |
| return frames, label |
| idx = np.random.randint(0, len(self.df)) |
| raise RuntimeError("Too many unreadable videos in dataset.") |
|
|
| def _label_to_id(self, label_text): |
| label_text = label_text.lower() |
| if "letter" in label_text: |
| char = label_text.split("letter")[-1].strip().split()[0] |
| return ord(char.upper()) - ord('A') |
| elif "digit" in label_text: |
| digit = label_text.split("digit")[-1].strip().split()[0] |
| return 26 + int(digit) |
| else: |
| return 36 |
|
|
| def _load_video(self, video_path): |
| try: |
| cap = cv2.VideoCapture(video_path) |
| total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
| if total == 0 or not cap.isOpened(): |
| raise ValueError("Unreadable video") |
|
|
| frames = [] |
| step = max(1, total // self.num_frames) |
|
|
| for i in range(self.num_frames): |
| cap.set(cv2.CAP_PROP_POS_FRAMES, i * step) |
| ret, frame = cap.read() |
| if not ret or frame is None: |
| continue |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| frame = self.transform(frame) |
| frames.append(frame) |
|
|
| cap.release() |
|
|
| if len(frames) == 0: |
| raise ValueError("No valid frames") |
|
|
| while len(frames) < self.num_frames: |
| frames.append(torch.zeros_like(frames[0])) |
|
|
| return torch.stack(frames).permute(1, 0, 2, 3) |
|
|
| except Exception as e: |
| print(f"[WARNING] Skipping unreadable video: {video_path} ({str(e)})") |
| return None |
|
|
|
|
| |
| |
| |
| CHECKPOINT_PATH = "checkpoint.pth" |
| SAVE_INTERVAL = 10000 |
|
|
| def train(model, train_loader, val_loader, test_loader, device): |
| criterion = nn.CrossEntropyLoss() |
| optimizer = optim.Adam(model.parameters(), lr=1e-4) |
|
|
| |
| start_epoch = 0 |
| global_step = 0 |
| resume_batch_idx = 0 |
|
|
| |
| if os.path.exists(CHECKPOINT_PATH): |
| checkpoint = torch.load(CHECKPOINT_PATH, map_location=device) |
|
|
| model.load_state_dict(checkpoint['model']) |
| optimizer.load_state_dict(checkpoint['optimizer']) |
|
|
| start_epoch = checkpoint['epoch'] |
| global_step = checkpoint['step'] |
| resume_batch_idx = checkpoint['batch_idx'] |
|
|
| print(f"π Resuming from Epoch {start_epoch}, Batch {resume_batch_idx}, Step {global_step}") |
|
|
| for epoch in range(start_epoch, 5): |
| model.train() |
| running_loss = 0.0 |
| correct = 0 |
| total = 0 |
|
|
| loop = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/5") |
|
|
| for batch_idx, (inputs, labels) in loop: |
|
|
| |
| if epoch == start_epoch and batch_idx < resume_batch_idx: |
| continue |
|
|
| inputs, labels = inputs.to(device), labels.to(device) |
|
|
| optimizer.zero_grad() |
| outputs = model(inputs) |
| loss = criterion(outputs, labels) |
| loss.backward() |
| optimizer.step() |
|
|
| running_loss += loss.item() |
|
|
| _, predicted = outputs.max(1) |
| total += labels.size(0) |
| correct += predicted.eq(labels).sum().item() |
|
|
| global_step += 1 |
|
|
| |
| if global_step % SAVE_INTERVAL == 0: |
| torch.save({ |
| 'epoch': epoch, |
| 'step': global_step, |
| 'batch_idx': batch_idx, |
| 'model': model.state_dict(), |
| 'optimizer': optimizer.state_dict() |
| }, CHECKPOINT_PATH) |
|
|
| print(f"\nπΎ Checkpoint saved at step {global_step}") |
|
|
| |
| resume_batch_idx = 0 |
|
|
| train_acc = 100. * correct / total |
| print(f"\nβ
Epoch {epoch+1} - Loss: {running_loss/len(train_loader):.4f}, Train Accuracy: {train_acc:.2f}%") |
|
|
| |
| torch.save({ |
| 'epoch': epoch + 1, |
| 'step': global_step, |
| 'batch_idx': 0, |
| 'model': model.state_dict(), |
| 'optimizer': optimizer.state_dict() |
| }, CHECKPOINT_PATH) |
|
|
| |
| model.eval() |
| val_correct = 0 |
| val_total = 0 |
|
|
| with torch.no_grad(): |
| for inputs, labels in val_loader: |
| inputs, labels = inputs.to(device), labels.to(device) |
| outputs = model(inputs) |
| _, predicted = outputs.max(1) |
| val_total += labels.size(0) |
| val_correct += predicted.eq(labels).sum().item() |
|
|
| val_acc = 100. * val_correct / val_total |
| print(f"β
Validation Accuracy: {val_acc:.2f}%") |
|
|
| |
| test_correct = 0 |
| test_total = 0 |
|
|
| with torch.no_grad(): |
| for inputs, labels in test_loader: |
| inputs, labels = inputs.to(device), labels.to(device) |
| outputs = model(inputs) |
| _, predicted = outputs.max(1) |
| test_total += labels.size(0) |
| test_correct += predicted.eq(labels).sum().item() |
|
|
| test_acc = 100. * test_correct / test_total |
| print(f"π― Final Test Accuracy: {test_acc:.2f}%") |
|
|
| |
| torch.save(model.state_dict(), "resnext200_airletters.pth") |
| print("\nβ
Model saved to resnext200_airletters.pth") |
| print("π¦ Please upload this file to Hugging Face to preserve it.") |
|
|
| |
| |
| |
| if __name__ == "__main__": |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print("π Using device:", device) |
|
|
| train_csv = "train.csv" |
| val_csv = "val.csv" |
| test_csv = "test.csv" |
| video_dir = "/home/mluser/dataset/dataset/videos/videos" |
|
|
| train_set = AirLettersDataset(train_csv, video_dir) |
| val_set = AirLettersDataset(val_csv, video_dir) |
| test_set = AirLettersDataset(test_csv, video_dir) |
|
|
| train_loader = DataLoader(train_set, batch_size=2, shuffle=True, num_workers=2) |
| val_loader = DataLoader(val_set, batch_size=2, shuffle=False, num_workers=2) |
| test_loader = DataLoader(test_set, batch_size=2, shuffle=False, num_workers=2) |
|
|
| model = create_resnet( |
| input_channel=3, |
| model_num_class=37, |
| model_depth=101, |
| norm=nn.BatchNorm3d, |
| activation=nn.ReLU |
| ).to(device) |
| train(model, train_loader, val_loader, test_loader, device) |
|
|
|
|