Spaces:
Sleeping
Sleeping
| import os | |
| import uuid | |
| import shutil | |
| import logging | |
| from contextlib import asynccontextmanager | |
| from PIL import Image | |
| from fastapi import FastAPI, UploadFile, File, HTTPException | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from dotenv import load_dotenv | |
| from scripts.gradcam import get_resnet_gradcam, get_fusion_gradcam | |
| from scripts.yolo_predict import get_yolo_damage_boxes | |
| from scripts.load_models import initialize_models | |
| # ---------------- LOGGING ---------------- | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s - %(levelname)s - %(name)s - %(message)s" | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # ---------------- ENV ---------------- | |
| load_dotenv() | |
| # ---------------- DIRECTORIES ---------------- | |
| UPLOAD_DIR = "static/uploads" | |
| RESULT_DIR = "static/results" | |
| os.makedirs(UPLOAD_DIR, exist_ok=True) | |
| os.makedirs(RESULT_DIR, exist_ok=True) | |
| # ---------------- GLOBAL MODELS ---------------- | |
| resnet_predictor = None | |
| fusion_predictor = None | |
| yolo_model = None | |
| CLASS_MAP = { | |
| 0: "Front Breakage", | |
| 1: "Front Crushed", | |
| 2: "Front Normal", | |
| 3: "Rear Breakage", | |
| 4: "Rear Crushed", | |
| 5: "Rear Normal" | |
| } | |
| # ---------------- FASTAPI STARTUP ---------------- | |
| async def lifespan(app: FastAPI): | |
| global resnet_predictor, fusion_predictor, yolo_model | |
| logger.info("Loading models at startup...") | |
| try: | |
| resnet_predictor, fusion_predictor, yolo_model = initialize_models(CLASS_MAP) | |
| logger.info("All models loaded successfully.") | |
| except Exception as e: | |
| logger.exception("Model loading failed.") | |
| raise RuntimeError(str(e)) | |
| yield | |
| logger.info("Application shutdown.") | |
| # ---------------- APP ---------------- | |
| app = FastAPI(lifespan=lifespan) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # restrict this in production | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| # ---------------- HELPERS ---------------- | |
| def validate_image(upload_file: UploadFile): | |
| if not upload_file.content_type.startswith("image/"): | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Uploaded file must be an image." | |
| ) | |
| def save_upload(upload_file: UploadFile): | |
| unique_id = str(uuid.uuid4()) | |
| filename = f"{unique_id}_input.jpg" | |
| file_path = os.path.join(UPLOAD_DIR, filename) | |
| with open(file_path, "wb") as buffer: | |
| shutil.copyfileobj(upload_file.file, buffer) | |
| return unique_id, filename, file_path | |
| # ---------------- ROUTES ---------------- | |
| def api_status(): | |
| return {"status": "API is running"} | |
| async def predict_and_generate_cams( | |
| file: UploadFile = File(...), | |
| mode: str = "resnet" | |
| ): | |
| validate_image(file) | |
| mode = mode.lower() | |
| if mode not in {"resnet", "fusion"}: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="mode must be 'resnet' or 'fusion'" | |
| ) | |
| try: | |
| unique_id, input_filename, input_path = save_upload(file) | |
| if mode == "resnet": | |
| output_name = f"{unique_id}_resnet.jpg" | |
| output_path = os.path.join(RESULT_DIR, output_name) | |
| get_resnet_gradcam( | |
| input_path, | |
| resnet_predictor, | |
| output_path | |
| ) | |
| selected_viz = f"/static/results/{output_name}" | |
| return { | |
| "status": "success", | |
| "mode": mode, | |
| "original_image": f"/static/uploads/{input_filename}", | |
| "selected_viz": selected_viz, | |
| "resnet_viz": selected_viz, | |
| "fusion_viz": None | |
| } | |
| output_name = f"{unique_id}_fusion.jpg" | |
| output_path = os.path.join(RESULT_DIR, output_name) | |
| get_fusion_gradcam( | |
| input_path, | |
| fusion_predictor, | |
| output_path | |
| ) | |
| selected_viz = f"/static/results/{output_name}" | |
| return { | |
| "status": "success", | |
| "mode": mode, | |
| "original_image": f"/static/uploads/{input_filename}", | |
| "selected_viz": selected_viz, | |
| "resnet_viz": None, | |
| "fusion_viz": selected_viz | |
| } | |
| except Exception as e: | |
| logger.exception("GradCAM generation failed.") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def resnet_prediction(image: UploadFile = File(...)): | |
| validate_image(image) | |
| try: | |
| pil_image = Image.open(image.file).convert("RGB") | |
| result = resnet_predictor.resnet_predict(pil_image) | |
| return { | |
| "status": "success", | |
| "prediction": result | |
| } | |
| except Exception as e: | |
| logger.exception("ResNet prediction failed.") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def fusion_prediction(image: UploadFile = File(...)): | |
| validate_image(image) | |
| try: | |
| pil_image = Image.open(image.file).convert("RGB") | |
| result = fusion_predictor.predict(pil_image) | |
| return { | |
| "status": "success", | |
| "prediction": result | |
| } | |
| except Exception as e: | |
| logger.exception("Fusion prediction failed.") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def yolo_detection(file: UploadFile = File(...)): | |
| validate_image(file) | |
| try: | |
| unique_id, input_filename, input_path = save_upload(file) | |
| output_name = f"{unique_id}_yolo.jpg" | |
| output_path = os.path.join(RESULT_DIR, output_name) | |
| result = get_yolo_damage_boxes( | |
| input_path, | |
| yolo_model, | |
| output_path | |
| ) | |
| return { | |
| "status": "success", | |
| "original_image": f"/static/uploads/{input_filename}", | |
| "yolo_image": f"/static/results/{output_name}", | |
| "detections": result["detections"], | |
| "total_detections": result["total_detections"], | |
| "message": result["message"] | |
| } | |
| except Exception as e: | |
| logger.exception("YOLO detection failed.") | |
| raise HTTPException(status_code=500, detail=str(e)) |