| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import BertModel |
| from transformers.models.clip.modeling_clip import CLIPTextModel |
| from transformers.models.mpnet.modeling_mpnet import MPNetModel |
| from transformers.trainer import logger |
|
|
| from .align_transformers import build_align_transformer |
| from .common_layers import BasePreTrainedModel |
| from .configuration_radzero import CxrAlignConfig |
| from .losses import KeyPhraseAlignmentLoss |
| from .text_encoders import build_text_encoder |
| from .vision_encoders import Dinov2Model, build_vision_encoder |
|
|
|
|
| class CxrAlignModel(BasePreTrainedModel): |
|
|
| config_class = CxrAlignConfig |
|
|
| def build_vision_model(self, config: CxrAlignConfig): |
| vision_config = config.vision_config |
| vision_config.pretrained_dir = config.pretrained_dir |
| vision_model = build_vision_encoder(vision_config) |
| return vision_model |
|
|
| def build_text_model(self, config: CxrAlignConfig): |
| text_config = config.text_config |
| text_model = build_text_encoder(text_config) |
| return text_model |
|
|
| def build_align_transformer_model(self, config: CxrAlignConfig): |
| align_transformer_config = config.align_transformer_config |
| align_transformer = build_align_transformer(align_transformer_config) |
|
|
| return align_transformer |
|
|
| def __init__(self, config: CxrAlignConfig): |
| super().__init__(config) |
|
|
| logger.info("Build vision model ...") |
| self.vision_model = self.build_vision_model(config) |
|
|
| logger.info("Build text model ...") |
| self.text_model = self.build_text_model(config) |
|
|
| if ( |
| isinstance(self.text_model, CLIPTextModel) |
| or isinstance(self.text_model, MPNetModel) |
| or isinstance(self.text_model, BertModel) |
| ): |
| text_dim = self.text_model.config.hidden_size |
|
|
| self.hidden_size = config.align_transformer_config.hidden_size |
|
|
| if config.text_config.use_text_projection: |
| self.text_projector = nn.Linear(text_dim, 2 * self.hidden_size) |
| else: |
| self.text_projector = None |
|
|
| logger.info("Build align transformer model ...") |
| self.align_transformer = self.build_align_transformer_model(config) |
|
|
| logger.info("Build loss functions ...") |
| loss_cfg = config.kwargs["loss"] |
| self.loss_ratio = dict() |
| self.loss_fns = nn.ModuleDict() |
| for loss_type, ratio in zip(loss_cfg["apply"], loss_cfg["ratio"]): |
| logger.info(f"Build {loss_type} loss function ...") |
| if loss_cfg[loss_type] is None: |
| loss_cfg[loss_type] = dict() |
| if torch.distributed.is_available() and torch.distributed.is_initialized(): |
| loss_cfg[loss_type]["rank"] = torch.distributed.get_rank() |
| loss_cfg[loss_type]["world_size"] = torch.distributed.get_world_size() |
| self.loss_fns[loss_type] = eval(loss_type)(**loss_cfg[loss_type]) |
| self.loss_ratio[loss_type] = ratio |
|
|
| self.compute_logits_type = config.kwargs.get("compute_logits_type") |
| self.use_negative_logits = config.kwargs.get("use_negative_logits") |
|
|
| self.module_to_update = config.kwargs.get("module_to_update") |
|
|
| def forward_vision_model(self, pixel_values): |
|
|
| if isinstance(self.vision_model, Dinov2Model): |
| vision_tokens = self.vision_model(pixel_values)["last_hidden_state"] |
|
|
| else: |
| raise NotImplementedError |
|
|
| vision_tokens = self.align_transformer(vision_tokens) |
|
|
| cls_token = vision_tokens[:, 0] |
| patch_tokens = vision_tokens[:, 1:] |
| image_features = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1) |
| image_features = F.normalize(image_features, p=2, dim=1) |
|
|
| outputs = {} |
| outputs["vision_tokens"] = vision_tokens |
| outputs["image_cls_token"] = cls_token |
| outputs["image_patch_tokens"] = patch_tokens |
| outputs["image_features"] = image_features |
|
|
| return outputs |
|
|
| def forward_text_model(self, encoded_input): |
| text_outputs = {} |
|
|
| if isinstance(self.text_model, MPNetModel): |
| model_output = self.text_model( |
| input_ids=encoded_input["input_ids"], |
| attention_mask=encoded_input["attention_mask"], |
| ) |
|
|
| token_embeddings = model_output[ |
| 0 |
| ] |
|
|
| |
| if self.text_projector is not None: |
| token_embeddings = self.text_projector(token_embeddings) |
|
|
| |
| if self.config.text_config.use_cls_token: |
| text_features = token_embeddings[:, 0, :] |
|
|
| else: |
| |
| input_mask_expanded = ( |
| encoded_input["attention_mask"] |
| .unsqueeze(-1) |
| .expand(token_embeddings.size()) |
| .float() |
| ) |
| text_features = torch.sum( |
| token_embeddings * input_mask_expanded, 1 |
| ) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) |
|
|
| else: |
| raise NotImplementedError |
|
|
| text_outputs["text_features_wo_l2_norm"] = text_features |
| text_outputs["text_features"] = F.normalize(text_features, p=2, dim=1) |
|
|
| return text_outputs |
|
|
| def forward( |
| self, |
| pixel_values, |
| encoded_key_phrases=None, |
| return_loss=True, |
| **kwargs, |
| ): |
| vision_outputs = self.forward_vision_model(pixel_values) |
|
|
| outputs = {} |
| outputs.update(vision_outputs) |
|
|
| |
| if return_loss: |
| loss = 0 |
| losses = {} |
|
|
| for loss_type, loss_fn in self.loss_fns.items(): |
| if isinstance(loss_fn, KeyPhraseAlignmentLoss): |
| loss_outputs = loss_fn( |
| encoded_key_phrases, |
| outputs["vision_tokens"], |
| self.forward_text_model, |
| ) |
| key_phrase_alignment_losses = loss_outputs["losses"] |
| losses["key_phrase_alignment_loss"] = ( |
| key_phrase_alignment_losses.pop("loss") |
| ) |
| for loss_name, loss_value in key_phrase_alignment_losses.items(): |
| losses[loss_name] = loss_value |
| loop_loss = losses["key_phrase_alignment_loss"] |
| else: |
| raise NotImplementedError |
|
|
| loss += loop_loss * self.loss_ratio[loss_type] |
|
|
| losses["loss"] = loss |
|
|
| outputs["losses"] = losses |
|
|
| return outputs |
|
|
| def compute_logits( |
| self, |
| pixel_values, |
| encoded_key_phrases, |
| **kwargs, |
| ): |
| vision_outputs = self.forward_vision_model(pixel_values) |
|
|
| outputs = {} |
|
|
| if self.compute_logits_type == "key_phrase_alignment": |
|
|
| splited_key_phrases = [ |
| { |
| "input_ids": encoded_key_phrases[0]["input_ids"][i : i + 1], |
| "attention_mask": encoded_key_phrases[0]["attention_mask"][ |
| i : i + 1 |
| ], |
| } |
| for i in range(encoded_key_phrases[0]["input_ids"].size(0)) |
| ] |
|
|
| loss_outputs = self.loss_fns["KeyPhraseAlignmentLoss"]( |
| splited_key_phrases, |
| vision_outputs["vision_tokens"], |
| self.forward_text_model, |
| ddp_gather=False, |
| need_attn_weights=True, |
| compute_loss=False, |
| ) |
| outputs.update(loss_outputs) |
|
|
| |
| outputs["similarity_scores"] = torch.mean( |
| torch.stack(loss_outputs["t2i_attn_weights"]), dim=0 |
| ) |
|
|
| |
| if self.loss_fns["KeyPhraseAlignmentLoss"].use_vision_cls_token: |
| outputs["similarity_scores"] = outputs["similarity_scores"][:, :, 1:] |
|
|
| |
| logits = loss_outputs["t2i_logits"] |
| logits = logits.T |
|
|
| logits = ( |
| logits / self.loss_fns["KeyPhraseAlignmentLoss"].loss_temperature.exp() |
| ) |
|
|
| outputs["logits"] = logits |
| return outputs |
|
|