Mini-ImageNet / src /caption /check_clip_score.py
ImAMJayKIM's picture
Upload 96 files
c1596ac verified
import json
from pathlib import Path
from typing import Any
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from tqdm import tqdm
from transformers import CLIPModel, CLIPProcessor
import os
from dotenv import load_dotenv
# ============================================================
# ์„ค์ •๊ฐ’
# ============================================================
load_dotenv()
# .env ์•ˆ์˜ HF_TOKEN ์ฝ๊ธฐ
hf_token = os.getenv("HF_TOKEN")
# ์ „์ฒด ํด๋ž˜์Šค๋ฅผ ๊ฒ€์ˆ˜ํ•˜๋ ค๋ฉด True
# ํŠน์ • ํด๋ž˜์Šค๋งŒ ๊ฒ€์ˆ˜ํ•˜๋ ค๋ฉด False
CHECK_ALL_CLASSES = True
# ์ „์ฒด ํด๋ž˜์Šค ๊ฒ€์ˆ˜ ์‹œ ๊ธฐ์ค€์ด ๋˜๋Š” raw ๋ฐ์ดํ„ฐ ๋ฃจํŠธ
DATA_RAW_ROOT_DIR = Path("data/raw")
# ํŠน์ • ํด๋ž˜์Šค๋งŒ ๊ฒ€์ˆ˜ํ•  ๋•Œ ์‚ฌ์šฉํ•  ํด๋ž˜์Šค ํด๋” ๊ฒฝ๋กœ
# CHECK_ALL_CLASSES = False ์ผ ๋•Œ๋งŒ ์‚ฌ์šฉ๋จ
TARGET_CLASS_DIR = Path("data/raw")
# ์ž…๋ ฅ JSON ํŒŒ์ผ
INPUT_JSON_PATH = Path("data/annotations/captions_flo_all.json")
# ์ถœ๋ ฅ JSON ํŒŒ์ผ
OUTPUT_JSON_PATH = Path("data/annotations/clip_checked_flo_all.json")
# ์‚ฌ์šฉํ•  CLIP ๋ชจ๋ธ
MODEL_NAME = "openai/clip-vit-base-patch32"
# ํ•œ ๋ฒˆ์— ์ฒ˜๋ฆฌํ•  ์ด๋ฏธ์ง€-์บก์…˜ ์Œ ๊ฐœ์ˆ˜
BATCH_SIZE = 32
# ํ•˜์œ„ ๋ช‡ %๋ฅผ fail / review๋กœ ๋ณผ์ง€
FAIL_BOTTOM_PERCENT = 10
REVIEW_BOTTOM_PERCENT = 20
print("๊ฒฝ๋กœ : " , INPUT_JSON_PATH)
# ============================================================
# JSON ์ž…์ถœ๋ ฅ
# ============================================================
def load_json(path: Path) -> list[dict[str, Any]]:
with path.open("r", encoding="utf-8") as f:
data = json.load(f)
if not isinstance(data, list):
raise ValueError("์ž…๋ ฅ JSON์€ ๋ฐ˜๋“œ์‹œ ๋ฐฐ์—ด ํ˜•ํƒœ์—ฌ์•ผ ํ•ฉ๋‹ˆ๋‹ค.")
return data
def save_json(data: list[dict[str, Any]], path: Path) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=4)
# ============================================================
# ํด๋ž˜์Šค / ๊ฒฝ๋กœ ์ฒ˜๋ฆฌ
# ============================================================
def get_target_class_name() -> str:
"""
TARGET_CLASS_DIR = data/raw/airplane ์ด๋ฉด airplane ๋ฐ˜ํ™˜
"""
return TARGET_CLASS_DIR.name
def get_class_name_from_image_value(image_value: str) -> str:
"""
JSON์˜ image ๊ฐ’์ด airplane/hf_airplane_001.jpg ๋ผ๋ฉด airplane ๋ฐ˜ํ™˜
"""
image_value = image_value.replace("\\", "/")
image_path = Path(image_value)
if len(image_path.parts) < 2:
return ""
return image_path.parts[0]
def is_target_item(item: dict[str, Any]) -> bool:
"""
CHECK_ALL_CLASSES = True:
๋ชจ๋“  item ์ฒ˜๋ฆฌ
CHECK_ALL_CLASSES = False:
TARGET_CLASS_DIR.name๊ณผ JSON image์˜ ์ฒซ ๋ฒˆ์งธ ํด๋”๋ช…์ด ๊ฐ™์€ item๋งŒ ์ฒ˜๋ฆฌ
"""
if CHECK_ALL_CLASSES:
return True
image_value = str(item.get("image", ""))
image_class_name = get_class_name_from_image_value(image_value)
return image_class_name == get_target_class_name()
def resolve_image_path(image_value: str) -> Path:
"""
JSON:
"image": "airplane/hf_airplane_001.jpg"
์ „์ฒด ํด๋ž˜์Šค ๊ฒ€์ˆ˜:
DATA_RAW_ROOT_DIR / image
โ†’ data/raw/airplane/hf_airplane_001.jpg
ํŠน์ • ํด๋ž˜์Šค ๊ฒ€์ˆ˜:
TARGET_CLASS_DIR / ํŒŒ์ผ๋ช…
โ†’ data/raw/airplane/hf_airplane_001.jpg
"""
image_value = image_value.replace("\\", "/")
image_path = Path(image_value)
if CHECK_ALL_CLASSES:
return DATA_RAW_ROOT_DIR / image_path
return TARGET_CLASS_DIR / image_path.name
def load_image(image_path: Path) -> Image.Image | None:
try:
with Image.open(image_path) as img:
return img.convert("RGB").copy()
except Exception:
return None
# ============================================================
# ์บก์…˜ ํŽผ์น˜๊ธฐ
# ============================================================
def flatten_caption_items(data: list[dict[str, Any]]) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
"""
์ด๋ฏธ์ง€ 1์žฅ์— caption 3๊ฐœ๊ฐ€ ์žˆ์œผ๋ฉด
์ด๋ฏธ์ง€-์บก์…˜ ์Œ 3๊ฐœ๋กœ ํŽผ์นœ๋‹ค.
"""
target_data = []
flat_items = []
for item in data:
if not is_target_item(item):
continue
target_item_index = len(target_data)
target_data.append(item)
image_value = str(item.get("image", ""))
captions = item.get("captions", [])
if not isinstance(captions, list):
captions = []
for caption_index, caption in enumerate(captions):
flat_items.append({
"item_index": target_item_index,
"caption_index": caption_index,
"image": image_value,
"class": item.get("class", ""),
"split": item.get("split", ""),
"caption": str(caption).strip()
})
return target_data, flat_items
# ============================================================
# CLIP Score ๊ณ„์‚ฐ
# ============================================================
@torch.no_grad()
def compute_clip_scores(
flat_items: list[dict[str, Any]],
model: CLIPModel,
processor: CLIPProcessor,
device: torch.device
) -> list[dict[str, Any]]:
results = []
for start in tqdm(range(0, len(flat_items), BATCH_SIZE), desc="computing CLIP scores"):
batch_items = flat_items[start:start + BATCH_SIZE]
valid_items = []
images = []
texts = []
for item in batch_items:
image_path = resolve_image_path(item["image"])
image = load_image(image_path)
if image is None:
results.append({
**item,
"resolved_image_path": str(image_path).replace("\\", "/"),
"clip_cosine": None,
"clip_score": None,
"clip_status": "missing_image",
"clip_reason": f"image file could not be opened: {image_path}"
})
continue
caption = item["caption"]
if not caption:
results.append({
**item,
"resolved_image_path": str(image_path).replace("\\", "/"),
"clip_cosine": None,
"clip_score": None,
"clip_status": "empty_caption",
"clip_reason": "caption is empty"
})
continue
valid_items.append({
**item,
"resolved_image_path": str(image_path).replace("\\", "/")
})
images.append(image)
texts.append(caption)
if not valid_items:
continue
inputs = processor(
text=texts,
images=images,
return_tensors="pt",
padding=True,
truncation=True
)
inputs = {
key: value.to(device)
for key, value in inputs.items()
}
outputs = model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
pixel_values=inputs["pixel_values"]
)
image_features = outputs.image_embeds
text_features = outputs.text_embeds
image_features = F.normalize(image_features, p=2, dim=1)
text_features = F.normalize(text_features, p=2, dim=1)
cosine_scores = (image_features * text_features).sum(dim=1)
for item, cosine in zip(valid_items, cosine_scores):
cosine_value = float(cosine.detach().cpu().item())
clip_score = 2.5 * max(cosine_value, 0.0)
results.append({
**item,
"clip_cosine": round(cosine_value, 6),
"clip_score": round(clip_score, 6),
"clip_status": "pending",
"clip_reason": ""
})
return results
# ============================================================
# pass / review / fail ํŒ์ •
# ============================================================
def assign_clip_status(results: list[dict[str, Any]]) -> None:
valid_scores = [
result["clip_score"]
for result in results
if isinstance(result.get("clip_score"), float)
]
if not valid_scores:
return
fail_threshold = np.percentile(valid_scores, FAIL_BOTTOM_PERCENT)
review_threshold = np.percentile(valid_scores, REVIEW_BOTTOM_PERCENT)
for result in results:
clip_score = result.get("clip_score")
if clip_score is None:
continue
if clip_score <= fail_threshold:
result["clip_status"] = "fail"
result["clip_reason"] = f"clip score is in the bottom {FAIL_BOTTOM_PERCENT}%"
elif clip_score <= review_threshold:
result["clip_status"] = "review"
result["clip_reason"] = f"clip score is in the bottom {REVIEW_BOTTOM_PERCENT}%"
else:
result["clip_status"] = "pass"
result["clip_reason"] = "clip score is acceptable"
# ============================================================
# ๊ฒฐ๊ณผ๋ฅผ ์›๋ž˜ JSON ๊ตฌ์กฐ์— ๋ถ™์ด๊ธฐ
# ============================================================
def attach_results_to_data(
target_data: list[dict[str, Any]],
results: list[dict[str, Any]]
) -> list[dict[str, Any]]:
for item in target_data:
item["caption_checks"] = []
results = sorted(
results,
key=lambda x: (x["item_index"], x["caption_index"])
)
for result in results:
item_index = result["item_index"]
check = {
"caption_index": result["caption_index"],
"caption": result["caption"],
"resolved_image_path": result.get("resolved_image_path"),
"clip_cosine": result.get("clip_cosine"),
"clip_score": result.get("clip_score"),
"clip_status": result.get("clip_status"),
"clip_reason": result.get("clip_reason", "")
}
target_data[item_index]["caption_checks"].append(check)
return target_data
# ============================================================
# ์š”์•ฝ ์ถœ๋ ฅ
# ============================================================
def print_summary(
target_data: list[dict[str, Any]],
flat_items: list[dict[str, Any]],
results: list[dict[str, Any]]
) -> None:
status_count = {}
valid_scores = []
for result in results:
status = result.get("clip_status", "unknown")
status_count[status] = status_count.get(status, 0) + 1
if isinstance(result.get("clip_score"), float):
valid_scores.append(result["clip_score"])
print("\n===== CLIP Score Summary =====")
print(f"check all classes: {CHECK_ALL_CLASSES}")
if CHECK_ALL_CLASSES:
print(f"data raw root dir: {DATA_RAW_ROOT_DIR}")
else:
print(f"target class dir: {TARGET_CLASS_DIR}")
print(f"target class name: {get_target_class_name()}")
print(f"target images: {len(target_data)}")
print(f"target image-caption pairs: {len(flat_items)}")
print(f"status count: {status_count}")
if valid_scores:
print(f"min score: {min(valid_scores):.4f}")
print(f"max score: {max(valid_scores):.4f}")
print(f"mean score: {np.mean(valid_scores):.4f}")
print(f"bottom {FAIL_BOTTOM_PERCENT}% threshold: {np.percentile(valid_scores, FAIL_BOTTOM_PERCENT):.4f}")
print(f"bottom {REVIEW_BOTTOM_PERCENT}% threshold: {np.percentile(valid_scores, REVIEW_BOTTOM_PERCENT):.4f}")
# ============================================================
# ์‹คํ–‰
# ============================================================
def main():
if not INPUT_JSON_PATH.exists():
raise FileNotFoundError(f"input file not found: {INPUT_JSON_PATH}")
if CHECK_ALL_CLASSES:
if not DATA_RAW_ROOT_DIR.exists():
raise FileNotFoundError(f"data raw root directory not found: {DATA_RAW_ROOT_DIR}")
else:
if not TARGET_CLASS_DIR.exists():
raise FileNotFoundError(f"target class directory not found: {TARGET_CLASS_DIR}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"device: {device}")
print(f"loading model: {MODEL_NAME}")
model = CLIPModel.from_pretrained(MODEL_NAME, token=hf_token).to(device)
processor = CLIPProcessor.from_pretrained(MODEL_NAME, token=hf_token)
model.eval()
data = load_json(INPUT_JSON_PATH)
target_data, flat_items = flatten_caption_items(data)
if not target_data:
raise ValueError("๊ฒ€์ˆ˜ ๋Œ€์ƒ ๋ฐ์ดํ„ฐ๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค. CHECK_ALL_CLASSES ๋˜๋Š” TARGET_CLASS_DIR ์„ค์ •์„ ํ™•์ธํ•˜์„ธ์š”.")
results = compute_clip_scores(
flat_items=flat_items,
model=model,
processor=processor,
device=device
)
assign_clip_status(results)
checked_data = attach_results_to_data(target_data, results)
save_json(checked_data, OUTPUT_JSON_PATH)
print_summary(target_data, flat_items, results)
print(f"\nsaved: {OUTPUT_JSON_PATH}")
if __name__ == "__main__":
main()