Spaces:
Running
Running
| import json | |
| from pathlib import Path | |
| from typing import Any | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| from tqdm import tqdm | |
| from transformers import CLIPModel, CLIPProcessor | |
| import os | |
| from dotenv import load_dotenv | |
| # ============================================================ | |
| # ์ค์ ๊ฐ | |
| # ============================================================ | |
| load_dotenv() | |
| # .env ์์ HF_TOKEN ์ฝ๊ธฐ | |
| hf_token = os.getenv("HF_TOKEN") | |
| # ์ ์ฒด ํด๋์ค๋ฅผ ๊ฒ์ํ๋ ค๋ฉด True | |
| # ํน์ ํด๋์ค๋ง ๊ฒ์ํ๋ ค๋ฉด False | |
| CHECK_ALL_CLASSES = True | |
| # ์ ์ฒด ํด๋์ค ๊ฒ์ ์ ๊ธฐ์ค์ด ๋๋ raw ๋ฐ์ดํฐ ๋ฃจํธ | |
| DATA_RAW_ROOT_DIR = Path("data/raw") | |
| # ํน์ ํด๋์ค๋ง ๊ฒ์ํ ๋ ์ฌ์ฉํ ํด๋์ค ํด๋ ๊ฒฝ๋ก | |
| # CHECK_ALL_CLASSES = False ์ผ ๋๋ง ์ฌ์ฉ๋จ | |
| TARGET_CLASS_DIR = Path("data/raw") | |
| # ์ ๋ ฅ JSON ํ์ผ | |
| INPUT_JSON_PATH = Path("data/annotations/captions_flo_all.json") | |
| # ์ถ๋ ฅ JSON ํ์ผ | |
| OUTPUT_JSON_PATH = Path("data/annotations/clip_checked_flo_all.json") | |
| # ์ฌ์ฉํ CLIP ๋ชจ๋ธ | |
| MODEL_NAME = "openai/clip-vit-base-patch32" | |
| # ํ ๋ฒ์ ์ฒ๋ฆฌํ ์ด๋ฏธ์ง-์บก์ ์ ๊ฐ์ | |
| BATCH_SIZE = 32 | |
| # ํ์ ๋ช %๋ฅผ fail / review๋ก ๋ณผ์ง | |
| FAIL_BOTTOM_PERCENT = 10 | |
| REVIEW_BOTTOM_PERCENT = 20 | |
| print("๊ฒฝ๋ก : " , INPUT_JSON_PATH) | |
| # ============================================================ | |
| # JSON ์ ์ถ๋ ฅ | |
| # ============================================================ | |
| def load_json(path: Path) -> list[dict[str, Any]]: | |
| with path.open("r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| if not isinstance(data, list): | |
| raise ValueError("์ ๋ ฅ JSON์ ๋ฐ๋์ ๋ฐฐ์ด ํํ์ฌ์ผ ํฉ๋๋ค.") | |
| return data | |
| def save_json(data: list[dict[str, Any]], path: Path) -> None: | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| with path.open("w", encoding="utf-8") as f: | |
| json.dump(data, f, ensure_ascii=False, indent=4) | |
| # ============================================================ | |
| # ํด๋์ค / ๊ฒฝ๋ก ์ฒ๋ฆฌ | |
| # ============================================================ | |
| def get_target_class_name() -> str: | |
| """ | |
| TARGET_CLASS_DIR = data/raw/airplane ์ด๋ฉด airplane ๋ฐํ | |
| """ | |
| return TARGET_CLASS_DIR.name | |
| def get_class_name_from_image_value(image_value: str) -> str: | |
| """ | |
| JSON์ image ๊ฐ์ด airplane/hf_airplane_001.jpg ๋ผ๋ฉด airplane ๋ฐํ | |
| """ | |
| image_value = image_value.replace("\\", "/") | |
| image_path = Path(image_value) | |
| if len(image_path.parts) < 2: | |
| return "" | |
| return image_path.parts[0] | |
| def is_target_item(item: dict[str, Any]) -> bool: | |
| """ | |
| CHECK_ALL_CLASSES = True: | |
| ๋ชจ๋ item ์ฒ๋ฆฌ | |
| CHECK_ALL_CLASSES = False: | |
| TARGET_CLASS_DIR.name๊ณผ JSON image์ ์ฒซ ๋ฒ์งธ ํด๋๋ช ์ด ๊ฐ์ item๋ง ์ฒ๋ฆฌ | |
| """ | |
| if CHECK_ALL_CLASSES: | |
| return True | |
| image_value = str(item.get("image", "")) | |
| image_class_name = get_class_name_from_image_value(image_value) | |
| return image_class_name == get_target_class_name() | |
| def resolve_image_path(image_value: str) -> Path: | |
| """ | |
| JSON: | |
| "image": "airplane/hf_airplane_001.jpg" | |
| ์ ์ฒด ํด๋์ค ๊ฒ์: | |
| DATA_RAW_ROOT_DIR / image | |
| โ data/raw/airplane/hf_airplane_001.jpg | |
| ํน์ ํด๋์ค ๊ฒ์: | |
| TARGET_CLASS_DIR / ํ์ผ๋ช | |
| โ data/raw/airplane/hf_airplane_001.jpg | |
| """ | |
| image_value = image_value.replace("\\", "/") | |
| image_path = Path(image_value) | |
| if CHECK_ALL_CLASSES: | |
| return DATA_RAW_ROOT_DIR / image_path | |
| return TARGET_CLASS_DIR / image_path.name | |
| def load_image(image_path: Path) -> Image.Image | None: | |
| try: | |
| with Image.open(image_path) as img: | |
| return img.convert("RGB").copy() | |
| except Exception: | |
| return None | |
| # ============================================================ | |
| # ์บก์ ํผ์น๊ธฐ | |
| # ============================================================ | |
| def flatten_caption_items(data: list[dict[str, Any]]) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: | |
| """ | |
| ์ด๋ฏธ์ง 1์ฅ์ caption 3๊ฐ๊ฐ ์์ผ๋ฉด | |
| ์ด๋ฏธ์ง-์บก์ ์ 3๊ฐ๋ก ํผ์น๋ค. | |
| """ | |
| target_data = [] | |
| flat_items = [] | |
| for item in data: | |
| if not is_target_item(item): | |
| continue | |
| target_item_index = len(target_data) | |
| target_data.append(item) | |
| image_value = str(item.get("image", "")) | |
| captions = item.get("captions", []) | |
| if not isinstance(captions, list): | |
| captions = [] | |
| for caption_index, caption in enumerate(captions): | |
| flat_items.append({ | |
| "item_index": target_item_index, | |
| "caption_index": caption_index, | |
| "image": image_value, | |
| "class": item.get("class", ""), | |
| "split": item.get("split", ""), | |
| "caption": str(caption).strip() | |
| }) | |
| return target_data, flat_items | |
| # ============================================================ | |
| # CLIP Score ๊ณ์ฐ | |
| # ============================================================ | |
| def compute_clip_scores( | |
| flat_items: list[dict[str, Any]], | |
| model: CLIPModel, | |
| processor: CLIPProcessor, | |
| device: torch.device | |
| ) -> list[dict[str, Any]]: | |
| results = [] | |
| for start in tqdm(range(0, len(flat_items), BATCH_SIZE), desc="computing CLIP scores"): | |
| batch_items = flat_items[start:start + BATCH_SIZE] | |
| valid_items = [] | |
| images = [] | |
| texts = [] | |
| for item in batch_items: | |
| image_path = resolve_image_path(item["image"]) | |
| image = load_image(image_path) | |
| if image is None: | |
| results.append({ | |
| **item, | |
| "resolved_image_path": str(image_path).replace("\\", "/"), | |
| "clip_cosine": None, | |
| "clip_score": None, | |
| "clip_status": "missing_image", | |
| "clip_reason": f"image file could not be opened: {image_path}" | |
| }) | |
| continue | |
| caption = item["caption"] | |
| if not caption: | |
| results.append({ | |
| **item, | |
| "resolved_image_path": str(image_path).replace("\\", "/"), | |
| "clip_cosine": None, | |
| "clip_score": None, | |
| "clip_status": "empty_caption", | |
| "clip_reason": "caption is empty" | |
| }) | |
| continue | |
| valid_items.append({ | |
| **item, | |
| "resolved_image_path": str(image_path).replace("\\", "/") | |
| }) | |
| images.append(image) | |
| texts.append(caption) | |
| if not valid_items: | |
| continue | |
| inputs = processor( | |
| text=texts, | |
| images=images, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True | |
| ) | |
| inputs = { | |
| key: value.to(device) | |
| for key, value in inputs.items() | |
| } | |
| outputs = model( | |
| input_ids=inputs["input_ids"], | |
| attention_mask=inputs["attention_mask"], | |
| pixel_values=inputs["pixel_values"] | |
| ) | |
| image_features = outputs.image_embeds | |
| text_features = outputs.text_embeds | |
| image_features = F.normalize(image_features, p=2, dim=1) | |
| text_features = F.normalize(text_features, p=2, dim=1) | |
| cosine_scores = (image_features * text_features).sum(dim=1) | |
| for item, cosine in zip(valid_items, cosine_scores): | |
| cosine_value = float(cosine.detach().cpu().item()) | |
| clip_score = 2.5 * max(cosine_value, 0.0) | |
| results.append({ | |
| **item, | |
| "clip_cosine": round(cosine_value, 6), | |
| "clip_score": round(clip_score, 6), | |
| "clip_status": "pending", | |
| "clip_reason": "" | |
| }) | |
| return results | |
| # ============================================================ | |
| # pass / review / fail ํ์ | |
| # ============================================================ | |
| def assign_clip_status(results: list[dict[str, Any]]) -> None: | |
| valid_scores = [ | |
| result["clip_score"] | |
| for result in results | |
| if isinstance(result.get("clip_score"), float) | |
| ] | |
| if not valid_scores: | |
| return | |
| fail_threshold = np.percentile(valid_scores, FAIL_BOTTOM_PERCENT) | |
| review_threshold = np.percentile(valid_scores, REVIEW_BOTTOM_PERCENT) | |
| for result in results: | |
| clip_score = result.get("clip_score") | |
| if clip_score is None: | |
| continue | |
| if clip_score <= fail_threshold: | |
| result["clip_status"] = "fail" | |
| result["clip_reason"] = f"clip score is in the bottom {FAIL_BOTTOM_PERCENT}%" | |
| elif clip_score <= review_threshold: | |
| result["clip_status"] = "review" | |
| result["clip_reason"] = f"clip score is in the bottom {REVIEW_BOTTOM_PERCENT}%" | |
| else: | |
| result["clip_status"] = "pass" | |
| result["clip_reason"] = "clip score is acceptable" | |
| # ============================================================ | |
| # ๊ฒฐ๊ณผ๋ฅผ ์๋ JSON ๊ตฌ์กฐ์ ๋ถ์ด๊ธฐ | |
| # ============================================================ | |
| def attach_results_to_data( | |
| target_data: list[dict[str, Any]], | |
| results: list[dict[str, Any]] | |
| ) -> list[dict[str, Any]]: | |
| for item in target_data: | |
| item["caption_checks"] = [] | |
| results = sorted( | |
| results, | |
| key=lambda x: (x["item_index"], x["caption_index"]) | |
| ) | |
| for result in results: | |
| item_index = result["item_index"] | |
| check = { | |
| "caption_index": result["caption_index"], | |
| "caption": result["caption"], | |
| "resolved_image_path": result.get("resolved_image_path"), | |
| "clip_cosine": result.get("clip_cosine"), | |
| "clip_score": result.get("clip_score"), | |
| "clip_status": result.get("clip_status"), | |
| "clip_reason": result.get("clip_reason", "") | |
| } | |
| target_data[item_index]["caption_checks"].append(check) | |
| return target_data | |
| # ============================================================ | |
| # ์์ฝ ์ถ๋ ฅ | |
| # ============================================================ | |
| def print_summary( | |
| target_data: list[dict[str, Any]], | |
| flat_items: list[dict[str, Any]], | |
| results: list[dict[str, Any]] | |
| ) -> None: | |
| status_count = {} | |
| valid_scores = [] | |
| for result in results: | |
| status = result.get("clip_status", "unknown") | |
| status_count[status] = status_count.get(status, 0) + 1 | |
| if isinstance(result.get("clip_score"), float): | |
| valid_scores.append(result["clip_score"]) | |
| print("\n===== CLIP Score Summary =====") | |
| print(f"check all classes: {CHECK_ALL_CLASSES}") | |
| if CHECK_ALL_CLASSES: | |
| print(f"data raw root dir: {DATA_RAW_ROOT_DIR}") | |
| else: | |
| print(f"target class dir: {TARGET_CLASS_DIR}") | |
| print(f"target class name: {get_target_class_name()}") | |
| print(f"target images: {len(target_data)}") | |
| print(f"target image-caption pairs: {len(flat_items)}") | |
| print(f"status count: {status_count}") | |
| if valid_scores: | |
| print(f"min score: {min(valid_scores):.4f}") | |
| print(f"max score: {max(valid_scores):.4f}") | |
| print(f"mean score: {np.mean(valid_scores):.4f}") | |
| print(f"bottom {FAIL_BOTTOM_PERCENT}% threshold: {np.percentile(valid_scores, FAIL_BOTTOM_PERCENT):.4f}") | |
| print(f"bottom {REVIEW_BOTTOM_PERCENT}% threshold: {np.percentile(valid_scores, REVIEW_BOTTOM_PERCENT):.4f}") | |
| # ============================================================ | |
| # ์คํ | |
| # ============================================================ | |
| def main(): | |
| if not INPUT_JSON_PATH.exists(): | |
| raise FileNotFoundError(f"input file not found: {INPUT_JSON_PATH}") | |
| if CHECK_ALL_CLASSES: | |
| if not DATA_RAW_ROOT_DIR.exists(): | |
| raise FileNotFoundError(f"data raw root directory not found: {DATA_RAW_ROOT_DIR}") | |
| else: | |
| if not TARGET_CLASS_DIR.exists(): | |
| raise FileNotFoundError(f"target class directory not found: {TARGET_CLASS_DIR}") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"device: {device}") | |
| print(f"loading model: {MODEL_NAME}") | |
| model = CLIPModel.from_pretrained(MODEL_NAME, token=hf_token).to(device) | |
| processor = CLIPProcessor.from_pretrained(MODEL_NAME, token=hf_token) | |
| model.eval() | |
| data = load_json(INPUT_JSON_PATH) | |
| target_data, flat_items = flatten_caption_items(data) | |
| if not target_data: | |
| raise ValueError("๊ฒ์ ๋์ ๋ฐ์ดํฐ๊ฐ ์์ต๋๋ค. CHECK_ALL_CLASSES ๋๋ TARGET_CLASS_DIR ์ค์ ์ ํ์ธํ์ธ์.") | |
| results = compute_clip_scores( | |
| flat_items=flat_items, | |
| model=model, | |
| processor=processor, | |
| device=device | |
| ) | |
| assign_clip_status(results) | |
| checked_data = attach_results_to_data(target_data, results) | |
| save_json(checked_data, OUTPUT_JSON_PATH) | |
| print_summary(target_data, flat_items, results) | |
| print(f"\nsaved: {OUTPUT_JSON_PATH}") | |
| if __name__ == "__main__": | |
| main() |