import os import torch import torch.nn as nn import math from einops import rearrange import matplotlib.pyplot as plt class PositionalEncoding(nn.Module): def __init__(self, d_model, max_len): super().__init__() pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0,d_model, 2) * (-math.log(10000.0)/d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) self.register_buffer("pe", pe) def forward(self, caption): return self.pe[:, :caption.size(1)] + caption class MHA(nn.Module): def __init__(self, d_model, nhead, drop_p): super().__init__() self.d_model = d_model self.nhead = nhead self.dropout = nn.Dropout(drop_p) self.fc_q = nn.Linear(d_model, d_model) self.fc_k = nn.Linear(d_model, d_model) self.fc_v = nn.Linear(d_model, d_model) self.fc_o = nn.Linear(d_model, d_model) self.scale = math.sqrt(d_model // nhead) def forward(self, Q, K, V, mask=None): Q = self.fc_q(Q) K = self.fc_k(K) V = self.fc_v(V) Q = rearrange(Q, 'batch seq_len (nhead dim) -> batch nhead seq_len dim', nhead = self.nhead) K = rearrange(K, 'batch seq_len (nhead dim) -> batch nhead seq_len dim', nhead = self.nhead) V = rearrange(V, 'batch seq_len (nhead dim) -> batch nhead seq_len dim', nhead = self.nhead) attention_score = Q @ K.transpose(-1, -2) / self.scale # if mask is not None: attention_score = attention_score.masked_fill(mask, -1e10) attention_weights = torch.softmax(attention_score, dim=-1) # B, nhead, seq_len, (seq_len or 49) attention_weights = self.dropout(attention_weights) attention = attention_weights @ V x = rearrange(attention, 'batch nhead seq_len dim -> batch seq_len (nhead dim)') x = self.fc_o(x) return x, attention_weights class FeedForward(nn.Module): def __init__(self, d_model, d_ff, drop_p): super().__init__() self.d_model = d_model self.d_ff = d_ff self.linear = nn.Sequential( nn.Linear(d_model, d_ff), nn.ReLU(), # nn.GELU(), nn.Dropout(drop_p), nn.Linear(d_ff, d_model) ) def forward(self, x): x = self.linear(x) return x class DecoderLayer(nn.Module): def __init__(self, d_model, nhead, d_ff, drop_p): super().__init__() self.MHA = MHA(d_model, nhead, drop_p) self.MHA_LN = nn.LayerNorm(d_model) self.Cross_MHA = MHA(d_model, nhead, drop_p) self.Cross_MHA_LN = nn.LayerNorm(d_model) self.FFN = FeedForward(d_model, d_ff, drop_p) self.FFN_LN = nn.LayerNorm(d_model) self.drop = nn.Dropout(drop_p) def forward(self, x, features, mask): residual, dec_weights = self.MHA(x, x, x, mask) residual = self.drop(residual) x = self.MHA_LN(x + residual) residual, enc_dec_weights = self.Cross_MHA(x, features, features, None) residual = self.drop(residual) x = self.Cross_MHA_LN(x + residual) residual = self.FFN(x) residual = self.drop(residual) x = self.FFN_LN(x + residual) return x, dec_weights, enc_dec_weights class DecoderTransformer(nn.Module): def __init__(self, n_layers=4, nhead=8, d_model=512, d_ff=2048, voca_size=10000, max_len=30, drop_p=0.1): super().__init__() self.nhead = nhead self.max_len = max_len self.embedding = nn.Embedding(voca_size, d_model) self.pos_enc = PositionalEncoding(d_model, max_len) # self.pos_enc = nn.Embedding(max_len, d_model) self.layers = nn.ModuleList([DecoderLayer(d_model, nhead, d_ff, drop_p) for _ in range(n_layers)]) self.fc_out = nn.Linear(d_model, voca_size) def make_mask(self, T, device): mask = torch.triu(torch.ones(T, T, device=device), diagonal=1).bool() mask = mask.unsqueeze(0).unsqueeze(0) return mask def show_dec_atten(self, atten, generated_caption, n_layer, save_path): # layers, nhead, seq_len, seq_len) atten = atten.mean(dim=1) # layers, seq_len, seq_len) atten = atten[n_layer-1] # seq_len, seq_len atten = atten.detach().cpu().numpy() seq_len = len(generated_caption) atten = atten[:seq_len, :seq_len] fig, ax = plt.subplots(figsize=(8, 8)) im = ax.imshow(atten, cmap="bone") ax.set_xticks(range(seq_len)) ax.set_yticks(range(seq_len)) ax.set_xticklabels(generated_caption, rotation=45, ha="right") ax.set_yticklabels(generated_caption) plt.colorbar(im) plt.tight_layout() os.makedirs(os.path.dirname(save_path), exist_ok=True) # 저장 plt.savefig(save_path, dpi=300, bbox_inches="tight") plt.close() def show_cross_atten(self, atten, generated_caption, n_layer, image, save_path): # layers, nhead, seq_len, 49) import cv2 import numpy as np # ------------------------ # attention 전처리 # ------------------------ atten = atten.mean(dim=1) # layers, seq_len, seq_len) atten = atten[n_layer-1] # seq_len, seq_len atten = atten.detach().cpu().numpy() seq_len = len(generated_caption) atten = atten[:seq_len] # ------------------------ # 이미지 준비 # ------------------------ if isinstance(image, torch.Tensor): image = image.detach().cpu() # (C,H,W) -> (H,W,C) image = image.permute(1, 2, 0).numpy() # normalize 복원 (ImageNet 기준) mean = np.array([0.485, 0.456, 0.406]) std = np.array([0.229, 0.224, 0.225]) image = image * std + mean image = np.clip(image, 0, 1) H, W = image.shape[:2] # ------------------------ # subplot 설정 # ------------------------ n_cols = 4 n_rows = math.ceil(seq_len / n_cols) fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 4 * n_rows)) axes = np.array(axes).reshape(-1) # ------------------------ # 단어별 overlay # ------------------------ for i in range(seq_len): # 49 -> 7x7 num_patch = atten.shape[-1] side = int(math.sqrt(num_patch)) heatmap = atten[i].reshape(side, side) # resize heatmap = cv2.resize(heatmap, (W, H), interpolation=cv2.INTER_CUBIC) # normalize heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-8) ax = axes[i] # 원본 이미지 ax.imshow(image) # heatmap overlay ax.imshow(heatmap, cmap="jet", alpha=0.45) ax.set_title(generated_caption[i]) ax.axis("off") # 남는 subplot 제거 for i in range(seq_len, len(axes)): axes[i].axis("off") plt.tight_layout() os.makedirs(os.path.dirname(save_path), exist_ok=True) plt.savefig(save_path, dpi=300, bbox_inches="tight") plt.close() def forward(self, features, x): mask = self.make_mask(x.shape[1], x.device) # pos = torch.arange(x.shape[1], device=x.device).expand_as(x) # expand_as(x) = x의 shape에 맞춰서 view해줌 (x.shape[1],) -> (B,x.shape[1]) x = self.embedding(x) # x = x + self.pos_enc(pos) x = self.pos_enc(x) for layer in self.layers: x, dec_weights, enc_dec_weights = layer(x, features, mask) x = self.fc_out(x) return x def generate(self, features, start_token, end_token): generated = start_token.unsqueeze(1) # B, 1 finished = torch.zeros(generated.size(0), dtype=torch.bool, device=features.device) # B, for _ in range(self.max_len - 1): # pos = torch.arange(generated.shape[1], device=generated.device).expand_as(generated) # expand_as(x) = x의 shape에 맞춰서 view해줌 (x.shape[1],) -> (B,x.shape[1]) x = self.embedding(generated) # B, 1, d_model # x = x + self.pos_enc(pos) x = self.pos_enc(x) # B, 1, d_model mask = self.make_mask(generated.shape[1], generated.device) dec_atten = [] enc_dec_atten = [] # x->(B, 1, d_model), dec_weights->(B, nhead, seq_len, seq_len), enc_dec_weights->(B, nhead, seq_len, 49) for layer in self.layers: x, dec_weights, enc_dec_weights = layer(x, features, mask) dec_atten.append(dec_weights.detach().cpu()) # layers*[B, nhead, seq_len, seq_len] enc_dec_atten.append(enc_dec_weights.detach().cpu()) # layers*[B, nhead, seq_len, 49] dec_atten = torch.stack(dec_atten, dim=1) enc_dec_atten = torch.stack(enc_dec_atten, dim=1) logits = self.fc_out(x) # B, 1, voca_size pred = torch.argmax(logits[:,-1,:], dim=-1) # B, pred[finished] = end_token generated = torch.cat([generated, pred.unsqueeze(1)], dim=1) # cat[(B, 1), (B, 1)] -> B, 2 finished |= (pred == end_token) if finished.all(): break # (B, seq_len-1), (B, layers, nhead, seq_len, seq_len), (B, layers, nhead, seq_len, 49) return generated[:,1:].tolist(), dec_atten, enc_dec_atten def generate_beam(self, features, start_token, end_token, beam_size, length_alpha=0.7): all_generated = [] all_dec_atten = [] all_enc_dec_atten = [] def normalized_score(seq, score): return score / (len(seq) ** length_alpha) for b in range(len(features)): feature = features[b].unsqueeze(0) # 1, seq, dim beams = [([start_token[b].item()], 0.0, None, None)] # seq, score for _ in range(self.max_len - 1): candidates = [] for seq, score, prev_dec, prev_enc_dec in beams: if seq[-1] == end_token: candidates.append((seq, score, prev_dec, prev_enc_dec)) continue input_seq = torch.tensor(seq, device=feature.device).unsqueeze(0) # 1, seq x = self.embedding(input_seq) # 1, seq, d_model x = self.pos_enc(x) # 1, seq, d_model mask = self.make_mask(input_seq.shape[1], input_seq.device) dec_atten = [] enc_dec_atten = [] # x->(1, 1, d_model), dec_weights->(1, nhead, seq_len, seq_len), enc_dec_weights->(1, nhead, seq_len, 49) for layer in self.layers: x, dec_weights, enc_dec_weights = layer(x, feature, mask) dec_atten.append(dec_weights.detach().cpu()) # layers*[1, nhead, seq_len, seq_len] enc_dec_atten.append(enc_dec_weights.detach().cpu()) # layers*[1, nhead, seq_len, seq_len] dec_atten = torch.stack(dec_atten, dim=1) # 1, layers, nhead, seq_len, seq_len enc_dec_atten = torch.stack(enc_dec_atten, dim=1) # 1, layers, nhead, seq_len, 49 logits = self.fc_out(x) # 1, 1, voca_size log_probs = torch.log_softmax(logits[:, -1, :], dim=-1) topk_probs, topk_ids = torch.topk(log_probs, beam_size, dim=-1) for k in range(beam_size): token = topk_ids[0, k].item() token_score = topk_probs[0, k].item() candidates.append((seq + [token], score + token_score, dec_atten, enc_dec_atten)) beams = sorted(candidates, key=lambda x: normalized_score(x[0], x[1]), reverse=True)[:beam_size] if all(seq[-1] == end_token for seq, _, _, _ in beams): break best_seq, _, best_dec_atten, best_enc_dec_atten = beams[0] all_generated.append(best_seq[1:]) # sos 제거 all_dec_atten.append(best_dec_atten.squeeze(0)) all_enc_dec_atten.append(best_enc_dec_atten.squeeze(0)) return all_generated, all_dec_atten, all_enc_dec_atten