DamageLensAI / src /export /conver_model.py
junaid17's picture
Upload 43 files
eef8873 verified
import os
import logging
import torch
from src.config import DEVICE, NUM_CLASSES, CHECKPOINT_DIR
from src.models.fusion_model import FusionClassifier
logger = logging.getLogger(__name__)
INPUT_CHECKPOINT = CHECKPOINT_DIR / "best_fusion_model.pt"
OUTPUT_CHECKPOINT = CHECKPOINT_DIR / "best_fusion_model_fp16.pt"
def convert_fusion_to_fp16():
logger.info("Initializing Fusion model for FP16 conversion...")
if not INPUT_CHECKPOINT.exists():
raise FileNotFoundError(
f"Fusion checkpoint not found: {INPUT_CHECKPOINT}"
)
model = FusionClassifier(
num_classes=NUM_CLASSES
).to(DEVICE)
logger.info(f"Loading checkpoint from: {INPUT_CHECKPOINT}")
checkpoint = torch.load(
INPUT_CHECKPOINT,
map_location=DEVICE
)
if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
model.load_state_dict(checkpoint["model_state_dict"])
else:
model.load_state_dict(checkpoint)
logger.info("Model weights loaded successfully.")
model.eval()
logger.info("Converting model to FP16...")
model = model.half()
torch.save(
model.state_dict(),
OUTPUT_CHECKPOINT
)
size_mb = os.path.getsize(OUTPUT_CHECKPOINT) / (1024 * 1024)
logger.info(f"FP16 model saved at: {OUTPUT_CHECKPOINT}")
logger.info(f"FP16 model size: {size_mb:.2f} MB")
return OUTPUT_CHECKPOINT
if __name__ == "__main__":
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s"
)
fp16_path = convert_fusion_to_fp16()
print("\nFusion FP16 conversion completed successfully.")
print(f"Saved model: {fp16_path}")