expanded_model / train_2.py
DIMPU1516's picture
Upload train_2.py with huggingface_hub
bbef64d verified
import os
import cv2
import torch
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from pytorchvideo.models.resnet import create_resnet
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
# -------------------------------
# Custom Dataset for AirLetters
# -------------------------------
class AirLettersDataset(Dataset):
def __init__(self, csv_path, video_dir, num_frames=8, image_size=224):
self.df = pd.read_csv(csv_path)
self.df.columns = self.df.columns.str.strip()
self.video_dir = video_dir
self.num_frames = num_frames
self.image_size = image_size
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((image_size, image_size)),
transforms.Normalize(mean=[0.45, 0.45, 0.45], std=[0.225, 0.225, 0.225])
])
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
for _ in range(10):
row = self.df.iloc[idx]
video_path = os.path.join(self.video_dir, row['filename'])
frames = self._load_video(video_path)
if frames is not None:
label = self._label_to_id(row['label'])
return frames, label
idx = np.random.randint(0, len(self.df))
raise RuntimeError("Too many unreadable videos in dataset.")
def _label_to_id(self, label_text):
label_text = label_text.lower()
if "letter" in label_text:
char = label_text.split("letter")[-1].strip().split()[0]
return ord(char.upper()) - ord('A')
elif "digit" in label_text:
digit = label_text.split("digit")[-1].strip().split()[0]
return 26 + int(digit)
else:
return 36
def _load_video(self, video_path):
try:
cap = cv2.VideoCapture(video_path)
total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
if total == 0 or not cap.isOpened():
raise ValueError("Unreadable video")
frames = []
step = max(1, total // self.num_frames)
for i in range(self.num_frames):
cap.set(cv2.CAP_PROP_POS_FRAMES, i * step)
ret, frame = cap.read()
if not ret or frame is None:
continue
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = self.transform(frame)
frames.append(frame)
cap.release()
if len(frames) == 0:
raise ValueError("No valid frames")
while len(frames) < self.num_frames:
frames.append(torch.zeros_like(frames[0]))
return torch.stack(frames).permute(1, 0, 2, 3)
except Exception as e:
print(f"[WARNING] Skipping unreadable video: {video_path} ({str(e)})")
return None
# -------------------------------
# Train + Evaluate Function
# -------------------------------
CHECKPOINT_PATH = "checkpoint.pth"
SAVE_INTERVAL = 10000
def train(model, train_loader, val_loader, test_loader, device):
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
# ===== Resume variables =====
start_epoch = 0
global_step = 0
resume_batch_idx = 0
# ===== Load checkpoint if exists =====
if os.path.exists(CHECKPOINT_PATH):
checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch = checkpoint['epoch']
global_step = checkpoint['step']
resume_batch_idx = checkpoint['batch_idx']
print(f"πŸ” Resuming from Epoch {start_epoch}, Batch {resume_batch_idx}, Step {global_step}")
for epoch in range(start_epoch, 5):
model.train()
running_loss = 0.0
correct = 0
total = 0
loop = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/5")
for batch_idx, (inputs, labels) in loop:
# Skip already-trained batches only on resume epoch
if epoch == start_epoch and batch_idx < resume_batch_idx:
continue
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
global_step += 1
# Save checkpoint every 10,000 steps
if global_step % SAVE_INTERVAL == 0:
torch.save({
'epoch': epoch,
'step': global_step,
'batch_idx': batch_idx,
'model': model.state_dict(),
'optimizer': optimizer.state_dict()
}, CHECKPOINT_PATH)
print(f"\nπŸ’Ύ Checkpoint saved at step {global_step}")
# Reset after first resumed epoch
resume_batch_idx = 0
train_acc = 100. * correct / total
print(f"\nβœ… Epoch {epoch+1} - Loss: {running_loss/len(train_loader):.4f}, Train Accuracy: {train_acc:.2f}%")
# Save checkpoint at end of epoch
torch.save({
'epoch': epoch + 1,
'step': global_step,
'batch_idx': 0,
'model': model.state_dict(),
'optimizer': optimizer.state_dict()
}, CHECKPOINT_PATH)
# βœ… Run validation after each epoch
model.eval()
val_correct = 0
val_total = 0
with torch.no_grad():
for inputs, labels in val_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
_, predicted = outputs.max(1)
val_total += labels.size(0)
val_correct += predicted.eq(labels).sum().item()
val_acc = 100. * val_correct / val_total
print(f"βœ… Validation Accuracy: {val_acc:.2f}%")
# βœ… Final Test Accuracy
test_correct = 0
test_total = 0
with torch.no_grad():
for inputs, labels in test_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
_, predicted = outputs.max(1)
test_total += labels.size(0)
test_correct += predicted.eq(labels).sum().item()
test_acc = 100. * test_correct / test_total
print(f"🎯 Final Test Accuracy: {test_acc:.2f}%")
# βœ… Save final model
torch.save(model.state_dict(), "resnext200_airletters.pth")
print("\nβœ… Model saved to resnext200_airletters.pth")
print("πŸ“¦ Please upload this file to Hugging Face to preserve it.")
# -------------------------------
# Entry Point
# -------------------------------
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("πŸš€ Using device:", device)
train_csv = "train.csv" # Update with your path
val_csv = "val.csv" # Update with your path
test_csv = "test.csv" # Update with your path
video_dir = "/home/mluser/dataset/dataset/videos/videos" # Update with your path
train_set = AirLettersDataset(train_csv, video_dir)
val_set = AirLettersDataset(val_csv, video_dir)
test_set = AirLettersDataset(test_csv, video_dir)
train_loader = DataLoader(train_set, batch_size=2, shuffle=True, num_workers=2)
val_loader = DataLoader(val_set, batch_size=2, shuffle=False, num_workers=2)
test_loader = DataLoader(test_set, batch_size=2, shuffle=False, num_workers=2)
model = create_resnet(
input_channel=3,
model_num_class=37,
model_depth=101,
norm=nn.BatchNorm3d,
activation=nn.ReLU
).to(device)
train(model, train_loader, val_loader, test_loader, device)