Spaces:
Sleeping
Sleeping
| import os | |
| import logging | |
| import torch | |
| from src.config import DEVICE, NUM_CLASSES, CHECKPOINT_DIR | |
| from src.models.fusion_model import FusionClassifier | |
| logger = logging.getLogger(__name__) | |
| INPUT_CHECKPOINT = CHECKPOINT_DIR / "best_fusion_model.pt" | |
| OUTPUT_CHECKPOINT = CHECKPOINT_DIR / "best_fusion_model_fp16.pt" | |
| def convert_fusion_to_fp16(): | |
| logger.info("Initializing Fusion model for FP16 conversion...") | |
| if not INPUT_CHECKPOINT.exists(): | |
| raise FileNotFoundError( | |
| f"Fusion checkpoint not found: {INPUT_CHECKPOINT}" | |
| ) | |
| model = FusionClassifier( | |
| num_classes=NUM_CLASSES | |
| ).to(DEVICE) | |
| logger.info(f"Loading checkpoint from: {INPUT_CHECKPOINT}") | |
| checkpoint = torch.load( | |
| INPUT_CHECKPOINT, | |
| map_location=DEVICE | |
| ) | |
| if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint: | |
| model.load_state_dict(checkpoint["model_state_dict"]) | |
| else: | |
| model.load_state_dict(checkpoint) | |
| logger.info("Model weights loaded successfully.") | |
| model.eval() | |
| logger.info("Converting model to FP16...") | |
| model = model.half() | |
| torch.save( | |
| model.state_dict(), | |
| OUTPUT_CHECKPOINT | |
| ) | |
| size_mb = os.path.getsize(OUTPUT_CHECKPOINT) / (1024 * 1024) | |
| logger.info(f"FP16 model saved at: {OUTPUT_CHECKPOINT}") | |
| logger.info(f"FP16 model size: {size_mb:.2f} MB") | |
| return OUTPUT_CHECKPOINT | |
| if __name__ == "__main__": | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s - %(levelname)s - %(message)s" | |
| ) | |
| fp16_path = convert_fusion_to_fp16() | |
| print("\nFusion FP16 conversion completed successfully.") | |
| print(f"Saved model: {fp16_path}") |