Spaces:
Running
Running
| import os | |
| import sys | |
| import tempfile | |
| from pathlib import Path | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import yaml | |
| from PIL import Image | |
| from pytorch_grad_cam import GradCAM | |
| from pytorch_grad_cam.utils.image import show_cam_on_image | |
| WORKSPACE_ROOT = Path( | |
| os.environ.get("WORKSPACE_ROOT", Path(__file__).resolve().parent) | |
| ).resolve() | |
| SRC_DIR = WORKSPACE_ROOT / "src" | |
| sys.path.insert(0, str(SRC_DIR)) | |
| from src.models.swin import EncoderSwinTiny | |
| from src.transforms.image_transform import get_classification_valid_transform | |
| from src.utils.captioning_inference import build_caption_runtime, decode_tokens | |
| from src.visualization.generate_gradcam import ( | |
| SwinClassifierWrapper, | |
| reshape_transform, | |
| ) | |
| CLASSIFICATION_STATE = None | |
| CAPTIONING_STATE = None | |
| def load_params(): | |
| """params.yaml์ ์ฝ์ด์ ๋ฐ๋ชจ, ๋ชจ๋ธ, ์ฒดํฌํฌ์ธํธ ์ค์ ์ ๊ฐ์ ธ์จ๋ค.""" | |
| with open(WORKSPACE_ROOT / "params.yaml", "r", encoding="utf-8") as f: | |
| return yaml.safe_load(f) | |
| # params.yaml์ demo.class_names์์ ํ์ต ๋น์ ํด๋์ค ๋ชฉ๋ก์ ๊ฐ์ ธ์จ๋ค. | |
| def load_class_names(params): | |
| class_names = params.get("demo", {}).get("class_names", []) | |
| if not isinstance(class_names, list) or not all( | |
| isinstance(class_name, str) | |
| for class_name in class_names | |
| ): | |
| raise ValueError("demo.class_names must be a list of class name strings.") | |
| if not class_names: | |
| raise ValueError("No class names found in params.yaml demo.class_names.") | |
| return class_names | |
| # CUDA ์ฌ์ฉ ๊ฐ๋ฅ ์ฌ๋ถ๋ฅผ ๊ธฐ์ค์ผ๋ก ์ฅ์น๋ฅผ ์ ํ | |
| def get_device(params): | |
| device_name = params.get("train", {}).get("device", "cuda") | |
| # ์ค์ ์ด cuda์ด๊ณ ์ค์ CUDA๊ฐ ์์ผ๋ฉด GPU๋ฅผ ์ฌ์ฉํ๋ค. | |
| if device_name == "cuda" and torch.cuda.is_available(): | |
| return torch.device("cuda") | |
| return torch.device("cpu") | |
| def load_classification_checkpoint(model, checkpoint_path, device): | |
| """๋ถ๋ฅ ๋ชจ๋ธ ์ฒดํฌํฌ์ธํธ๋ฅผ ๋ก๋ํ๊ณ model_state_dict ํ์์ด๋ฉด ๋ด๋ถ state_dict๋ง ๊บผ๋ธ๋ค.""" | |
| checkpoint = torch.load( | |
| checkpoint_path, | |
| map_location=device, | |
| ) | |
| # ์ ์ฅ ํฌ๋งท์ด {"model_state_dict": ...} ํํ์ธ ๊ฒฝ์ฐ ์ค์ ๊ฐ์ค์น๋ง ์ฌ์ฉํ๋ค. | |
| if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint: | |
| checkpoint = checkpoint["model_state_dict"] | |
| model.load_state_dict(checkpoint) | |
| def build_classification_runtime(): | |
| """๋ถ๋ฅ ๋ชจ๋ธ, transform, ํด๋์ค๋ช , ์ฒดํฌํฌ์ธํธ ๊ฒฝ๋ก๋ฅผ ๋ฌถ์ ๋ฐํ์ ์ํ๋ฅผ ๋ง๋ ๋ค.""" | |
| params = load_params() | |
| model_name = params["classification"]["model_name"] | |
| # ํ์ฌ Grad-CAM wrapper์ ๋ชจ๋ธ ์์ฑ ๋ก์ง์ Swin-T ์ ์ฉ์ด๋ฏ๋ก ๋ค๋ฅธ ๋ชจ๋ธ์ ๋ช ์์ ์ผ๋ก ๋ง๋๋ค. | |
| if model_name != "swin_t": | |
| raise ValueError( | |
| "The combined Gradio demo currently supports only swin_t " | |
| f"for classification, got: {model_name}" | |
| ) | |
| class_names = load_class_names(params) | |
| device = get_device(params) | |
| model = EncoderSwinTiny( | |
| num_classes=len(class_names) | |
| ).to(device) | |
| checkpoint_path = WORKSPACE_ROOT / params["classification"]["final_checkpoint"] | |
| load_classification_checkpoint( | |
| model, | |
| checkpoint_path, | |
| device, | |
| ) | |
| model.eval() | |
| return { | |
| "params": params, | |
| "model": model, | |
| "model_name": model_name, | |
| "device": device, | |
| "class_names": class_names, | |
| "transform": get_classification_valid_transform(), | |
| "checkpoint_path": checkpoint_path, | |
| } | |
| def get_classification_runtime(): | |
| """๋ถ๋ฅ ๋ฐํ์์ ์ต์ด ์์ฒญ ์ ํ ๋ฒ๋ง ๋ง๋ค๊ณ ์ดํ์๋ ์บ์๋ ์ํ๋ฅผ ์ฌ์ฌ์ฉํ๋ค.""" | |
| global CLASSIFICATION_STATE | |
| # ๋ฒํผ ํด๋ฆญ ์ ์๋ ๋ชจ๋ธ์ ๋ก๋ํ์ง ์๊ณ , ์ฒซ ์์ธก ์์ ์๋ง ๋ก๋ํ๋ค. | |
| if CLASSIFICATION_STATE is None: | |
| CLASSIFICATION_STATE = build_classification_runtime() | |
| return CLASSIFICATION_STATE | |
| def get_caption_checkpoint_path(params): | |
| """์บก์ ๋ ์ฒดํฌํฌ์ธํธ ๊ฒฝ๋ก๋ฅผ params.yaml์์ ์ฐ์ ์ฐพ๊ณ , ์์ผ๋ฉด ๊ธฐ๋ณธ ํ์ผ๋ช ๊ท์น์ผ๋ก ๋ง๋ ๋ค.""" | |
| checkpoint_config = params["captioning"]["checkpoint"] | |
| final_checkpoint = checkpoint_config.get("final_checkpoint") | |
| # final_checkpoint๊ฐ ๋ช ์๋์ด ์์ผ๋ฉด ๊ทธ ํ์ผ์ ์ฐ์ ์ฌ์ฉํ๋ค. | |
| if final_checkpoint: | |
| return WORKSPACE_ROOT / checkpoint_config["save_dir"] / final_checkpoint | |
| # ๋ช ์ ๊ฒฝ๋ก๊ฐ ์์ผ๋ฉด ํ์ต ์ฝ๋์ encoder-decoder_version_best.pt ๊ท์น์ผ๋ก fallbackํ๋ค. | |
| encoder_name = params["captioning"]["encoder"] | |
| decoder_name = params["captioning"]["decoder"] | |
| version = params["captioning"]["version"] | |
| return ( | |
| WORKSPACE_ROOT | |
| / checkpoint_config["save_dir"] | |
| / f"{encoder_name}-{decoder_name}_{version}_best.pt" | |
| ) | |
| def get_captioning_runtime(): | |
| """์บก์ ๋ ๋ฐํ์์ ์ต์ด ์์ฒญ ์ ํ ๋ฒ๋ง ๋ง๋ค๊ณ ์ดํ์๋ ์บ์๋ ์ํ๋ฅผ ์ฌ์ฌ์ฉํ๋ค.""" | |
| global CAPTIONING_STATE | |
| # ์บก์ ๋ ํญ์ ์ค์ ๋ก ์คํํ๊ธฐ ์ ๊น์ง encoder/decoder ๋ก๋ฉ์ ๋ฏธ๋ฃฌ๋ค. | |
| if CAPTIONING_STATE is None: | |
| params = load_params() | |
| CAPTIONING_STATE = build_caption_runtime( | |
| WORKSPACE_ROOT, | |
| checkpoint_path=get_caption_checkpoint_path(params), | |
| ) | |
| return CAPTIONING_STATE | |
| def make_gradcam_overlay(model, image, tensor, device): | |
| """๋ถ๋ฅ ๋ชจ๋ธ์ ๋ง์ง๋ง Swin block์ ๋์์ผ๋ก Grad-CAM overlay ์ด๋ฏธ์ง๋ฅผ ์์ฑํ๋ค.""" | |
| # Grad-CAM์ gradient๊ฐ ํ์ํ๋ฏ๋ก frozen backbone/classifier๋ ์ผ์์ ์ผ๋ก gradient๋ฅผ ์ผ ๋ค. | |
| for param in model.backbone.parameters(): | |
| param.requires_grad = True | |
| for param in model.classifier.parameters(): | |
| param.requires_grad = True | |
| gradcam_model = SwinClassifierWrapper(model).to(device) | |
| gradcam_model.eval() | |
| resized_image = image.resize((224, 224)) | |
| image_np = np.array(resized_image).astype(np.float32) / 255.0 | |
| target_layer = model.backbone.features[-1][-1].norm2 | |
| with GradCAM( | |
| model=gradcam_model, | |
| target_layers=[target_layer], | |
| reshape_transform=reshape_transform, | |
| ) as cam: | |
| grayscale_cam = cam(input_tensor=tensor)[0] | |
| overlay = show_cam_on_image( | |
| image_np, | |
| grayscale_cam, | |
| use_rgb=True, | |
| ) | |
| return Image.fromarray(overlay) | |
| def predict_classification(image, show_gradcam): | |
| """์ ๋ก๋๋ ์ด๋ฏธ์ง๋ฅผ ๋ถ๋ฅํ๊ณ , ์ ํ ์ Grad-CAM ๊ฒฐ๊ณผ๊น์ง ํจ๊ป ๋ฐํํ๋ค.""" | |
| # ์ด๋ฏธ์ง๊ฐ ์์ผ๋ฉด Gradio ์ถ๋ ฅ ๊ฐ์์ ๋ง์ถฐ ๋น ๊ฒฐ๊ณผ๋ฅผ ๋ฐํํ๋ค. | |
| if image is None: | |
| return None, "Please upload an image.", [] | |
| runtime = get_classification_runtime() | |
| params = runtime["params"] | |
| model = runtime["model"] | |
| device = runtime["device"] | |
| class_names = runtime["class_names"] | |
| transform = runtime["transform"] | |
| image = image.convert("RGB") | |
| tensor = transform(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| logits = model(tensor) | |
| probs = torch.softmax(logits, dim=1)[0] | |
| top_k = max( | |
| 1, | |
| min( | |
| int(params["demo"].get("top_k", 5)), | |
| len(class_names), | |
| ), | |
| ) | |
| top_probs, top_indices = torch.topk( | |
| probs, | |
| k=top_k, | |
| ) | |
| top_probs = top_probs.detach().cpu().tolist() | |
| top_indices = top_indices.detach().cpu().tolist() | |
| # confidences = { | |
| # class_names[idx]: float(prob) | |
| # for idx, prob in zip(top_indices, top_probs) | |
| # } | |
| predicted_idx = top_indices[0] | |
| predicted_label = class_names[predicted_idx] | |
| predicted_confidence = top_probs[0] | |
| summary = ( | |
| f" {predicted_label} " | |
| f"({predicted_confidence * 100:.2f}%)" | |
| ) | |
| table = [ | |
| [ | |
| rank, | |
| class_names[idx], | |
| f"{prob * 100:.2f}%", | |
| ] | |
| for rank, (idx, prob) in enumerate( | |
| zip(top_indices, top_probs), | |
| start=1, | |
| ) | |
| ] | |
| gradcam_image = None | |
| # ์ฌ์ฉ์๊ฐ ์ฒดํฌ๋ฐ์ค๋ฅผ ์ผ ๊ฒฝ์ฐ์๋ง ๋น์ฉ์ด ํฐ Grad-CAM์ ์์ฑํ๋ค. | |
| if show_gradcam: | |
| gradcam_image = make_gradcam_overlay( | |
| model, | |
| image, | |
| tensor, | |
| device, | |
| ) | |
| return gradcam_image, summary, table | |
| def caption_token_labels(generated_tokens, runtime, caption): | |
| """attention heatmap ์ ๋ชฉ์ผ๋ก ์ฌ์ฉํ ์์ฑ ํ ํฐ ๋ผ๋ฒจ์ ๋ง๋ ๋ค.""" | |
| special_ids = { | |
| runtime["w2i"].get("<pad>"), | |
| runtime["w2i"].get("<sos>"), | |
| runtime["w2i"].get("<eos>"), | |
| } | |
| labels = [ | |
| runtime["i2w"].get(token, "<unk>") | |
| for token in generated_tokens | |
| if token not in special_ids | |
| ] | |
| # ํ ํฐ id ๊ธฐ๋ฐ ๋ผ๋ฒจ์ด ์์ผ๋ฉด attention ๊ธธ์ด์ ๋ง๊ธฐ ์ฌ์ด ์ด ๋ผ๋ฒจ์ ์ฌ์ฉํ๋ค. | |
| if labels: | |
| return labels | |
| # ์์ธ์ ์ผ๋ก ๋ผ๋ฒจ์ด ๋น์ด ์์ผ๋ฉด ๋ฌธ์ฅ ๋ฌธ์์ด์ ๋จ์ด ๋จ์๋ก ๋๋ fallbackํ๋ค. | |
| return caption.split() | |
| def predict_captioning(image): | |
| """์ ๋ก๋๋ ์ด๋ฏธ์ง์ ๋ํด ์บก์ ์ ์์ฑํ๊ณ cross-attention heatmap๋ค์ ๋ฐํํ๋ค.""" | |
| # ์ด๋ฏธ์ง๊ฐ ์์ผ๋ฉด Gradio ์ถ๋ ฅ ๊ฐ์์ ๋ง์ถฐ ๋น ๊ฒฐ๊ณผ๋ฅผ ๋ฐํํ๋ค. | |
| if image is None: | |
| return "Please upload an image.", [] | |
| runtime = get_captioning_runtime() | |
| params = runtime["params"] | |
| image = image.convert("RGB") | |
| image_tensor = runtime["transform"](image) | |
| image_tensor = image_tensor.unsqueeze(0).to(runtime["device"]) | |
| features = runtime["encoder"]( | |
| image_tensor, | |
| return_features=True, | |
| ) | |
| start_token = torch.full( | |
| (features.size(0),), | |
| runtime["w2i"]["<sos>"], | |
| dtype=torch.long, | |
| device=runtime["device"], | |
| ) | |
| beam_config = params["captioning"]["beam_search"] | |
| use_beam_search = beam_config.get("use_beam_search", True) | |
| beam_size = beam_config.get("beam_size", 3) | |
| # params.yaml์์ beam search๋ฅผ ์ผ ๊ฒฝ์ฐ ์ฌ๋ฌ ํ๋ณด๋ฅผ ํ์ํด ์บก์ ์ ์์ฑํ๋ค. | |
| if use_beam_search: | |
| generated_tokens, _, enc_dec_atten = runtime["decoder"].generate_beam( | |
| features, | |
| start_token, | |
| runtime["w2i"]["<eos>"], | |
| beam_size, | |
| ) | |
| else: | |
| # beam search๋ฅผ ๋ ๊ฒฝ์ฐ ๋งค step์์ ๊ฐ์ฅ ํ๋ฅ ๋์ ํ ํฐ์ ์ ํํ๋ greedy ์์ฑ์ ์ฌ์ฉํ๋ค. | |
| generated_tokens, _, enc_dec_atten = runtime["decoder"].generate( | |
| features, | |
| start_token, | |
| runtime["w2i"]["<eos>"], | |
| ) | |
| caption = decode_tokens( | |
| generated_tokens[0], | |
| runtime["w2i"], | |
| runtime["i2w"], | |
| params["captioning"]["tokenizer"]["use_subword"], | |
| sp_model_path=runtime["sp_model_path"], | |
| ) | |
| caption_tokens = caption_token_labels( | |
| generated_tokens[0], | |
| runtime, | |
| caption, | |
| ) | |
| tmp_dir = tempfile.mkdtemp(prefix="combined_captioning_gradio_") | |
| last_layer = len(runtime["decoder"].layers) | |
| cross_atten_path = Path(tmp_dir) / "cross_attention_last_layer.jpg" | |
| runtime["decoder"].show_cross_atten( | |
| enc_dec_atten[0], | |
| caption_tokens, | |
| last_layer, | |
| image_tensor.squeeze(0).detach().cpu(), | |
| str(cross_atten_path), | |
| ) | |
| heatmap_images = [ | |
| ( | |
| str(cross_atten_path), | |
| f"Last Layer ({last_layer})", | |
| ) | |
| ] | |
| return caption, heatmap_images | |
| def create_demo(): | |
| """๋ถ๋ฅ ํญ๊ณผ ์บก์ ๋ ํญ์ ๊ฐ์ง ํ๋์ Gradio Blocks ์ฑ์ ๋ง๋ ๋ค.""" | |
| params = load_params() | |
| top_k = max(1, int(params["demo"].get("top_k", 5))) | |
| caption_checkpoint = get_caption_checkpoint_path(params) | |
| with gr.Blocks(title="ImageNet Classification and Captioning Demo") as demo: | |
| gr.Markdown("# ImageNet Classification and Captioning Demo") | |
| with gr.Tabs(): | |
| with gr.Tab("Classification"): | |
| gr.Markdown( | |
| "Upload an image and classify it with the final checkpoint." | |
| ) | |
| gr.Markdown( | |
| f"checkpoint: {WORKSPACE_ROOT / params['classification']['final_checkpoint']}" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| classification_image_input = gr.Image( | |
| type="pil", | |
| label="Input Image", | |
| ) | |
| gradcam_checkbox = gr.Checkbox( | |
| value=bool(params["demo"].get("show_gradcam", True)), | |
| label="Show Grad-CAM", | |
| ) | |
| classification_button = gr.Button( | |
| "Predict", | |
| variant="primary", | |
| ) | |
| with gr.Column(): | |
| gradcam_output = gr.Image( | |
| type="pil", | |
| label="Grad-CAM", | |
| ) | |
| classification_summary_output = gr.Textbox( | |
| label="Prediction", | |
| ) | |
| # confidence_output = gr.Label( | |
| # label="Top Prediction Scores", | |
| # num_top_classes=top_k, | |
| # ) | |
| table_output = gr.Dataframe( | |
| headers=["Rank", "Class", "Confidence"], | |
| datatype=["number", "str", "str"], | |
| label=f"Top-{top_k}", | |
| interactive=False, | |
| ) | |
| classification_button.click( | |
| fn=predict_classification, | |
| inputs=[ | |
| classification_image_input, | |
| gradcam_checkbox, | |
| ], | |
| outputs=[ | |
| gradcam_output, | |
| classification_summary_output, | |
| # confidence_output, | |
| table_output, | |
| ], | |
| ) | |
| with gr.Tab("Captioning"): | |
| gr.Markdown( | |
| "Upload an image and generate a caption with cross-attention heatmaps." | |
| ) | |
| gr.Markdown(f"checkpoint: {caption_checkpoint}") | |
| with gr.Row(): | |
| with gr.Column(): | |
| captioning_image_input = gr.Image( | |
| type="pil", | |
| label="Input Image", | |
| ) | |
| captioning_button = gr.Button( | |
| "Generate Caption", | |
| variant="primary", | |
| ) | |
| with gr.Column(): | |
| caption_output = gr.Textbox( | |
| label="Generated Caption", | |
| lines=4, | |
| ) | |
| cross_atten_output = gr.Gallery( | |
| label="Cross Attention Heatmaps", | |
| columns=2, | |
| object_fit="contain", | |
| height="auto", | |
| ) | |
| captioning_button.click( | |
| fn=predict_captioning, | |
| inputs=[captioning_image_input], | |
| outputs=[ | |
| caption_output, | |
| cross_atten_output, | |
| ], | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| params = load_params() | |
| demo = create_demo() | |
| demo.launch( | |
| server_name=params["demo"]["host"], | |
| server_port=params["demo"]["port"], | |
| share=params["demo"]["share"], | |
| ) | |