| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from dataclasses import dataclass |
| | from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence |
| |
|
| | import torch |
| | from transformers import DataCollatorForSeq2Seq |
| |
|
| |
|
| | if TYPE_CHECKING: |
| | from transformers import ProcessorMixin |
| |
|
| | from .template import Template |
| |
|
| |
|
| | def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor": |
| | r""" |
| | Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len), |
| | while handles packed sequences and transforms the mask to lower triangular form to prevent future peeking. |
| | |
| | e.g. |
| | ```python |
| | # input |
| | [[1, 1, 2, 2, 2, 0]] |
| | # output |
| | [ |
| | [ |
| | [ |
| | [o, x, x, x, x, x], |
| | [o, o, x, x, x, x], |
| | [x, x, o, x, x, x], |
| | [x, x, o, o, x, x], |
| | [x, x, o, o, o, x], |
| | [x, x, x, x, x, x], |
| | ] |
| | ] |
| | ] |
| | ``` |
| | where `o` equals to `0.0`, `x` equals to `min_dtype`. |
| | """ |
| | bsz, seq_len = attention_mask_with_indices.size() |
| | min_dtype = torch.finfo(dtype).min |
| | expanded_mask = attention_mask_with_indices[:, None, None, :].expand(bsz, 1, seq_len, seq_len) |
| | |
| | padding_mask = torch.where(expanded_mask != 0, 1, 0) |
| | |
| | attention_mask_4d = torch.eq(expanded_mask, expanded_mask.transpose(-1, -2)).int() * padding_mask |
| | |
| | attention_mask_4d *= torch.tril(torch.ones((seq_len, seq_len), dtype=torch.long)) |
| | |
| | attention_mask_4d = torch.where(attention_mask_4d != 0, torch.tensor(0, dtype=dtype), min_dtype) |
| | return attention_mask_4d |
| |
|
| |
|
| | @dataclass |
| | class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): |
| | r""" |
| | Data collator that supports VLMs. |
| | |
| | Features should contain input_ids, attention_mask, labels and images. |
| | """ |
| |
|
| | template: Optional["Template"] = None |
| | processor: Optional["ProcessorMixin"] = None |
| |
|
| | def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]: |
| | batch_images, batch_videos, batch_imglens, batch_vidlens, batch_seqlens = [], [], [], [], [] |
| | for feature in features: |
| | images = feature.pop("images", None) or [] |
| | videos = feature.pop("videos", None) or [] |
| | batch_images.extend(images) |
| | batch_videos.extend(videos) |
| | batch_imglens.append(len(images)) |
| | batch_vidlens.append(len(videos)) |
| | batch_seqlens.append(len(feature["input_ids"])) |
| |
|
| | mm_inputs = self.template.mm_plugin.get_mm_inputs( |
| | batch_images, batch_videos, batch_imglens, batch_vidlens, batch_seqlens, self.processor |
| | ) |
| | if "token_type_ids" in mm_inputs: |
| | token_type_ids = mm_inputs.pop("token_type_ids") |
| | for i, feature in enumerate(features): |
| | feature["token_type_ids"] = token_type_ids[i] |
| |
|
| | features: Dict[str, "torch.Tensor"] = super().__call__(features) |
| | features.update(mm_inputs) |
| | return features |
| |
|
| |
|
| | @dataclass |
| | class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq): |
| | r""" |
| | Data collator for 4d attention mask. |
| | """ |
| |
|
| | block_diag_attn: bool = False |
| | attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager" |
| | compute_dtype: "torch.dtype" = torch.float32 |
| |
|
| | def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]: |
| | features = super().__call__(features) |
| | if self.block_diag_attn and self.attn_implementation != "flash_attention_2": |
| | features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype) |
| |
|
| | return features |
| |
|
| |
|
| | @dataclass |
| | class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq): |
| | r""" |
| | Data collator for pairwise data. |
| | """ |
| |
|
| | def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]: |
| | r""" |
| | Pads batched data to the longest sequence in the batch. |
| | |
| | We generate 2 * n examples where the first n examples represent chosen examples and |
| | the last n examples represent rejected examples. |
| | """ |
| | concatenated_features = [] |
| | for key in ("chosen", "rejected"): |
| | for feature in features: |
| | target_feature = { |
| | "input_ids": feature["{}_input_ids".format(key)], |
| | "attention_mask": feature["{}_attention_mask".format(key)], |
| | "labels": feature["{}_labels".format(key)], |
| | "images": feature["images"], |
| | "videos": feature["videos"], |
| | } |
| | concatenated_features.append(target_feature) |
| |
|
| | return super().__call__(concatenated_features) |
| |
|
| |
|
| | @dataclass |
| | class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq): |
| | r""" |
| | Data collator for KTO data. |
| | """ |
| |
|
| | def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]: |
| | target_features = [] |
| | kl_features = [] |
| | kto_tags = [] |
| | for feature in features: |
| | target_feature = { |
| | "input_ids": feature["input_ids"], |
| | "attention_mask": feature["attention_mask"], |
| | "labels": feature["labels"], |
| | "images": feature["images"], |
| | "videos": feature["videos"], |
| | } |
| | kl_feature = { |
| | "input_ids": feature["kl_input_ids"], |
| | "attention_mask": feature["kl_attention_mask"], |
| | "labels": feature["kl_labels"], |
| | "images": feature["images"], |
| | "videos": feature["videos"], |
| | } |
| | target_features.append(target_feature) |
| | kl_features.append(kl_feature) |
| | kto_tags.append(feature["kto_tags"]) |
| |
|
| | batch = super().__call__(target_features) |
| | kl_batch = super().__call__(kl_features) |
| | batch["kl_input_ids"] = kl_batch["input_ids"] |
| | batch["kl_attention_mask"] = kl_batch["attention_mask"] |
| | batch["kl_labels"] = kl_batch["labels"] |
| | if "token_type_ids" in kl_batch: |
| | batch["kl_token_type_ids"] = kl_batch["token_type_ids"] |
| |
|
| | batch["kto_tags"] = torch.tensor(kto_tags) |
| | return batch |
| |
|