Mini-ImageNet / src /models /transformer.py
ImAMJayKIM's picture
Upload 96 files
c1596ac verified
import os
import torch
import torch.nn as nn
import math
from einops import rearrange
import matplotlib.pyplot as plt
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0,d_model, 2) * (-math.log(10000.0)/d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer("pe", pe)
def forward(self, caption):
return self.pe[:, :caption.size(1)] + caption
class MHA(nn.Module):
def __init__(self, d_model, nhead, drop_p):
super().__init__()
self.d_model = d_model
self.nhead = nhead
self.dropout = nn.Dropout(drop_p)
self.fc_q = nn.Linear(d_model, d_model)
self.fc_k = nn.Linear(d_model, d_model)
self.fc_v = nn.Linear(d_model, d_model)
self.fc_o = nn.Linear(d_model, d_model)
self.scale = math.sqrt(d_model // nhead)
def forward(self, Q, K, V, mask=None):
Q = self.fc_q(Q)
K = self.fc_k(K)
V = self.fc_v(V)
Q = rearrange(Q, 'batch seq_len (nhead dim) -> batch nhead seq_len dim', nhead = self.nhead)
K = rearrange(K, 'batch seq_len (nhead dim) -> batch nhead seq_len dim', nhead = self.nhead)
V = rearrange(V, 'batch seq_len (nhead dim) -> batch nhead seq_len dim', nhead = self.nhead)
attention_score = Q @ K.transpose(-1, -2) / self.scale #
if mask is not None:
attention_score = attention_score.masked_fill(mask, -1e10)
attention_weights = torch.softmax(attention_score, dim=-1) # B, nhead, seq_len, (seq_len or 49)
attention_weights = self.dropout(attention_weights)
attention = attention_weights @ V
x = rearrange(attention, 'batch nhead seq_len dim -> batch seq_len (nhead dim)')
x = self.fc_o(x)
return x, attention_weights
class FeedForward(nn.Module):
def __init__(self, d_model, d_ff, drop_p):
super().__init__()
self.d_model = d_model
self.d_ff = d_ff
self.linear = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
# nn.GELU(),
nn.Dropout(drop_p),
nn.Linear(d_ff, d_model)
)
def forward(self, x):
x = self.linear(x)
return x
class DecoderLayer(nn.Module):
def __init__(self, d_model, nhead, d_ff, drop_p):
super().__init__()
self.MHA = MHA(d_model, nhead, drop_p)
self.MHA_LN = nn.LayerNorm(d_model)
self.Cross_MHA = MHA(d_model, nhead, drop_p)
self.Cross_MHA_LN = nn.LayerNorm(d_model)
self.FFN = FeedForward(d_model, d_ff, drop_p)
self.FFN_LN = nn.LayerNorm(d_model)
self.drop = nn.Dropout(drop_p)
def forward(self, x, features, mask):
residual, dec_weights = self.MHA(x, x, x, mask)
residual = self.drop(residual)
x = self.MHA_LN(x + residual)
residual, enc_dec_weights = self.Cross_MHA(x, features, features, None)
residual = self.drop(residual)
x = self.Cross_MHA_LN(x + residual)
residual = self.FFN(x)
residual = self.drop(residual)
x = self.FFN_LN(x + residual)
return x, dec_weights, enc_dec_weights
class DecoderTransformer(nn.Module):
def __init__(self, n_layers=4, nhead=8, d_model=512, d_ff=2048, voca_size=10000, max_len=30, drop_p=0.1):
super().__init__()
self.nhead = nhead
self.max_len = max_len
self.embedding = nn.Embedding(voca_size, d_model)
self.pos_enc = PositionalEncoding(d_model, max_len)
# self.pos_enc = nn.Embedding(max_len, d_model)
self.layers = nn.ModuleList([DecoderLayer(d_model, nhead, d_ff, drop_p) for _ in range(n_layers)])
self.fc_out = nn.Linear(d_model, voca_size)
def make_mask(self, T, device):
mask = torch.triu(torch.ones(T, T, device=device), diagonal=1).bool()
mask = mask.unsqueeze(0).unsqueeze(0)
return mask
def show_dec_atten(self, atten, generated_caption, n_layer, save_path): # layers, nhead, seq_len, seq_len)
atten = atten.mean(dim=1) # layers, seq_len, seq_len)
atten = atten[n_layer-1] # seq_len, seq_len
atten = atten.detach().cpu().numpy()
seq_len = len(generated_caption)
atten = atten[:seq_len, :seq_len]
fig, ax = plt.subplots(figsize=(8, 8))
im = ax.imshow(atten, cmap="bone")
ax.set_xticks(range(seq_len))
ax.set_yticks(range(seq_len))
ax.set_xticklabels(generated_caption, rotation=45, ha="right")
ax.set_yticklabels(generated_caption)
plt.colorbar(im)
plt.tight_layout()
os.makedirs(os.path.dirname(save_path), exist_ok=True)
# 저장
plt.savefig(save_path, dpi=300, bbox_inches="tight")
plt.close()
def show_cross_atten(self, atten, generated_caption, n_layer, image, save_path): # layers, nhead, seq_len, 49)
import cv2
import numpy as np
# ------------------------
# attention 전처리
# ------------------------
atten = atten.mean(dim=1) # layers, seq_len, seq_len)
atten = atten[n_layer-1] # seq_len, seq_len
atten = atten.detach().cpu().numpy()
seq_len = len(generated_caption)
atten = atten[:seq_len]
# ------------------------
# 이미지 준비
# ------------------------
if isinstance(image, torch.Tensor):
image = image.detach().cpu()
# (C,H,W) -> (H,W,C)
image = image.permute(1, 2, 0).numpy()
# normalize 복원 (ImageNet 기준)
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
image = image * std + mean
image = np.clip(image, 0, 1)
H, W = image.shape[:2]
# ------------------------
# subplot 설정
# ------------------------
n_cols = 4
n_rows = math.ceil(seq_len / n_cols)
fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 4 * n_rows))
axes = np.array(axes).reshape(-1)
# ------------------------
# 단어별 overlay
# ------------------------
for i in range(seq_len):
# 49 -> 7x7
num_patch = atten.shape[-1]
side = int(math.sqrt(num_patch))
heatmap = atten[i].reshape(side, side)
# resize
heatmap = cv2.resize(heatmap, (W, H), interpolation=cv2.INTER_CUBIC)
# normalize
heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-8)
ax = axes[i]
# 원본 이미지
ax.imshow(image)
# heatmap overlay
ax.imshow(heatmap, cmap="jet", alpha=0.45)
ax.set_title(generated_caption[i])
ax.axis("off")
# 남는 subplot 제거
for i in range(seq_len, len(axes)):
axes[i].axis("off")
plt.tight_layout()
os.makedirs(os.path.dirname(save_path), exist_ok=True)
plt.savefig(save_path, dpi=300, bbox_inches="tight")
plt.close()
def forward(self, features, x):
mask = self.make_mask(x.shape[1], x.device)
# pos = torch.arange(x.shape[1], device=x.device).expand_as(x) # expand_as(x) = x의 shape에 맞춰서 view해줌 (x.shape[1],) -> (B,x.shape[1])
x = self.embedding(x)
# x = x + self.pos_enc(pos)
x = self.pos_enc(x)
for layer in self.layers:
x, dec_weights, enc_dec_weights = layer(x, features, mask)
x = self.fc_out(x)
return x
def generate(self, features, start_token, end_token):
generated = start_token.unsqueeze(1) # B, 1
finished = torch.zeros(generated.size(0), dtype=torch.bool, device=features.device) # B,
for _ in range(self.max_len - 1):
# pos = torch.arange(generated.shape[1], device=generated.device).expand_as(generated) # expand_as(x) = x의 shape에 맞춰서 view해줌 (x.shape[1],) -> (B,x.shape[1])
x = self.embedding(generated) # B, 1, d_model
# x = x + self.pos_enc(pos)
x = self.pos_enc(x) # B, 1, d_model
mask = self.make_mask(generated.shape[1], generated.device)
dec_atten = []
enc_dec_atten = []
# x->(B, 1, d_model), dec_weights->(B, nhead, seq_len, seq_len), enc_dec_weights->(B, nhead, seq_len, 49)
for layer in self.layers:
x, dec_weights, enc_dec_weights = layer(x, features, mask)
dec_atten.append(dec_weights.detach().cpu()) # layers*[B, nhead, seq_len, seq_len]
enc_dec_atten.append(enc_dec_weights.detach().cpu()) # layers*[B, nhead, seq_len, 49]
dec_atten = torch.stack(dec_atten, dim=1)
enc_dec_atten = torch.stack(enc_dec_atten, dim=1)
logits = self.fc_out(x) # B, 1, voca_size
pred = torch.argmax(logits[:,-1,:], dim=-1) # B,
pred[finished] = end_token
generated = torch.cat([generated, pred.unsqueeze(1)], dim=1) # cat[(B, 1), (B, 1)] -> B, 2
finished |= (pred == end_token)
if finished.all():
break
# (B, seq_len-1), (B, layers, nhead, seq_len, seq_len), (B, layers, nhead, seq_len, 49)
return generated[:,1:].tolist(), dec_atten, enc_dec_atten
def generate_beam(self, features, start_token, end_token, beam_size, length_alpha=0.7):
all_generated = []
all_dec_atten = []
all_enc_dec_atten = []
def normalized_score(seq, score):
return score / (len(seq) ** length_alpha)
for b in range(len(features)):
feature = features[b].unsqueeze(0) # 1, seq, dim
beams = [([start_token[b].item()], 0.0, None, None)] # seq, score
for _ in range(self.max_len - 1):
candidates = []
for seq, score, prev_dec, prev_enc_dec in beams:
if seq[-1] == end_token:
candidates.append((seq, score, prev_dec, prev_enc_dec))
continue
input_seq = torch.tensor(seq, device=feature.device).unsqueeze(0) # 1, seq
x = self.embedding(input_seq) # 1, seq, d_model
x = self.pos_enc(x) # 1, seq, d_model
mask = self.make_mask(input_seq.shape[1], input_seq.device)
dec_atten = []
enc_dec_atten = []
# x->(1, 1, d_model), dec_weights->(1, nhead, seq_len, seq_len), enc_dec_weights->(1, nhead, seq_len, 49)
for layer in self.layers:
x, dec_weights, enc_dec_weights = layer(x, feature, mask)
dec_atten.append(dec_weights.detach().cpu()) # layers*[1, nhead, seq_len, seq_len]
enc_dec_atten.append(enc_dec_weights.detach().cpu()) # layers*[1, nhead, seq_len, seq_len]
dec_atten = torch.stack(dec_atten, dim=1) # 1, layers, nhead, seq_len, seq_len
enc_dec_atten = torch.stack(enc_dec_atten, dim=1) # 1, layers, nhead, seq_len, 49
logits = self.fc_out(x) # 1, 1, voca_size
log_probs = torch.log_softmax(logits[:, -1, :], dim=-1)
topk_probs, topk_ids = torch.topk(log_probs, beam_size, dim=-1)
for k in range(beam_size):
token = topk_ids[0, k].item()
token_score = topk_probs[0, k].item()
candidates.append((seq + [token], score + token_score, dec_atten, enc_dec_atten))
beams = sorted(candidates, key=lambda x: normalized_score(x[0], x[1]), reverse=True)[:beam_size]
if all(seq[-1] == end_token for seq, _, _, _ in beams):
break
best_seq, _, best_dec_atten, best_enc_dec_atten = beams[0]
all_generated.append(best_seq[1:]) # sos 제거
all_dec_atten.append(best_dec_atten.squeeze(0))
all_enc_dec_atten.append(best_enc_dec_atten.squeeze(0))
return all_generated, all_dec_atten, all_enc_dec_atten