import torch import torch.nn as nn class DecoderGRU(nn.Module): def __init__( self, voca_size=10000, emd_size=256, hidden_size=512, max_len=30 ): super().__init__() self.max_len = max_len self.h = nn.Linear(512, hidden_size) self.embedding = nn.Embedding(voca_size, emd_size) self.gru = nn.GRU(emd_size, hidden_size, batch_first=True) self.fc = nn.Linear(hidden_size, voca_size) def forward(self, feature, caption): caption = caption.to(feature.device) h = self.h(feature).unsqueeze(0) input = self.embedding(caption) out, h = self.gru(input, h) out = self.fc(out) return out def generate(self, feature, start_token, end_token): device = feature.device start_token = start_token.to(device) h = self.h(feature).unsqueeze(0) generated = start_token.unsqueeze(1) finished = torch.zeros(generated.size(0), dtype=torch.bool, device=device) input = self.embedding(start_token).unsqueeze(1) for _ in range(self.max_len): out, h = self.gru(input, h) logits = self.fc(out).squeeze(1) pred = torch.argmax(logits, dim=1) pred[finished] = end_token generated = torch.cat([generated, pred.unsqueeze(1)], dim=1) finished |= (pred == end_token) if finished.all(): break input = self.embedding(pred).unsqueeze(1) return generated[:, 1:].tolist()