DamageLensAI / src /data /dataset.py
junaid17's picture
Upload 43 files
eef8873 verified
import logging
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from transformers import ConvNextImageProcessor
from src.config import (
BATCH_SIZE,
NUM_WORKERS
)
from src.data.ingestion import collect_image_paths
from src.data.preprocessing import split_dataset
from src.data.augmentation import (
get_resnet_train_transforms,
get_resnet_val_transforms,
get_fusion_train_transforms,
get_fusion_val_transforms
)
logger = logging.getLogger(__name__)
class ResNetDataset(Dataset):
def __init__(self, samples, transforms=None):
self.samples = samples
self.transforms = transforms
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
image_path, label = self.samples[idx]
image = Image.open(image_path).convert("RGB")
if self.transforms:
image = self.transforms(image)
return image, label
class FusionDataset(Dataset):
def __init__(
self,
samples,
transforms=None,
convnext_model_name="facebook/convnext-small-224"
):
self.samples = samples
self.transforms = transforms
logger.info("Loading ConvNeXt processor...")
self.processor = ConvNextImageProcessor.from_pretrained(
convnext_model_name
)
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
image_path, label = self.samples[idx]
image = Image.open(image_path).convert("RGB")
if self.transforms:
eff_tensor = self.transforms(image)
else:
raise ValueError("Fusion transforms are required.")
convnext_inputs = self.processor(
images=image,
return_tensors="pt"
)
convnext_tensor = convnext_inputs["pixel_values"].squeeze(0)
return {
"pixel_values_eff": eff_tensor,
"pixel_values_cnx": convnext_tensor,
"labels": label
}
def create_resnet_dataloaders():
logger.info("Creating ResNet dataloaders...")
samples = collect_image_paths()
train_data, val_data = split_dataset(samples)
train_dataset = ResNetDataset(
train_data,
transforms=get_resnet_train_transforms()
)
val_dataset = ResNetDataset(
val_data,
transforms=get_resnet_val_transforms()
)
train_loader = DataLoader(
train_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=NUM_WORKERS
)
val_loader = DataLoader(
val_dataset,
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=NUM_WORKERS
)
logger.info("ResNet dataloaders created successfully.")
return train_loader, val_loader
def create_fusion_dataloaders():
logger.info("Creating Fusion dataloaders...")
samples = collect_image_paths()
train_data, val_data = split_dataset(samples)
train_dataset = FusionDataset(
train_data,
transforms=get_fusion_train_transforms()
)
val_dataset = FusionDataset(
val_data,
transforms=get_fusion_val_transforms()
)
train_loader = DataLoader(
train_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=NUM_WORKERS
)
val_loader = DataLoader(
val_dataset,
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=NUM_WORKERS
)
logger.info("Fusion dataloaders created successfully.")
return train_loader, val_loader
if __name__ == "__main__":
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s"
)
print("\nTesting ResNet dataloaders...\n")
train_loader, val_loader = create_resnet_dataloaders()
images, labels = next(iter(train_loader))
print("ResNet batch shape:", images.shape)
print("ResNet labels shape:", labels.shape)
print("\nTesting Fusion dataloaders...\n")
train_loader, val_loader = create_fusion_dataloaders()
batch = next(iter(train_loader))
print(
"Fusion EfficientNet batch shape:",
batch["pixel_values_eff"].shape
)
print(
"Fusion ConvNeXt batch shape:",
batch["pixel_values_cnx"].shape
)
print(
"Fusion labels shape:",
batch["labels"].shape
)