DamageLensAI / src /models /fusion_model.py
junaid17's picture
Upload 43 files
eef8873 verified
import logging
import torch
import torch.nn as nn
from torchvision import models
from transformers import ConvNextModel
logger = logging.getLogger(__name__)
class FusionClassifier(nn.Module):
def __init__(
self,
num_classes,
convnext_model_name="facebook/convnext-small-224"
):
super().__init__()
logger.info("Initializing Fusion model...")
# EfficientNet-V2-S
eff = models.efficientnet_v2_s(
weights=models.EfficientNet_V2_S_Weights.IMAGENET1K_V1
)
for param in eff.parameters():
param.requires_grad = False
for param in eff.features[5].parameters():
param.requires_grad = True
for param in eff.features[6].parameters():
param.requires_grad = True
for param in eff.features[7].parameters():
param.requires_grad = True
self.eff_features = eff.features
self.eff_avgpool = eff.avgpool
self.eff_out_dim = eff.classifier[1].in_features
# ConvNeXt
cnx = ConvNextModel.from_pretrained(convnext_model_name)
for param in cnx.parameters():
param.requires_grad = False
for param in cnx.encoder.stages[2].parameters():
param.requires_grad = True
for param in cnx.encoder.stages[3].parameters():
param.requires_grad = True
for param in cnx.layernorm.parameters():
param.requires_grad = True
self.cnx_backbone = cnx
self.cnx_out_dim = 768
fused_dim = self.eff_out_dim + self.cnx_out_dim
self.fusion_head = nn.Sequential(
nn.Dropout(0.4),
nn.Linear(fused_dim, 512),
nn.LayerNorm(512),
nn.GELU(),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.LayerNorm(256),
nn.GELU(),
nn.Dropout(0.2),
nn.Linear(256, num_classes)
)
logger.info("Fusion model initialized successfully.")
def forward(self, pixel_values_eff, pixel_values_cnx):
x_eff = self.eff_features(pixel_values_eff)
x_eff = self.eff_avgpool(x_eff)
x_eff = torch.flatten(x_eff, 1)
cnx_out = self.cnx_backbone(
pixel_values=pixel_values_cnx,
return_dict=True
)
x_cnx = cnx_out.pooler_output
fused = torch.cat([x_eff, x_cnx], dim=1)
logits = self.fusion_head(fused)
return logits
if __name__ == "__main__":
import logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s"
)
model = FusionClassifier(num_classes=6)
eff_dummy = torch.randn(2, 3, 260, 260)
cnx_dummy = torch.randn(2, 3, 224, 224)
output = model(eff_dummy, cnx_dummy)
print("Fusion output shape:", output.shape)