| | """ |
| | 2025.10.1 |
| | 2025.10.1 |
| | 4.56.2 |
| | 0.22.2 |
| | __UNSLOTH_VERSIONING__ |
| | """ |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from torch import Tensor |
| | import torch |
| | import torch.nn as nn |
| | from torch.nn import functional as F |
| | from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable |
| | from trl.trainer.dpo_trainer import (Any, AutoModelForCausalLM, AutoTokenizer, BaseImageProcessor, Callable, DPOConfig, DPOTrainer, DataCollator, DataCollatorForPreference, DataLoader, Dataset, EvalLoopOutput, F, FDivergenceConstants, FDivergenceType, FeatureExtractionMixin, IterableDataset, Literal, MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, Optional, PartialState, Path, PeftConfig, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, RunningMoments, SyncRefModelCallback, Trainer, TrainerCallback, Union, autocast, cap_exp, contextmanager, create_reference_model, dataclass, defaultdict, disable_dropout_in_model, empty_cache, flush_left, flush_right, generate_model_card, get_comet_experiment_url, get_peft_model, inspect, is_comet_available, is_liger_kernel_available, is_mlflow_available, is_peft_available, is_wandb_available, log_table_to_comet_experiment, logger, logging, maybe_apply_chat_template, maybe_extract_prompt, nn, nullcontext, os, pad, pad_to_length, pd, peft_module_casting_to_bf16, prepare_deepspeed, prepare_fsdp, prepare_model_for_kbit_training, random, selective_log_softmax, shift_tokens_right, textwrap, torch, tqdm, wandb, F, Optional, PeftModel, PreTrainedModel, Trainer, is_peft_available, logger, os, torch) |
| |
|
| |
|
| | import os |
| | from typing import * |
| | from dataclasses import dataclass, field |
| | from packaging.version import Version |
| | import torch |
| | import numpy as np |
| | from contextlib import nullcontext |
| | from torch.nn import functional as F |
| | from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling |
| | from transformers.training_args import ParallelMode |
| |
|
| | |
| | import functools |
| | from types import MethodType |
| | def prepare_for_training_mode(f): |
| | @functools.wraps(f) |
| | def wrapper(self, *args, **kwargs): |
| | |
| | if hasattr(self, 'model') and hasattr(self.model, "for_training"): |
| | self.model.for_training() |
| | output = f(self, *args, **kwargs) |
| | |
| | if hasattr(self, 'model') and hasattr(self.model, "for_inference"): |
| | self.model.for_inference() |
| | return output |
| | return wrapper |
| | pass |
| |
|
| | torch_compile_options = { |
| | "epilogue_fusion" : True, |
| | "max_autotune" : False, |
| | "shape_padding" : True, |
| | "trace.enabled" : False, |
| | "triton.cudagraphs" : False, |
| | } |
| |
|
| | @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) |
| | def chunked_selective_log_softmax(logits, index): |
| | |
| | chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0) |
| | chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0) |
| | all_per_token_logps = [] |
| | |
| | for chunk_logits, chunk_index in zip(chunked_logits, chunked_index): |
| | chunk_logits = chunk_logits.to(torch.float32) |
| | selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1) |
| | logsumexp_values = torch.logsumexp(chunk_logits, dim = -1) |
| | per_token_logps = selected_logits - logsumexp_values |
| | all_per_token_logps.append(per_token_logps) |
| | pass |
| | all_per_token_logps = torch.concat(all_per_token_logps) |
| | all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1])) |
| | return all_per_token_logps |
| |
|
| | def calculate_pad_tokens_in_prompt( |
| | input_ids: torch.Tensor, |
| | logits_to_keep: int, |
| | pad_token_id: int |
| | ) -> torch.Tensor: |
| | """ |
| | Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens |
| | """ |
| | if logits_to_keep >= input_ids.shape[1]: |
| | raise ValueError("logits_to_keep must be smaller than the sequence length.") |
| |
|
| | prompt_section = input_ids[:, :-logits_to_keep] |
| |
|
| | padding_mask = (prompt_section == pad_token_id) |
| |
|
| | pad_token_counts = padding_mask.sum(dim=1) |
| |
|
| | return pad_token_counts |
| |
|
| | def create_completion_attention_mask( |
| | completion_input_ids: torch.Tensor, |
| | left_pad_tokens_per_prompt: torch.Tensor, |
| | max_left_pad: int, |
| | pad_token_id: int |
| | ) -> torch.Tensor: |
| | """ |
| | Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad] |
| | |
| | Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens |
| | and pad are pad tokens, this function would make a completion mask that would 0 out the pad |
| | and p tokens. so in this example [0,0,0,1,1,1,0,0,0] |
| | """ |
| | batch_size, completion_len = completion_input_ids.shape |
| | device = completion_input_ids.device |
| |
|
| | num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt |
| |
|
| | indices = torch.arange(completion_len, device=device).unsqueeze(0) |
| | shift_mask = indices >= num_tokens_to_mask.unsqueeze(1) |
| |
|
| | non_padding_mask = (completion_input_ids != pad_token_id) |
| |
|
| | final_mask = shift_mask & non_padding_mask |
| |
|
| | return final_mask |
| |
|
| | def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor: |
| | """ |
| | Moves all padding tokens in each sequence of a batch to the right. |
| | """ |
| | mask = (tensor != pad_id) |
| | |
| | sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True) |
| | packed_tensor = torch.gather(tensor, 1, sorted_indices) |
| | return packed_tensor |
| | @dataclass |
| | class UnslothDPOConfig(DPOConfig): |
| | """ |
| | |
| | Configuration class for the [`DPOTrainer`]. |
| | |
| | This class includes only the parameters that are specific to DPO training. For a full list of training arguments, |
| | please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may |
| | differ from those in [`~transformers.TrainingArguments`]. |
| | |
| | Using [`~transformers.HfArgumentParser`] we can turn this class into |
| | [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the |
| | command line. |
| | |
| | Parameters: |
| | > Parameters that control the model and reference model |
| | |
| | model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`): |
| | Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `model` argument of the |
| | [`DPOTrainer`] is provided as a string. |
| | ref_model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`): |
| | Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `ref_model` argument of the |
| | [`DPOTrainer`] is provided as a string. |
| | model_adapter_name (`str` or `None`, *optional*, defaults to `None`): |
| | Name of the train target PEFT adapter, when using LoRA with multiple adapters. |
| | ref_adapter_name (`str` or `None`, *optional*, defaults to `None`): |
| | Name of the reference PEFT adapter, when using LoRA with multiple adapters. |
| | force_use_ref_model (`bool`, *optional*, defaults to `False`): |
| | If you provide a PEFT model as the active model and wish to use a different model for the `ref_model`, set |
| | this flag to `True`. |
| | disable_dropout (`bool`, *optional*, defaults to `True`): |
| | Whether to disable dropout in the model and reference model. |
| | use_logits_to_keep (`bool`, *optional*, defaults to `False`): |
| | If `True`, only a specified number of logits are computed in the forward pass. This can be useful for |
| | saving memory and speeding up training by not computing the logits for all tokens, especially in scenarios |
| | when working with very long prompts where labels are ignored (-100). |
| | |
| | > Parameters that control the data preprocessing |
| | |
| | dataset_num_proc (`int` or `None`, *optional*, defaults to `None`): |
| | Number of processes to use for processing the dataset. |
| | padding_value (`int` or `None`, *optional*, defaults to `None`): |
| | Padding value to use. If `None`, the padding value of the tokenizer is used. |
| | label_pad_token_id (`int`, *optional*, defaults to `-100`): |
| | Padding value to use for labels. |
| | max_prompt_length (`int` or `None`, *optional*, defaults to `512`): |
| | Maximum length of the prompt. |
| | max_completion_length (`int` or `None`, *optional*, defaults to `None`): |
| | Maximum length of the completion. |
| | max_length (`int` or `None`, *optional*, defaults to `1024`): |
| | Maximum length of the full sequence (prompt + completion). |
| | truncation_mode (`str`, *optional*, defaults to `"keep_end"`): |
| | Truncation mode to use when the sequence exceeds `max_length`. Possible values are `"keep_end"` and |
| | `"keep_start"`. |
| | padding_free (`bool`, *optional*, defaults to `False`): |
| | Whether to perform forward passes without padding by flattening all sequences in the batch into a single |
| | continuous sequence. This reduces memory usage by eliminating padding overhead. Currently, this is only |
| | supported with the `flash_attention_2` attention implementation, which can efficiently handle the flattened |
| | batch structure. |
| | precompute_ref_log_probs (`bool`, *optional*, defaults to `False`): |
| | Whether to precompute the log probabilities from the reference model. Setting this to `True` allows |
| | training without needing the reference model during training, which can help reduce GPU memory usage. If |
| | set to `False` (default), the reference model will be used during training to compute log probabilities |
| | on-the-fly. |
| | precompute_ref_batch_size (`int` or `None`, *optional*, defaults to `None`): |
| | Batch size to use when precomputing reference model log probabilities. This can be set higher than the |
| | training batch size to speed up preprocessing. If `None`, defaults to `per_device_train_batch_size` for |
| | training and `per_device_eval_batch_size` for evaluation. |
| | tools (`Optional[list[Union[dict, Callable]]]`, *optional*, defaults to `None`): |
| | List of tools (callable functions) that will be accessible to the model. If the template does not support |
| | function calling, this argument will have no effect. |
| | |
| | > Parameters that control the training |
| | |
| | loss_type (`str` or `list[str]`, *optional*, defaults to `"sigmoid"`): |
| | Type of loss to use. Possible values are: |
| | |
| | - `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper. |
| | - `"hinge"`: hinge loss on the normalized likelihood from the |
| | [SLiC](https://huggingface.co/papers/2305.10425) paper. |
| | - `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper. |
| | - `"exo_pair"`: pairwise EXO loss from the [EXO](https://huggingface.co/papers/2402.00856) paper. |
| | - `"nca_pair"`: pairwise NCA loss from the [NCA](https://huggingface.co/papers/2402.05369) paper. |
| | - `"robust"`: unbiased estimate of the DPO loss that is robust to preference noise from the [Robust |
| | DPO](https://huggingface.co/papers/2403.00409) paper. |
| | - `"bco_pair"`: pairwise BCO loss from the [BCO](https://huggingface.co/papers/2404.04656) paper. |
| | - `"sppo_hard"`: SPPO loss with hard label from the [SPPO](https://huggingface.co/papers/2405.00675) |
| | paper. |
| | - `"aot"`: AOT loss for paired datasets from the [AOT](https://huggingface.co/papers/2406.05882) paper. |
| | - `"aot_pair"`: AOT loss for unpaired datasets from the [AOT](https://huggingface.co/papers/2406.05882) |
| | paper. |
| | - `"discopop"`: DiscoPOP (a.k.a Log-Ratio Modulated Loss, LRML) loss from the |
| | [DiscoPOP](https://huggingface.co/papers/2406.08414) paper. |
| | - `"apo_zero"`: APO-zero loss from the [APO](https://huggingface.co/papers/2408.06266) paper. |
| | - `"apo_down"`: APO-down loss from the [APO](https://huggingface.co/papers/2408.06266) paper. |
| | - `"sft"`: Negative log-likelihood loss (standard supervised fine-tuning loss). |
| | |
| | Multiple loss types can be combined using comma separation (e.g., `["sigmoid", "bco_pair", "sft"]` for |
| | [MPO](https://huggingface.co/papers/2411.10442)). The `loss_weights` parameter can be used to specify |
| | corresponding weights for each loss type. |
| | |
| | use_liger_loss (`bool`, *optional*, defaults to `False`): |
| | Whether to use Liger loss. |
| | base_model_attribute_name (`str`, *optional*, defaults to `"model"`): |
| | Name of the attribute in the model that contains the base model. This is used to get the base model from |
| | the model when the model does not have a `get_decoder` method in the case when `use_liger_loss` is `True`. |
| | beta (`float`, *optional*, defaults to `0.1`): |
| | Parameter controlling the deviation from the reference model. Higher β means less deviation from the |
| | reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in |
| | the [paper](https://huggingface.co/papers/2310.12036). |
| | f_divergence_type (`str`, *optional*, defaults to `FDivergenceType.REVERSE_KL`): |
| | Type of f-divergence regularization function to compute divergence between policy and reference model. |
| | f_alpha_divergence_coef (`float`, *optional*, defaults to `1.0`): |
| | α coefficient in the α-divergence u^-α regularization function for DPO loss. |
| | reference_free (`bool`, *optional*, defaults to `False`): |
| | Whether to ignore the provided reference model and implicitly use a reference model that assigns equal |
| | probability to all responses. |
| | label_smoothing (`float`, *optional*, defaults to `0.0`): |
| | Robust DPO label smoothing parameter from the [cDPO report](https://ericmitchell.ai/cdpo.pdf) and [Robust |
| | DPO](https://huggingface.co/papers/2403.00409) paper that should be between `0.0` and `0.5`. |
| | use_weighting (`bool`, *optional*, defaults to `False`): |
| | Whether to weight the loss as done in the [WPO paper](https://huggingface.co/papers/2406.11827). |
| | rpo_alpha (`float`, *optional*, defaults to `None`): |
| | α parameter from the [RPO paper](https://huggingface.co/papers/2404.19733) (v3), which controls the |
| | weighting of the NLL term in the loss. If `None`, no weighting is applied and the loss is the same as the |
| | DPO loss. The paper recommends `rpo_alpha=1.0`. |
| | ld_alpha (`float` or `None`, *optional*, defaults to `None`): |
| | α parameter from the [LD-DPO paper](https://huggingface.co/papers/2409.06411), which controls the weighting |
| | of the verbose token log-probabilities in responses. If `None`, no weighting is applied to the verbose |
| | part, and the loss is equivalent to the standard DPO loss. The paper recommends setting `ld_alpha` between |
| | `0.0` and `1.0`. |
| | discopop_tau (`float`, *optional*, defaults to `0.05`): |
| | τ/temperature parameter from the [DiscoPOP](https://huggingface.co/papers/2406.08414) paper, which controls |
| | the shape of log ratio modulated loss. The paper recommends the default value `discopop_tau=0.05`. |
| | loss_weights (`list[float]` or `None`, *optional*, defaults to `None`): |
| | List of loss weights for multi-loss combinations. Used when combining multiple loss types. Example: `[0.8, |
| | 0.2, 1.0]` for [MPO](https://huggingface.co/papers/2411.10442). If not provided, defaults to equal weights |
| | (`1.0`) for all loss types. |
| | sync_ref_model (`bool`, *optional*, defaults to `False`): |
| | Whether to synchronize the reference model with the active model every `ref_model_sync_steps` steps, using |
| | the `ref_model_mixup_alpha` parameter. This synchronization originates from the |
| | [TR-DPO](https://huggingface.co/papers/2404.09656) paper. |
| | ref_model_mixup_alpha (`float`, *optional*, defaults to `0.6`): |
| | α parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which controls the mix |
| | between the current policy and the previous reference policy during updates. The reference policy is |
| | updated according to the equation: `π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you |
| | must set `sync_ref_model=True`. |
| | ref_model_sync_steps (`int`, *optional*, defaults to `512`): |
| | τ parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which determines how |
| | frequently the current policy is synchronized with the reference policy. To use this parameter, you must |
| | set `sync_ref_model=True`. |
| | |
| | > Parameters that control the logging |
| | |
| | generate_during_eval (`bool`, *optional*, defaults to `False`): |
| | Whether to generate and log completions from both the model and the reference model to W&B or Comet during |
| | evaluation. |
| | |
| | """ |
| | vllm_sampling_params: Optional[Any] = field( |
| | default = None, |
| | metadata = {'help': 'vLLM SamplingParams'}, |
| | ) |
| | unsloth_num_chunks : Optional[int] = field( |
| | default = -1, |
| | metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'}, |
| | ) |
| | max_seq_length : Optional[int] = field( |
| | default = None, |
| | metadata = {'help': 'Maximum sequence length to truncate to.'}, |
| | ) |
| | def __init__( |
| | self, |
| | output_dir = None, |
| | overwrite_output_dir = None, |
| | do_train = False, |
| | do_eval = False, |
| | do_predict = False, |
| | eval_strategy = 'no', |
| | prediction_loss_only = False, |
| | per_device_train_batch_size = 4, |
| | per_device_eval_batch_size = 4, |
| | per_gpu_train_batch_size = None, |
| | per_gpu_eval_batch_size = None, |
| | gradient_accumulation_steps = 2, |
| | eval_accumulation_steps = 2, |
| | eval_delay = 0, |
| | torch_empty_cache_steps = 250, |
| | learning_rate = 5e-05, |
| | weight_decay = 0.01, |
| | adam_beta1 = 0.9, |
| | adam_beta2 = 0.999, |
| | adam_epsilon = 1e-08, |
| | max_grad_norm = 1.0, |
| | num_train_epochs = 3.0, |
| | max_steps = -1, |
| | lr_scheduler_type = 'linear', |
| | warmup_ratio = 0.1, |
| | warmup_steps = 0, |
| | log_level = 'passive', |
| | log_level_replica = 'warning', |
| | log_on_each_node = True, |
| | logging_dir = None, |
| | logging_strategy = 'steps', |
| | logging_first_step = False, |
| | logging_steps = 1, |
| | logging_nan_inf_filter = False, |
| | save_strategy = 'steps', |
| | save_steps = 500, |
| | save_total_limit = None, |
| | save_safetensors = True, |
| | save_on_each_node = False, |
| | save_only_model = False, |
| | restore_callback_states_from_checkpoint = False, |
| | no_cuda = False, |
| | use_cpu = False, |
| | use_mps_device = False, |
| | seed = 3407, |
| | data_seed = 3407, |
| | jit_mode_eval = False, |
| | use_ipex = False, |
| | bf16 = False, |
| | fp16 = False, |
| | fp16_opt_level = 'O1', |
| | half_precision_backend = 'auto', |
| | bf16_full_eval = False, |
| | fp16_full_eval = False, |
| | tf32 = None, |
| | local_rank = -1, |
| | ddp_backend = None, |
| | tpu_num_cores = None, |
| | tpu_metrics_debug = False, |
| | debug = '', |
| | dataloader_drop_last = False, |
| | eval_steps = None, |
| | dataloader_num_workers = 0, |
| | dataloader_prefetch_factor = None, |
| | past_index = -1, |
| | run_name = None, |
| | disable_tqdm = None, |
| | remove_unused_columns = True, |
| | label_names = None, |
| | load_best_model_at_end = False, |
| | metric_for_best_model = None, |
| | greater_is_better = None, |
| | ignore_data_skip = False, |
| | fsdp = '', |
| | fsdp_min_num_params = 0, |
| | fsdp_config = None, |
| | fsdp_transformer_layer_cls_to_wrap = None, |
| | accelerator_config = None, |
| | parallelism_config = None, |
| | deepspeed = None, |
| | label_smoothing_factor = 0.0, |
| | optim = 'adamw_8bit', |
| | optim_args = None, |
| | adafactor = False, |
| | group_by_length = False, |
| | length_column_name = 'length', |
| | report_to = None, |
| | ddp_find_unused_parameters = None, |
| | ddp_bucket_cap_mb = None, |
| | ddp_broadcast_buffers = None, |
| | dataloader_pin_memory = True, |
| | dataloader_persistent_workers = False, |
| | skip_memory_metrics = True, |
| | use_legacy_prediction_loop = False, |
| | push_to_hub = False, |
| | resume_from_checkpoint = None, |
| | hub_model_id = None, |
| | hub_strategy = 'every_save', |
| | hub_token = None, |
| | hub_private_repo = None, |
| | hub_always_push = False, |
| | hub_revision = None, |
| | gradient_checkpointing = True, |
| | gradient_checkpointing_kwargs = None, |
| | include_inputs_for_metrics = False, |
| | eval_do_concat_batches = True, |
| | fp16_backend = 'auto', |
| | push_to_hub_model_id = None, |
| | push_to_hub_organization = None, |
| | push_to_hub_token = None, |
| | mp_parameters = '', |
| | auto_find_batch_size = False, |
| | full_determinism = False, |
| | torchdynamo = None, |
| | ray_scope = 'last', |
| | ddp_timeout = 1800, |
| | torch_compile = False, |
| | torch_compile_backend = None, |
| | torch_compile_mode = None, |
| | include_tokens_per_second = False, |
| | include_num_input_tokens_seen = False, |
| | neftune_noise_alpha = None, |
| | optim_target_modules = None, |
| | batch_eval_metrics = False, |
| | eval_on_start = False, |
| | use_liger_kernel = False, |
| | liger_kernel_config = None, |
| | eval_use_gather_object = False, |
| | average_tokens_across_devices = True, |
| | model_init_kwargs = None, |
| | ref_model_init_kwargs = None, |
| | model_adapter_name = None, |
| | ref_adapter_name = None, |
| | force_use_ref_model = False, |
| | disable_dropout = True, |
| | use_logits_to_keep = False, |
| | dataset_num_proc = None, |
| | padding_value = None, |
| | label_pad_token_id = -100, |
| | max_prompt_length = 512, |
| | max_completion_length = None, |
| | max_length = 1024, |
| | truncation_mode = 'keep_end', |
| | padding_free = False, |
| | precompute_ref_log_probs = False, |
| | precompute_ref_batch_size = None, |
| | tools = None, |
| | use_liger_loss = False, |
| | base_model_attribute_name = 'model', |
| | beta = 0.1, |
| | f_alpha_divergence_coef = 1.0, |
| | reference_free = False, |
| | label_smoothing = 0.0, |
| | use_weighting = False, |
| | rpo_alpha = None, |
| | ld_alpha = None, |
| | discopop_tau = 0.05, |
| | loss_weights = None, |
| | sync_ref_model = False, |
| | ref_model_mixup_alpha = 0.6, |
| | ref_model_sync_steps = 512, |
| | generate_during_eval = False, |
| | vllm_sampling_params = None, |
| | unsloth_num_chunks = -1, |
| | max_seq_length = None, |
| | **kwargs, |
| | ): |
| | if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!') |
| | if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!') |
| | if output_dir is None and save_strategy == 'steps' and save_steps == 500: |
| | output_dir = 'unsloth_training_checkpoints' |
| | save_strategy = 'no' |
| | if dataset_num_proc is None: |
| | from multiprocessing import cpu_count |
| | dataset_num_proc = max(cpu_count()+4, 2) |
| | |
| | super().__init__( |
| | output_dir = output_dir, |
| | overwrite_output_dir = overwrite_output_dir, |
| | do_train = do_train, |
| | do_eval = do_eval, |
| | do_predict = do_predict, |
| | eval_strategy = eval_strategy, |
| | prediction_loss_only = prediction_loss_only, |
| | per_device_train_batch_size = per_device_train_batch_size, |
| | per_device_eval_batch_size = per_device_eval_batch_size, |
| | per_gpu_train_batch_size = per_gpu_train_batch_size, |
| | per_gpu_eval_batch_size = per_gpu_eval_batch_size, |
| | gradient_accumulation_steps = gradient_accumulation_steps, |
| | eval_accumulation_steps = eval_accumulation_steps, |
| | eval_delay = eval_delay, |
| | torch_empty_cache_steps = torch_empty_cache_steps, |
| | learning_rate = learning_rate, |
| | weight_decay = weight_decay, |
| | adam_beta1 = adam_beta1, |
| | adam_beta2 = adam_beta2, |
| | adam_epsilon = adam_epsilon, |
| | max_grad_norm = max_grad_norm, |
| | num_train_epochs = num_train_epochs, |
| | max_steps = max_steps, |
| | lr_scheduler_type = lr_scheduler_type, |
| | warmup_ratio = warmup_ratio, |
| | warmup_steps = warmup_steps, |
| | log_level = log_level, |
| | log_level_replica = log_level_replica, |
| | log_on_each_node = log_on_each_node, |
| | logging_dir = logging_dir, |
| | logging_strategy = logging_strategy, |
| | logging_first_step = logging_first_step, |
| | logging_steps = logging_steps, |
| | logging_nan_inf_filter = logging_nan_inf_filter, |
| | save_strategy = save_strategy, |
| | save_steps = save_steps, |
| | save_total_limit = save_total_limit, |
| | save_safetensors = save_safetensors, |
| | save_on_each_node = save_on_each_node, |
| | save_only_model = save_only_model, |
| | restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint, |
| | no_cuda = no_cuda, |
| | use_cpu = use_cpu, |
| | use_mps_device = use_mps_device, |
| | seed = seed, |
| | data_seed = data_seed, |
| | jit_mode_eval = jit_mode_eval, |
| | use_ipex = use_ipex, |
| | bf16 = bf16, |
| | fp16 = fp16, |
| | fp16_opt_level = fp16_opt_level, |
| | half_precision_backend = half_precision_backend, |
| | bf16_full_eval = bf16_full_eval, |
| | fp16_full_eval = fp16_full_eval, |
| | tf32 = tf32, |
| | local_rank = local_rank, |
| | ddp_backend = ddp_backend, |
| | tpu_num_cores = tpu_num_cores, |
| | tpu_metrics_debug = tpu_metrics_debug, |
| | debug = debug, |
| | dataloader_drop_last = dataloader_drop_last, |
| | eval_steps = eval_steps, |
| | dataloader_num_workers = dataloader_num_workers, |
| | dataloader_prefetch_factor = dataloader_prefetch_factor, |
| | past_index = past_index, |
| | run_name = run_name, |
| | disable_tqdm = disable_tqdm, |
| | remove_unused_columns = remove_unused_columns, |
| | label_names = label_names, |
| | load_best_model_at_end = load_best_model_at_end, |
| | metric_for_best_model = metric_for_best_model, |
| | greater_is_better = greater_is_better, |
| | ignore_data_skip = ignore_data_skip, |
| | fsdp = fsdp, |
| | fsdp_min_num_params = fsdp_min_num_params, |
| | fsdp_config = fsdp_config, |
| | fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap, |
| | accelerator_config = accelerator_config, |
| | parallelism_config = parallelism_config, |
| | deepspeed = deepspeed, |
| | label_smoothing_factor = label_smoothing_factor, |
| | optim = optim, |
| | optim_args = optim_args, |
| | adafactor = adafactor, |
| | group_by_length = group_by_length, |
| | length_column_name = length_column_name, |
| | report_to = report_to, |
| | ddp_find_unused_parameters = ddp_find_unused_parameters, |
| | ddp_bucket_cap_mb = ddp_bucket_cap_mb, |
| | ddp_broadcast_buffers = ddp_broadcast_buffers, |
| | dataloader_pin_memory = dataloader_pin_memory, |
| | dataloader_persistent_workers = dataloader_persistent_workers, |
| | skip_memory_metrics = skip_memory_metrics, |
| | use_legacy_prediction_loop = use_legacy_prediction_loop, |
| | push_to_hub = push_to_hub, |
| | resume_from_checkpoint = resume_from_checkpoint, |
| | hub_model_id = hub_model_id, |
| | hub_strategy = hub_strategy, |
| | hub_token = hub_token, |
| | hub_private_repo = hub_private_repo, |
| | hub_always_push = hub_always_push, |
| | hub_revision = hub_revision, |
| | gradient_checkpointing = gradient_checkpointing, |
| | gradient_checkpointing_kwargs = gradient_checkpointing_kwargs, |
| | include_inputs_for_metrics = include_inputs_for_metrics, |
| | eval_do_concat_batches = eval_do_concat_batches, |
| | fp16_backend = fp16_backend, |
| | push_to_hub_model_id = push_to_hub_model_id, |
| | push_to_hub_organization = push_to_hub_organization, |
| | push_to_hub_token = push_to_hub_token, |
| | mp_parameters = mp_parameters, |
| | auto_find_batch_size = auto_find_batch_size, |
| | full_determinism = full_determinism, |
| | torchdynamo = torchdynamo, |
| | ray_scope = ray_scope, |
| | ddp_timeout = ddp_timeout, |
| | torch_compile = torch_compile, |
| | torch_compile_backend = torch_compile_backend, |
| | torch_compile_mode = torch_compile_mode, |
| | include_tokens_per_second = include_tokens_per_second, |
| | include_num_input_tokens_seen = include_num_input_tokens_seen, |
| | neftune_noise_alpha = neftune_noise_alpha, |
| | optim_target_modules = optim_target_modules, |
| | batch_eval_metrics = batch_eval_metrics, |
| | eval_on_start = eval_on_start, |
| | use_liger_kernel = use_liger_kernel, |
| | liger_kernel_config = liger_kernel_config, |
| | eval_use_gather_object = eval_use_gather_object, |
| | average_tokens_across_devices = average_tokens_across_devices, |
| | model_init_kwargs = model_init_kwargs, |
| | ref_model_init_kwargs = ref_model_init_kwargs, |
| | model_adapter_name = model_adapter_name, |
| | ref_adapter_name = ref_adapter_name, |
| | force_use_ref_model = force_use_ref_model, |
| | disable_dropout = disable_dropout, |
| | use_logits_to_keep = use_logits_to_keep, |
| | dataset_num_proc = dataset_num_proc, |
| | padding_value = padding_value, |
| | label_pad_token_id = label_pad_token_id, |
| | max_prompt_length = max_prompt_length, |
| | max_completion_length = max_completion_length, |
| | max_length = max_length, |
| | truncation_mode = truncation_mode, |
| | padding_free = padding_free, |
| | precompute_ref_log_probs = precompute_ref_log_probs, |
| | precompute_ref_batch_size = precompute_ref_batch_size, |
| | tools = tools, |
| | use_liger_loss = use_liger_loss, |
| | base_model_attribute_name = base_model_attribute_name, |
| | beta = beta, |
| | f_alpha_divergence_coef = f_alpha_divergence_coef, |
| | reference_free = reference_free, |
| | label_smoothing = label_smoothing, |
| | use_weighting = use_weighting, |
| | rpo_alpha = rpo_alpha, |
| | ld_alpha = ld_alpha, |
| | discopop_tau = discopop_tau, |
| | loss_weights = loss_weights, |
| | sync_ref_model = sync_ref_model, |
| | ref_model_mixup_alpha = ref_model_mixup_alpha, |
| | ref_model_sync_steps = ref_model_sync_steps, |
| | generate_during_eval = generate_during_eval,**kwargs) |
| | self.vllm_sampling_params = vllm_sampling_params |
| | self.unsloth_num_chunks = unsloth_num_chunks |
| | self.max_seq_length = max_seq_length |
| | pass |
| |
|
| | class _UnslothDPOTrainer(Trainer): |
| | """""" |
| |
|
| | _tag_names = ["trl", "dpo"] |
| |
|
| | def __init__( |
| | self, |
| | model: Union[str, nn.Module, PreTrainedModel], |
| | ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, |
| | args: Optional[DPOConfig] = None, |
| | data_collator: Optional[DataCollator] = None, |
| | train_dataset: Optional[Union[Dataset, IterableDataset]] = None, |
| | eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None, |
| | processing_class: Optional[ |
| | Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] |
| | ] = None, |
| | compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None, |
| | callbacks: Optional[list[TrainerCallback]] = None, |
| | optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), |
| | optimizer_cls_and_kwargs: Optional[tuple[type[torch.optim.Optimizer], dict[str, Any]]] = None, |
| | preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, |
| | peft_config: Optional["PeftConfig"] = None, |
| | ): |
| | |
| | model_id = model if isinstance(model, str) else model.config._name_or_path |
| | if args is None: |
| | model_name = model_id.split("/")[-1] |
| | args = DPOConfig(f"{model_name}-DPO") |
| |
|
| | |
| | if processing_class is None: |
| | processing_class = AutoTokenizer.from_pretrained(model_id) |
| |
|
| | if args.padding_value is not None: |
| | self.padding_value = args.padding_value |
| | else: |
| | if hasattr(processing_class, "pad_token_id") and processing_class.pad_token_id is not None: |
| | self.padding_value = processing_class.pad_token_id |
| | elif hasattr(processing_class, "tokenizer") and processing_class.tokenizer.pad_token_id is not None: |
| | self.padding_value = processing_class.tokenizer.pad_token_id |
| | else: |
| | raise ValueError( |
| | "`padding_value` is not specified in `DPOConfig`, and `pad_token_id` is missing in the " |
| | "`processing_class`. Please either set the `padding_value` argument in `DPOConfig`, or set " |
| | "`tokenizer.pad_token` (e.g., `tokenizer.pad_token = tokenizer.eos_token`) before instantiating " |
| | "the trainer." |
| | ) |
| |
|
| | |
| | if not isinstance(model, str) and ref_model is model: |
| | raise ValueError( |
| | "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the " |
| | "same as `model`, you must mass a copy of it, or `None` if you use peft." |
| | ) |
| |
|
| | if args.model_init_kwargs is not None and not isinstance(model, str): |
| | logger.warning( |
| | "You passed model_init_kwargs to the `DPOConfig`, but your model is already instantiated. " |
| | "The `model_init_kwargs` will be ignored." |
| | ) |
| | if isinstance(model, str): |
| | model = self._create_model_from_path(model, args) |
| |
|
| | if args.ref_model_init_kwargs is not None and not isinstance(ref_model, str): |
| | logger.warning( |
| | "You passed ref_model_init_kwargs to the `DPOConfig`, but your ref_model is already instantiated. " |
| | "The `ref_model_init_kwargs` will be ignored." |
| | ) |
| | if isinstance(ref_model, str): |
| | ref_model = self._create_model_from_path(ref_model, args, is_ref=True) |
| |
|
| | |
| | model = self._prepare_peft_model(model, ref_model, peft_config, args) |
| |
|
| | if args.generate_during_eval and not (is_wandb_available() or is_comet_available() or is_mlflow_available()): |
| | raise ValueError( |
| | "`generate_during_eval=True` requires Weights and Biases, MLFlow or Comet to be installed." |
| | " Please install `wandb`, `mlflow` or `comet-ml` to resolve." |
| | ) |
| |
|
| | self.is_encoder_decoder = model.config.is_encoder_decoder |
| | self.is_vision_model = model.config.model_type in MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES.keys() |
| | self.is_peft_model = is_peft_available() and isinstance(model, PeftModel) |
| | self.model_adapter_name = args.model_adapter_name |
| | self.ref_adapter_name = args.ref_adapter_name |
| | self.reference_free = args.reference_free |
| |
|
| | if ref_model: |
| | self.ref_model = ref_model |
| | elif self.is_peft_model or args.precompute_ref_log_probs: |
| | |
| | self.ref_model = None |
| | else: |
| | self.ref_model = create_reference_model(model) |
| |
|
| | |
| | if args.disable_dropout: |
| | disable_dropout_in_model(model) |
| | if self.ref_model is not None: |
| | disable_dropout_in_model(self.ref_model) |
| |
|
| | |
| | if args.use_liger_loss: |
| | if not is_liger_kernel_available(): |
| | raise ImportError( |
| | "You set `use_liger_loss=True` but the liger kernel is not available. " |
| | "Please install liger-kernel first: `pip install liger-kernel`" |
| | ) |
| | if args.loss_type not in ["sigmoid", "apo_zero", "apo_down", "sppo_hard", "nca_pair"]: |
| | raise ValueError( |
| | "You set `use_liger_loss=True` but the loss type is not from `[sigmoid, apo_zero, apo_down, sppo_hard, nca_pair`. " |
| | "Please set `loss_type='[sigmoid | apo_zero | apo_down | sppo_hard | nca_pair]'` to use the liger kernel." |
| | ) |
| | self.dpo_loss_fn = LigerFusedLinearDPOLoss( |
| | ignore_index=args.label_pad_token_id, |
| | beta=args.beta, |
| | use_ref_model=not args.reference_free, |
| | average_log_prob=False, |
| | loss_type=args.loss_type, |
| | ) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | model.warnings_issued["estimate_tokens"] = True |
| |
|
| | |
| | if data_collator is None: |
| | data_collator = DataCollatorForPreference(pad_token_id=self.padding_value) |
| |
|
| | self.generate_during_eval = args.generate_during_eval |
| | self.label_pad_token_id = args.label_pad_token_id |
| | self.max_prompt_length = args.max_prompt_length |
| | self.max_completion_length = args.max_completion_length |
| | self.max_length = args.max_length |
| | self.truncation_mode = args.truncation_mode |
| | self.precompute_ref_log_probs = args.precompute_ref_log_probs |
| | self.use_logits_to_keep = args.use_logits_to_keep |
| |
|
| | if args.padding_free: |
| | if model.config._attn_implementation != "flash_attention_2": |
| | logger.warning( |
| | "Padding-free training is enabled, but the attention implementation is not set to " |
| | "'flash_attention_2'. Padding-free training flattens batches into a single sequence, and " |
| | "'flash_attention_2' is the only known attention mechanism that reliably supports this. Using " |
| | "other implementations may lead to unexpected behavior. To ensure compatibility, set " |
| | "`attn_implementation='flash_attention_2'` in the model configuration, or verify that your " |
| | "attention mechanism can handle flattened sequences." |
| | ) |
| | if args.per_device_train_batch_size == 1: |
| | logger.warning( |
| | "You are using a per_device_train_batch_size of 1 with padding-free training. Using a batch size " |
| | "of 1 anihilate the benefits of padding-free training. Please consider increasing the batch size " |
| | "to at least 2." |
| | ) |
| | self.padding_free = args.padding_free |
| |
|
| | |
| | |
| | self._precomputed_train_ref_log_probs = False |
| | self._precomputed_eval_ref_log_probs = False |
| |
|
| | self.beta = args.beta |
| | self.label_smoothing = args.label_smoothing |
| | self.loss_type = args.loss_type if isinstance(args.loss_type, list) else [args.loss_type] |
| | self.loss_weights = args.loss_weights |
| | self.aux_loss_enabled = getattr(model.config, "output_router_logits", False) |
| | self.use_weighting = args.use_weighting |
| | self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0) |
| | if self.aux_loss_enabled and self.aux_loss_coef == 0.0: |
| | logger.warning( |
| | "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to " |
| | "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value " |
| | "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary " |
| | "loss.", |
| | ) |
| | for loss_type in self.loss_type: |
| | if ( |
| | loss_type in ["hinge", "ipo", "bco_pair", "sppo_hard", "nca_pair", "apo_zero", "apo_down"] |
| | and args.label_smoothing > 0 |
| | ): |
| | logger.warning( |
| | f"You are using the {loss_type} loss type that does not support label smoothing. The " |
| | "`label_smoothing` parameter will be ignored. Set `label_smoothing` to `0.0` to remove this " |
| | "warning.", |
| | ) |
| | if loss_type == "kto_pair": |
| | raise ValueError("Support for kto_pair has been removed in DPOTrainer. Please use KTOTrainer.") |
| |
|
| | self._stored_metrics = defaultdict(lambda: defaultdict(list)) |
| | self.f_divergence_type = args.f_divergence_type |
| | self.f_divergence_params = {FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY: args.f_alpha_divergence_coef} |
| | self.dataset_num_proc = args.dataset_num_proc |
| |
|
| | |
| | train_dataset = self._prepare_dataset(train_dataset, processing_class, args, "train") |
| | if eval_dataset is not None: |
| | if isinstance(eval_dataset, dict): |
| | eval_dataset = { |
| | key: self._prepare_dataset(dataset, processing_class, args, key) |
| | for key, dataset in eval_dataset.items() |
| | } |
| | else: |
| | eval_dataset = self._prepare_dataset(eval_dataset, processing_class, args, "eval") |
| |
|
| | super().__init__( |
| | model=model, |
| | args=args, |
| | data_collator=data_collator, |
| | train_dataset=train_dataset, |
| | eval_dataset=eval_dataset, |
| | processing_class=processing_class, |
| | compute_metrics=compute_metrics, |
| | callbacks=callbacks, |
| | optimizers=optimizers, |
| | optimizer_cls_and_kwargs=optimizer_cls_and_kwargs, |
| | preprocess_logits_for_metrics=preprocess_logits_for_metrics, |
| | ) |
| |
|
| | |
| | |
| | |
| | self.model_accepts_loss_kwargs = False |
| |
|
| | |
| | if hasattr(self.model, "add_model_tags"): |
| | self.model.add_model_tags(self._tag_names) |
| |
|
| | if not hasattr(self, "accelerator"): |
| | raise AttributeError( |
| | "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`." |
| | ) |
| |
|
| | |
| | if self.is_deepspeed_enabled: |
| | if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs: |
| | raise ValueError( |
| | "You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`." |
| | ) |
| |
|
| | if self.ref_model is None: |
| | if not (self.is_peft_model or self.precompute_ref_log_probs): |
| | raise ValueError( |
| | "No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`" |
| | ) |
| | if args.sync_ref_model: |
| | raise ValueError( |
| | "You currently cannot use `ref_model=None` with TR-DPO method. Please provide `ref_model`." |
| | ) |
| | else: |
| | if self.is_deepspeed_enabled: |
| | self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator) |
| | elif self.is_fsdp_enabled: |
| | self.ref_model = prepare_fsdp(self.ref_model, self.accelerator) |
| | else: |
| | self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) |
| |
|
| | if args.sync_ref_model: |
| | if self.precompute_ref_log_probs: |
| | raise ValueError( |
| | "You cannot use `precompute_ref_log_probs=True` with TR-DPO method. Please set `precompute_ref_log_probs=False`." |
| | ) |
| |
|
| | self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator)) |
| |
|
| | if "bco_pair" in self.loss_type: |
| | self.running = RunningMoments(self.accelerator) |
| |
|
| | def _create_model_from_path(self, model_path: str, args: DPOConfig, is_ref: bool = False) -> PreTrainedModel: |
| | """Creates a model from a path or model identifier.""" |
| | if not is_ref: |
| | model_init_kwargs = args.model_init_kwargs or {} |
| | else: |
| | model_init_kwargs = args.ref_model_init_kwargs or {} |
| |
|
| | |
| | torch_dtype = model_init_kwargs.get("torch_dtype") |
| | if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None: |
| | pass |
| | elif isinstance(torch_dtype, str): |
| | torch_dtype = getattr(torch, torch_dtype) |
| | model_init_kwargs["torch_dtype"] = torch_dtype |
| | else: |
| | raise ValueError( |
| | "Invalid `torch_dtype` passed to `DPOConfig`. Expected either 'auto' or a string representing " |
| | f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}." |
| | ) |
| |
|
| | |
| | model = AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs) |
| | return model |
| |
|
| | def _prepare_peft_model( |
| | self, model: PreTrainedModel, ref_model: PreTrainedModel, peft_config: Any, args: DPOConfig |
| | ) -> PreTrainedModel: |
| | """Prepares a model for PEFT training.""" |
| | |
| | |
| | self._peft_has_been_casted_to_bf16 = False |
| |
|
| | if not is_peft_available() and peft_config is not None: |
| | raise ValueError( |
| | "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models" |
| | ) |
| | elif is_peft_available() and peft_config is not None: |
| | |
| | if isinstance(model, PeftModel): |
| | model = model.merge_and_unload() |
| |
|
| | if ref_model is not None and not args.force_use_ref_model: |
| | raise ValueError( |
| | "You passed both a ref_model and a peft_config. For training PEFT adapters with DPO there is no need to pass a reference" |
| | " model. Please pass `ref_model=None` in case you want to train PEFT adapters, or pass a ref_model with `force_use_ref_model=True` in DPOTrainer's init." |
| | " if you want to use a different ref_model." |
| | ) |
| |
|
| | if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): |
| | _support_gc_kwargs = hasattr( |
| | args, "gradient_checkpointing_kwargs" |
| | ) and "gradient_checkpointing_kwargs" in list( |
| | inspect.signature(prepare_model_for_kbit_training).parameters |
| | ) |
| |
|
| | prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing} |
| |
|
| | if _support_gc_kwargs: |
| | prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs |
| |
|
| | model = prepare_model_for_kbit_training(model, **prepare_model_kwargs) |
| |
|
| | else: |
| | model = self._prepare_gradient_checkpointing(model, args) |
| |
|
| | |
| | model = get_peft_model(model, peft_config) |
| | if args.bf16 and getattr(model, "is_loaded_in_4bit", False): |
| | peft_module_casting_to_bf16(model) |
| | |
| | self._peft_has_been_casted_to_bf16 = True |
| |
|
| | else: |
| | model = self._prepare_gradient_checkpointing(model, args) |
| |
|
| | return model |
| |
|
| | def _prepare_gradient_checkpointing(self, model: PreTrainedModel, args: DPOConfig): |
| | """Prepare the gradienting checkpointing for the model.""" |
| | |
| | |
| | |
| | if args.gradient_checkpointing: |
| | |
| | if hasattr(model, "enable_input_require_grads"): |
| | model.enable_input_require_grads() |
| | else: |
| |
|
| | def make_inputs_require_grad(module, input, output): |
| | output.requires_grad_(True) |
| |
|
| | model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) |
| |
|
| | return model |
| |
|
| | def _prepare_dataset( |
| | self, |
| | dataset: Union[Dataset, IterableDataset], |
| | processing_class: Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin], |
| | args: DPOConfig, |
| | dataset_name: str, |
| | ) -> Union[Dataset, IterableDataset]: |
| | |
| | map_kwargs = {} |
| | if isinstance(dataset, Dataset): |
| | map_kwargs["num_proc"] = args.dataset_num_proc |
| | map_kwargs["writer_batch_size"] = 10 |
| |
|
| | with PartialState().main_process_first(): |
| | |
| | if isinstance(dataset, Dataset): |
| | map_kwargs["desc"] = f"Extracting prompt in {dataset_name} dataset" |
| | dataset = dataset.map(maybe_extract_prompt, **map_kwargs) |
| |
|
| | |
| | if isinstance(dataset, Dataset): |
| | map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset" |
| | dataset = dataset.map( |
| | maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class, "tools": args.tools}, **map_kwargs |
| | ) |
| |
|
| | |
| | if isinstance(dataset, Dataset): |
| | map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset" |
| |
|
| | dataset = dataset.map( |
| | self.tokenize_row if not self.is_vision_model else self.process_row, |
| | remove_columns=["chosen", "rejected"], |
| | fn_kwargs={ |
| | "processing_class": processing_class, |
| | "max_prompt_length": args.max_prompt_length, |
| | "max_completion_length": args.max_completion_length, |
| | |
| | "add_special_tokens": False, |
| | }, |
| | **map_kwargs, |
| | ) |
| |
|
| | return dataset |
| |
|
| | @staticmethod |
| | def tokenize_row( |
| | features: dict[str, str], |
| | processing_class: PreTrainedTokenizerBase, |
| | max_prompt_length: Optional[int] = None, |
| | max_completion_length: Optional[int] = None, |
| | add_special_tokens: bool = True, |
| | ) -> dict[str, list[int]]: |
| | """ |
| | Tokenize a row of the dataset. |
| | |
| | Args: |
| | features (`dict[str, str]`): |
| | Row of the dataset, should contain the keys `"prompt"`, `"chosen"`, and `"rejected"`. |
| | processing_class (`PreTrainedTokenizerBase`): |
| | Processing class used to process the data. |
| | max_prompt_length (`int` or `None`): |
| | Maximum length of the prompt sequence. If `None`, the prompt sequence is not truncated. |
| | max_completion_length (`int` or `None`): |
| | Maximum length of the completion sequences. If `None`, the completion sequences are not truncated. |
| | add_special_tokens (`bool`): |
| | Whether to add special tokens to the sequences. Typically used for encoder-decoder models. If `True`, |
| | the prompt sequence will have a bos token prepended and an eos token appended. In any case, the |
| | completion sequences will have an eos token appended. |
| | |
| | Returns: |
| | `dict[str, list[int]]`: |
| | Tokenized sequences with the keys `"prompt_input_ids"`, `"chosen_input_ids"`, and |
| | `"rejected_input_ids". |
| | |
| | Example: |
| | ```python |
| | >>> from transformers import GPT2Tokenizer |
| | |
| | >>> tokenizer = GPT2Tokenizer.from_pretrained("gpt2") |
| | >>> features = {"prompt": "The sky is", "chosen": " blue", "rejected": " green"} |
| | >>> DPOTrainer.tokenize_row( |
| | ... features, tokenizer, max_prompt_length=3, max_completion_length=3, add_special_tokens=False |
| | ... ) |
| | {'prompt_input_ids': [464, 6766, 318], 'chosen_input_ids': [4171, 50256], 'rejected_input_ids': [4077, 50256]} |
| | ``` |
| | """ |
| | tokenizer = processing_class |
| | prompt_input_ids = tokenizer(features["prompt"], add_special_tokens=False)["input_ids"] |
| | chosen_input_ids = tokenizer(features["chosen"], add_special_tokens=False)["input_ids"] |
| | rejected_input_ids = tokenizer(features["rejected"], add_special_tokens=False)["input_ids"] |
| |
|
| | |
| | if add_special_tokens: |
| | if tokenizer.bos_token_id is not None: |
| | prompt_input_ids = [tokenizer.bos_token_id] + prompt_input_ids |
| | if tokenizer.eos_token_id is not None: |
| | prompt_input_ids = prompt_input_ids + [tokenizer.eos_token_id] |
| | chosen_input_ids = chosen_input_ids + [tokenizer.eos_token_id] |
| | rejected_input_ids = rejected_input_ids + [tokenizer.eos_token_id] |
| |
|
| | |
| | if max_prompt_length is not None: |
| | prompt_input_ids = prompt_input_ids[-max_prompt_length:] |
| | if max_completion_length is not None: |
| | chosen_input_ids = chosen_input_ids[:max_completion_length] |
| | rejected_input_ids = rejected_input_ids[:max_completion_length] |
| |
|
| | return { |
| | "prompt_input_ids": prompt_input_ids, |
| | "chosen_input_ids": chosen_input_ids, |
| | "rejected_input_ids": rejected_input_ids, |
| | } |
| |
|
| | @staticmethod |
| | def process_row( |
| | features: dict[str, str], |
| | processing_class: PreTrainedTokenizerBase, |
| | max_prompt_length: Optional[int] = None, |
| | max_completion_length: Optional[int] = None, |
| | add_special_tokens: bool = True, |
| | ) -> dict[str, list[int]]: |
| | """ |
| | Same as `tokenize_row` but for vision models. Please refer to `tokenize_row` for more information. |
| | """ |
| | processor, tokenizer = processing_class, processing_class.tokenizer |
| | processed_features = processor(images=features["images"], text=features["prompt"], add_special_tokens=False) |
| |
|
| | prompt_input_ids = processed_features["input_ids"][0] |
| | pixel_values = processed_features["pixel_values"][0] |
| | chosen_input_ids = tokenizer(features["chosen"], add_special_tokens=False)["input_ids"] |
| | rejected_input_ids = tokenizer(features["rejected"], add_special_tokens=False)["input_ids"] |
| |
|
| | |
| | if add_special_tokens: |
| | if tokenizer.bos_token_id is not None: |
| | prompt_input_ids = [tokenizer.bos_token_id] + prompt_input_ids |
| | if tokenizer.eos_token_id is not None: |
| | prompt_input_ids = prompt_input_ids + [tokenizer.eos_token_id] |
| | chosen_input_ids = chosen_input_ids + [tokenizer.eos_token_id] |
| | rejected_input_ids = rejected_input_ids + [tokenizer.eos_token_id] |
| |
|
| | |
| | if max_prompt_length is not None: |
| | prompt_input_ids = prompt_input_ids[-max_prompt_length:] |
| | if max_completion_length is not None: |
| | chosen_input_ids = chosen_input_ids[:max_completion_length] |
| | rejected_input_ids = rejected_input_ids[:max_completion_length] |
| |
|
| | output = { |
| | "prompt_input_ids": prompt_input_ids, |
| | "pixel_values": pixel_values, |
| | "chosen_input_ids": chosen_input_ids, |
| | "rejected_input_ids": rejected_input_ids, |
| | } |
| |
|
| | if "pixel_attention_mask" in processed_features: |
| | output["pixel_attention_mask"] = processed_features["pixel_attention_mask"][0] |
| | if "image_sizes" in processed_features: |
| | output["image_sizes"] = processed_features["image_sizes"][0] |
| |
|
| | return output |
| |
|
| | def _set_signature_columns_if_needed(self): |
| | |
| | |
| | |
| | |
| | if self._signature_columns is None: |
| | self._signature_columns = [ |
| | "prompt_input_ids", |
| | "chosen_input_ids", |
| | "rejected_input_ids", |
| | "image_sizes", |
| | "ref_chosen_logps", |
| | "ref_rejected_logps", |
| | ] |
| |
|
| | def get_train_dataloader(self) -> DataLoader: |
| | """ |
| | Returns the training [`~torch.utils.data.DataLoader`]. |
| | |
| | Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`. |
| | """ |
| |
|
| | if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs: |
| | batch_size = self.args.precompute_ref_batch_size or self.args.per_device_train_batch_size |
| | dataloader_params = { |
| | "batch_size": batch_size, |
| | "collate_fn": self.data_collator, |
| | "num_workers": self.args.dataloader_num_workers, |
| | "pin_memory": self.args.dataloader_pin_memory, |
| | "shuffle": False, |
| | } |
| |
|
| | |
| | data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params)) |
| |
|
| | ref_chosen_logps = [] |
| | ref_rejected_logps = [] |
| | for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"): |
| | ref_chosen_logp, ref_rejected_logp = self.compute_ref_log_probs(padded_batch) |
| | ref_chosen_logp, ref_rejected_logp = self.accelerator.gather_for_metrics( |
| | (ref_chosen_logp, ref_rejected_logp) |
| | ) |
| | ref_chosen_logps.append(ref_chosen_logp.cpu()) |
| | ref_rejected_logps.append(ref_rejected_logp.cpu()) |
| |
|
| | |
| | empty_cache() |
| | self.accelerator.free_memory() |
| |
|
| | all_ref_chosen_logps = torch.cat(ref_chosen_logps).float().numpy() |
| | all_ref_rejected_logps = torch.cat(ref_rejected_logps).float().numpy() |
| |
|
| | self.train_dataset = self.train_dataset.add_column(name="ref_chosen_logps", column=all_ref_chosen_logps) |
| | self.train_dataset = self.train_dataset.add_column( |
| | name="ref_rejected_logps", column=all_ref_rejected_logps |
| | ) |
| |
|
| | self._precomputed_train_ref_log_probs = True |
| |
|
| | return super().get_train_dataloader() |
| |
|
| | def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: |
| | """ |
| | Returns the evaluation [`~torch.utils.data.DataLoader`]. |
| | |
| | Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`. |
| | |
| | Args: |
| | eval_dataset (`torch.utils.data.Dataset`, *optional*): |
| | If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted |
| | by the `model.forward()` method are automatically removed. It must implement `__len__`. |
| | """ |
| | if eval_dataset is None and self.eval_dataset is None: |
| | raise ValueError("Trainer: evaluation requires an eval_dataset.") |
| | eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset |
| |
|
| | if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs: |
| | batch_size = self.args.precompute_ref_batch_size or self.args.per_device_eval_batch_size |
| | dataloader_params = { |
| | "batch_size": batch_size, |
| | "collate_fn": self.data_collator, |
| | "num_workers": self.args.dataloader_num_workers, |
| | "pin_memory": self.args.dataloader_pin_memory, |
| | "shuffle": False, |
| | } |
| |
|
| | |
| | data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params)) |
| |
|
| | ref_chosen_logps = [] |
| | ref_rejected_logps = [] |
| | for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"): |
| | ref_chosen_logp, ref_rejected_logp = self.compute_ref_log_probs(padded_batch) |
| | ref_chosen_logp, ref_rejected_logp = self.accelerator.gather_for_metrics( |
| | (ref_chosen_logp, ref_rejected_logp) |
| | ) |
| | ref_chosen_logps.append(ref_chosen_logp.cpu()) |
| | ref_rejected_logps.append(ref_rejected_logp.cpu()) |
| |
|
| | all_ref_chosen_logps = torch.cat(ref_chosen_logps).float().numpy() |
| | all_ref_rejected_logps = torch.cat(ref_rejected_logps).float().numpy() |
| |
|
| | eval_dataset = eval_dataset.add_column(name="ref_chosen_logps", column=all_ref_chosen_logps) |
| | eval_dataset = eval_dataset.add_column(name="ref_rejected_logps", column=all_ref_rejected_logps) |
| |
|
| | |
| | if self.eval_dataset is not None: |
| | self.eval_dataset = eval_dataset |
| | self._precomputed_eval_ref_log_probs = True |
| |
|
| | return super().get_eval_dataloader(eval_dataset=eval_dataset) |
| |
|
| | @contextmanager |
| | def null_ref_context(self): |
| | """Context manager for handling null reference model (that is, peft adapter manipulation).""" |
| | with ( |
| | self.accelerator.unwrap_model(self.model).disable_adapter() |
| | if self.is_peft_model and not self.ref_adapter_name |
| | else nullcontext() |
| | ): |
| | if self.ref_adapter_name: |
| | self.model.set_adapter(self.ref_adapter_name) |
| | yield |
| | if self.ref_adapter_name: |
| | self.model.set_adapter(self.model_adapter_name or "default") |
| |
|
| | def compute_ref_log_probs(self, batch: dict[str, torch.LongTensor]) -> tuple[torch.Tensor, torch.Tensor]: |
| | """Computes log probabilities of the reference model for a single padded batch of a DPO specific dataset.""" |
| | compte_ref_context_manager = ( |
| | autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() |
| | ) |
| | with torch.no_grad(), compte_ref_context_manager: |
| | if self.ref_model is None: |
| | with self.null_ref_context(): |
| | ref_model_output = self.concatenated_forward(self.model, batch, is_ref_model=True) |
| | else: |
| | ref_model_output = self.concatenated_forward(self.ref_model, batch, is_ref_model=True) |
| | return ref_model_output["chosen_logps"], ref_model_output["rejected_logps"] |
| |
|
| | @staticmethod |
| | def concatenated_inputs( |
| | batch: dict[str, Union[list, torch.LongTensor]], padding_value: int |
| | ) -> dict[str, torch.LongTensor]: |
| | """ |
| | Concatenate the `chosen` and `rejected` inputs from the batch into a single tensor for both the prompt and |
| | completion sequences. |
| | |
| | Args: |
| | batch (`dict[str, Union[list, torch.LongTensor]]`): |
| | A batch of input data. The batch must contain the following keys: |
| | |
| | - `"prompt_input_ids"`: Tensor of shape `(batch_size, prompt_length)` representing the prompt input |
| | IDs. |
| | - `"chosen_input_ids"`: Tensor of shape `(batch_size, chosen_length)` representing the chosen |
| | completion input IDs. |
| | - `"rejected_input_ids"`: Tensor of shape `(batch_size, rejected_length)` representing the rejected |
| | completion input IDs. |
| | - `"prompt_pixel_values"` (optional): Tensor for pixel values, if available. |
| | - `"prompt_pixel_attention_mask"` (optional): Tensor for pixel attention masks, if available. |
| | |
| | padding_value (`int`): |
| | The padding value to use for the concatenated completion sequences (`chosen_input_ids` and |
| | `rejected_input_ids`). |
| | |
| | Returns: |
| | `dict[str, torch.LongTensor]`: A dictionary containing: |
| | |
| | - `"prompt_input_ids"`: Concatenated prompt input IDs of shape `(2 * batch_size, prompt_length)`. |
| | - `"completion_input_ids"`: Concatenated chosen and rejected completion input IDs of shape `(2 * |
| | batch_size, max_completion_length)`. |
| | - `"prompt_attention_mask"`: Concatenated prompt attention masks of shape `(2 * batch_size, |
| | prompt_length)`. |
| | - `"completion_attention_mask"`: Concatenated chosen and rejected attention masks of shape `(2 * |
| | batch_size, max_completion_length)`. |
| | - `"pixel_values"` (optional): Concatenated pixel values if `"prompt_pixel_values"` are present. |
| | - `"pixel_attention_mask"` (optional): Concatenated pixel attention masks if |
| | `"prompt_pixel_attention_mask"` are present. |
| | |
| | Notes: |
| | The completion input IDs and attention masks are padded to the maximum completion length of the chosen or |
| | rejected sequences. |
| | """ |
| | output = {} |
| |
|
| | |
| | output["prompt_input_ids"] = torch.cat([batch["prompt_input_ids"], batch["prompt_input_ids"]], dim=0) |
| | output["prompt_attention_mask"] = torch.cat( |
| | [batch["prompt_attention_mask"], batch["prompt_attention_mask"]], dim=0 |
| | ) |
| | if "pixel_values" in batch: |
| | output["pixel_values"] = torch.cat([batch["pixel_values"], batch["pixel_values"]], dim=0) |
| |
|
| | if "pixel_attention_mask" in batch: |
| | output["pixel_attention_mask"] = torch.cat( |
| | [batch["pixel_attention_mask"], batch["pixel_attention_mask"]], dim=0 |
| | ) |
| | if "image_sizes" in batch: |
| | output["image_sizes"] = torch.cat([batch["image_sizes"], batch["image_sizes"]], dim=0) |
| |
|
| | |
| | max_completion_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1]) |
| | output["completion_input_ids"] = torch.cat( |
| | ( |
| | pad_to_length(batch["chosen_input_ids"], max_completion_length, pad_value=padding_value), |
| | pad_to_length(batch["rejected_input_ids"], max_completion_length, pad_value=padding_value), |
| | ), |
| | ) |
| | output["completion_attention_mask"] = torch.cat( |
| | ( |
| | pad_to_length(batch["chosen_attention_mask"], max_completion_length, pad_value=0), |
| | pad_to_length(batch["rejected_attention_mask"], max_completion_length, pad_value=0), |
| | ), |
| | ) |
| |
|
| | return output |
| |
|
| | def dpo_loss( |
| | self, |
| | chosen_logps: torch.FloatTensor, |
| | rejected_logps: torch.FloatTensor, |
| | ref_chosen_logps: torch.FloatTensor, |
| | ref_rejected_logps: torch.FloatTensor, |
| | loss_type: str = "sigmoid", |
| | model_output: dict[str, torch.FloatTensor] = None, |
| | ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: |
| | """ |
| | Compute the DPO loss for a batch of policy and reference model log probabilities. |
| | |
| | Args: |
| | chosen_logps (`torch.FloatTensor`): |
| | Log probabilities of the model for the chosen responses. Shape: `(batch_size,)`. |
| | rejected_logps (`torch.FloatTensor`): |
| | Log probabilities of the model for the rejected responses. Shape: `(batch_size,)`. |
| | ref_chosen_logps (`torch.FloatTensor`): |
| | Log probabilities of the reference model for the chosen responses. Shape: `(batch_size,)`. |
| | ref_rejected_logps (`torch.FloatTensor`): |
| | Log probabilities of the reference model for the rejected responses. Shape: `(batch_size,)`. |
| | |
| | Returns: |
| | A tuple of three tensors: `(losses, chosen_rewards, rejected_rewards)`. The losses tensor contains the DPO |
| | loss for each example in the batch. The `chosen_rewards` and `rejected_rewards` tensors contain the rewards |
| | for the chosen and rejected responses, respectively. |
| | """ |
| | device = self.accelerator.device |
| |
|
| | |
| | chosen_logratios = chosen_logps.to(device) - (not self.reference_free) * ref_chosen_logps.to(device) |
| | rejected_logratios = rejected_logps.to(device) - (not self.reference_free) * ref_rejected_logps.to(device) |
| |
|
| | if self.f_divergence_type == FDivergenceType.ALPHA_DIVERGENCE.value: |
| | |
| | |
| | |
| | |
| | |
| | |
| | alpha_coef = FDivergenceConstants.ALPHA_DIVERGENCE_COEF_DEFAULT |
| | if self.f_divergence_params and FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY in self.f_divergence_params: |
| | alpha_coef = float(self.f_divergence_params[FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY]) |
| | logits = (cap_exp(rejected_logratios * -alpha_coef) - cap_exp(chosen_logratios * -alpha_coef)) / alpha_coef |
| | else: |
| | logratios = chosen_logps - rejected_logps |
| | if self.reference_free: |
| | ref_logratios = torch.tensor([0], dtype=logratios.dtype, device=logratios.device) |
| | else: |
| | ref_logratios = ref_chosen_logps - ref_rejected_logps |
| |
|
| | logratios = logratios.to(self.accelerator.device) |
| | ref_logratios = ref_logratios.to(self.accelerator.device) |
| | logits = logratios - ref_logratios |
| |
|
| | if self.f_divergence_type == FDivergenceType.JS_DIVERGENCE.value: |
| | |
| | |
| | |
| | |
| | |
| | |
| | logits -= F.softplus(chosen_logratios) - F.softplus(rejected_logratios) |
| |
|
| | |
| | |
| | |
| | if loss_type == "sigmoid": |
| | losses = ( |
| | -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) |
| | - F.logsigmoid(-self.beta * logits) * self.label_smoothing |
| | ) |
| |
|
| | elif loss_type == "robust": |
| | losses = ( |
| | -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) |
| | + F.logsigmoid(-self.beta * logits) * self.label_smoothing |
| | ) / (1 - 2 * self.label_smoothing) |
| |
|
| | elif loss_type == "exo_pair": |
| | |
| | import math |
| |
|
| | if self.label_smoothing == 0: |
| | self.label_smoothing = 1e-3 |
| | losses = (self.beta * logits).sigmoid() * ( |
| | F.logsigmoid(self.beta * logits) - math.log(1 - self.label_smoothing) |
| | ) + (-self.beta * logits).sigmoid() * (F.logsigmoid(-self.beta * logits) - math.log(self.label_smoothing)) |
| |
|
| | elif loss_type == "hinge": |
| | losses = torch.relu(1 - self.beta * logits) |
| |
|
| | elif loss_type == "ipo": |
| | |
| | losses = (logits - 1 / (2 * self.beta)) ** 2 |
| |
|
| | elif loss_type == "bco_pair": |
| | chosen_logratios = chosen_logps - ref_chosen_logps |
| | rejected_logratios = rejected_logps - ref_rejected_logps |
| | chosen_rewards = self.beta * chosen_logratios |
| | rejected_rewards = self.beta * rejected_logratios |
| | rewards = torch.cat((chosen_rewards, rejected_rewards), 0).mean().detach() |
| | self.running.update(rewards) |
| | delta = self.running.mean |
| | losses = -F.logsigmoid((self.beta * chosen_logratios) - delta) - F.logsigmoid( |
| | -(self.beta * rejected_logratios - delta) |
| | ) |
| |
|
| | elif loss_type == "sppo_hard": |
| | |
| | |
| | |
| | |
| | a = chosen_logps - ref_chosen_logps |
| | b = rejected_logps - ref_rejected_logps |
| | losses = (a - 0.5 / self.beta) ** 2 + (b + 0.5 / self.beta) ** 2 |
| |
|
| | elif loss_type == "nca_pair": |
| | chosen_rewards = (chosen_logps - ref_chosen_logps) * self.beta |
| | rejected_rewards = (rejected_logps - ref_rejected_logps) * self.beta |
| | losses = ( |
| | -F.logsigmoid(chosen_rewards) |
| | - 0.5 * F.logsigmoid(-chosen_rewards) |
| | - 0.5 * F.logsigmoid(-rejected_rewards) |
| | ) |
| |
|
| | elif loss_type == "aot_pair": |
| | chosen_logratios = chosen_logps - ref_chosen_logps |
| | rejected_logratios = rejected_logps - ref_rejected_logps |
| | chosen_logratios_sorted, _ = torch.sort(chosen_logratios, dim=0) |
| | rejected_logratios_sorted, _ = torch.sort(rejected_logratios, dim=0) |
| | delta = chosen_logratios_sorted - rejected_logratios_sorted |
| | losses = ( |
| | -F.logsigmoid(self.beta * delta) * (1 - self.label_smoothing) |
| | - F.logsigmoid(-self.beta * delta) * self.label_smoothing |
| | ) |
| |
|
| | elif loss_type == "aot": |
| | logratios = chosen_logps - rejected_logps |
| | ref_logratios = ref_chosen_logps - ref_rejected_logps |
| | logratios_sorted, _ = torch.sort(logratios, dim=0) |
| | ref_logratios_sorted, _ = torch.sort(ref_logratios, dim=0) |
| | delta = logratios_sorted - ref_logratios_sorted |
| | losses = ( |
| | -F.logsigmoid(self.beta * delta) * (1 - self.label_smoothing) |
| | - F.logsigmoid(-self.beta * delta) * self.label_smoothing |
| | ) |
| |
|
| | elif loss_type == "apo_zero": |
| | |
| | |
| | losses_chosen = 1 - F.sigmoid(self.beta * chosen_logratios) |
| | losses_rejected = F.sigmoid(self.beta * rejected_logratios) |
| | losses = losses_chosen + losses_rejected |
| |
|
| | elif loss_type == "apo_down": |
| | |
| | |
| | |
| | losses_chosen = F.sigmoid(self.beta * chosen_logratios) |
| | losses_rejected = 1 - F.sigmoid(self.beta * (chosen_logratios - rejected_logratios)) |
| | losses = losses_chosen + losses_rejected |
| |
|
| | elif loss_type == "discopop": |
| | |
| | |
| | logratios = chosen_logps - rejected_logps |
| | ref_logratios = ref_chosen_logps - ref_rejected_logps |
| | logits = logratios - ref_logratios |
| | logits = logits * self.beta |
| | |
| | log_ratio_modulation = torch.sigmoid(logits / self.args.discopop_tau) |
| | logistic_component = -F.logsigmoid(logits) |
| | exp_component = torch.exp(-logits) |
| | |
| | losses = logistic_component * (1 - log_ratio_modulation) + exp_component * log_ratio_modulation |
| |
|
| | elif loss_type == "sft": |
| | |
| | |
| | sft_loss = model_output["nll_loss"] |
| | |
| | batch_size = chosen_logps.shape[0] |
| | losses = sft_loss.expand(batch_size) |
| | |
| | chosen_rewards = torch.zeros_like(chosen_logps) |
| | rejected_rewards = torch.zeros_like(rejected_logps) |
| |
|
| | else: |
| | raise ValueError( |
| | f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'exo_pair', " |
| | "'nca_pair', 'robust', 'bco_pair', 'sppo_hard', 'aot', 'aot_pair', 'discopop', 'apo_zero', " |
| | "'apo_down', 'sft']" |
| | ) |
| |
|
| | chosen_rewards = self.beta * (chosen_logps.to(device) - ref_chosen_logps.to(device)).detach() |
| | rejected_rewards = self.beta * (rejected_logps.to(device) - ref_rejected_logps.to(device)).detach() |
| |
|
| | return losses, chosen_rewards, rejected_rewards |
| |
|
| | def _compute_loss_liger( |
| | self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]] |
| | ) -> dict[str, torch.Tensor]: |
| | unwrapped_model = self.accelerator.unwrap_model(model) |
| | concatenated_batch = self.concatenated_inputs(batch, padding_value=self.padding_value) |
| |
|
| | model_kwargs = {} |
| | if self.aux_loss_enabled: |
| | model_kwargs["output_router_logits"] = True |
| |
|
| | |
| | if "pixel_values" in concatenated_batch: |
| | model_kwargs["pixel_values"] = concatenated_batch["pixel_values"] |
| | if "pixel_attention_mask" in concatenated_batch: |
| | model_kwargs["pixel_attention_mask"] = concatenated_batch["pixel_attention_mask"] |
| | if "image_sizes" in concatenated_batch: |
| | model_kwargs["image_sizes"] = concatenated_batch["image_sizes"] |
| |
|
| | prompt_attention_mask = concatenated_batch["prompt_attention_mask"] |
| | completion_attention_mask = concatenated_batch["completion_attention_mask"] |
| |
|
| | if self.is_encoder_decoder: |
| | |
| | encoder_outputs = unwrapped_model.get_encoder()( |
| | concatenated_batch["prompt_input_ids"], |
| | attention_mask=concatenated_batch["prompt_attention_mask"], |
| | return_dict=True, |
| | ) |
| | |
| | decoder_input_ids = shift_tokens_right( |
| | concatenated_batch["completion_input_ids"], |
| | unwrapped_model.config.decoder_start_token_id, |
| | ) |
| | |
| | decoder_outputs = unwrapped_model.get_decoder()( |
| | input_ids=decoder_input_ids, |
| | attention_mask=concatenated_batch["completion_attention_mask"], |
| | encoder_hidden_states=encoder_outputs.last_hidden_state, |
| | encoder_attention_mask=concatenated_batch["prompt_attention_mask"], |
| | use_cache=False, |
| | ) |
| | hidden_states = decoder_outputs.last_hidden_state |
| |
|
| | ref_hidden_states = None |
| | if not self.reference_free and self.ref_model is not None: |
| | unwrapped_ref_model = self.accelerator.unwrap_model(self.ref_model) |
| | ref_encoder_outputs = unwrapped_ref_model.get_encoder()( |
| | concatenated_batch["prompt_input_ids"], |
| | attention_mask=concatenated_batch["prompt_attention_mask"], |
| | return_dict=True, |
| | ) |
| | ref_decoder_outputs = unwrapped_ref_model.get_decoder()( |
| | input_ids=decoder_input_ids, |
| | attention_mask=concatenated_batch["completion_attention_mask"], |
| | encoder_hidden_states=ref_encoder_outputs.last_hidden_state, |
| | encoder_attention_mask=concatenated_batch["prompt_attention_mask"], |
| | use_cache=False, |
| | ) |
| | ref_hidden_states = ref_decoder_outputs.last_hidden_state |
| | elif not self.reference_free: |
| | with self.null_ref_context(): |
| | ref_encoder_outputs = unwrapped_model.get_encoder()( |
| | concatenated_batch["prompt_input_ids"], |
| | attention_mask=concatenated_batch["prompt_attention_mask"], |
| | return_dict=True, |
| | ) |
| | ref_decoder_outputs = unwrapped_model.get_decoder()( |
| | input_ids=decoder_input_ids, |
| | attention_mask=concatenated_batch["completion_attention_mask"], |
| | encoder_hidden_states=ref_encoder_outputs.last_hidden_state, |
| | encoder_attention_mask=concatenated_batch["prompt_attention_mask"], |
| | use_cache=False, |
| | ) |
| | ref_hidden_states = ref_decoder_outputs.last_hidden_state |
| |
|
| | labels = concatenated_batch["completion_input_ids"] |
| | loss_mask = completion_attention_mask.bool() |
| | else: |
| | |
| | input_ids = torch.cat( |
| | (concatenated_batch["prompt_input_ids"], concatenated_batch["completion_input_ids"]), dim=1 |
| | ) |
| | attention_mask = torch.cat( |
| | (concatenated_batch["prompt_attention_mask"], concatenated_batch["completion_attention_mask"]), |
| | dim=1, |
| | ) |
| | |
| | loss_mask = torch.cat( |
| | (torch.zeros_like(prompt_attention_mask), completion_attention_mask), |
| | dim=1, |
| | ) |
| |
|
| | |
| | if self.max_length is not None and self.max_length < attention_mask.size(1): |
| | if self.truncation_mode == "keep_start": |
| | |
| | |
| | |
| | attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask) |
| | attention_mask = attention_mask[:, : self.max_length] |
| | input_ids = input_ids[:, : self.max_length] |
| | loss_mask = loss_mask[:, : self.max_length] |
| | elif self.truncation_mode == "keep_end": |
| | |
| | |
| | |
| | attention_mask, input_ids, loss_mask = flush_right(attention_mask, input_ids, loss_mask) |
| | input_ids = input_ids[:, -self.max_length :] |
| | attention_mask = attention_mask[:, -self.max_length :] |
| | loss_mask = loss_mask[:, -self.max_length :] |
| | attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask) |
| | else: |
| | raise ValueError( |
| | f"Unknown truncation mode: '{self.truncation_mode}'. Should be one of ['keep_end', " |
| | "'keep_start']." |
| | ) |
| | else: |
| | |
| | |
| | |
| | attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask) |
| |
|
| | |
| | if self.use_logits_to_keep: |
| | first_compute_index = loss_mask.nonzero(as_tuple=True)[1].min() |
| | logits_to_keep = (loss_mask.shape[1] - first_compute_index).item() + 1 |
| | model_kwargs["logits_to_keep"] = logits_to_keep |
| |
|
| | model_kwargs["output_hidden_states"] = True |
| |
|
| | |
| | if self.padding_free: |
| | input_ids = input_ids[attention_mask.bool()].unsqueeze(0) |
| | loss_mask = loss_mask[attention_mask.bool()].unsqueeze(0) |
| | position_ids = attention_mask.cumsum(1)[attention_mask.bool()].unsqueeze(0) - 1 |
| | model_kwargs["position_ids"] = position_ids |
| | else: |
| | model_kwargs["attention_mask"] = attention_mask |
| |
|
| | |
| | if hasattr(unwrapped_model, "get_decoder") and unwrapped_model.get_decoder() is not None: |
| | base_model = unwrapped_model.get_decoder() |
| | else: |
| | base_attr = getattr(unwrapped_model, "base_model_prefix", self.args.base_model_attribute_name) |
| | base_model = getattr(unwrapped_model, base_attr, unwrapped_model) |
| |
|
| | outputs = base_model( |
| | input_ids, |
| | use_cache=False, |
| | **model_kwargs, |
| | ) |
| | hidden_states = outputs.last_hidden_state[:, :-1] |
| |
|
| | |
| | ref_hidden_states = None |
| | if not self.reference_free and self.ref_model is not None: |
| | unwrapped_ref_model = self.accelerator.unwrap_model(self.ref_model) |
| | if hasattr(unwrapped_ref_model, "get_decoder") and unwrapped_ref_model.get_decoder() is not None: |
| | ref_base_model = unwrapped_ref_model.get_decoder() |
| | else: |
| | ref_attr = getattr(unwrapped_ref_model, "base_model_prefix", self.args.base_model_attribute_name) |
| | ref_base_model = getattr(unwrapped_ref_model, ref_attr, unwrapped_ref_model) |
| |
|
| | ref_outputs = ref_base_model( |
| | input_ids, |
| | use_cache=False, |
| | **model_kwargs, |
| | ) |
| | ref_hidden_states = ref_outputs.last_hidden_state[:, :-1] |
| | elif not self.reference_free: |
| | if hasattr(unwrapped_model, "get_decoder") and unwrapped_model.get_decoder() is not None: |
| | ref_base_model = unwrapped_model.get_decoder() |
| | else: |
| | ref_attr = getattr(unwrapped_model, "base_model_prefix", self.args.base_model_attribute_name) |
| | ref_base_model = getattr(unwrapped_model, ref_attr, unwrapped_model) |
| | with self.null_ref_context(): |
| | ref_outputs = ref_base_model( |
| | input_ids, |
| | use_cache=False, |
| | **model_kwargs, |
| | ) |
| | ref_hidden_states = ref_outputs.last_hidden_state[:, :-1] |
| |
|
| | masked_input_ids = torch.where(loss_mask != 0, input_ids, self.label_pad_token_id) |
| | labels = masked_input_ids[:, 1:] |
| |
|
| | |
| | lm_head = unwrapped_model.get_output_embeddings() |
| |
|
| | |
| | ref_weight = None |
| | ref_bias = None |
| | if not self.reference_free: |
| | if self.ref_model is not None: |
| | unwrapped_ref_model = self.accelerator.unwrap_model(self.ref_model) |
| | ref_lm_head = unwrapped_ref_model.get_output_embeddings() |
| | else: |
| | with self.null_ref_context(): |
| | ref_lm_head = unwrapped_model.get_output_embeddings() |
| | ref_weight = ref_lm_head.weight |
| | ref_bias = ref_lm_head.bias if hasattr(ref_lm_head, "bias") else None |
| |
|
| | |
| | loss_output = self.dpo_loss_fn( |
| | lm_head.weight, |
| | hidden_states, |
| | labels, |
| | bias=lm_head.bias if hasattr(lm_head, "bias") else None, |
| | ref_input=ref_hidden_states if not self.reference_free else None, |
| | ref_weight=ref_weight if not self.reference_free else None, |
| | ref_bias=ref_bias if not self.reference_free else None, |
| | ) |
| | ( |
| | loss, |
| | (chosen_logps, rejected_logps, chosen_logits_mean, rejected_logits_mean, nll_loss, *aux_outputs), |
| | ) = loss_output |
| |
|
| | output = { |
| | "loss": loss, |
| | "chosen_logps": chosen_logps, |
| | "rejected_logps": rejected_logps, |
| | "mean_chosen_logits": chosen_logits_mean, |
| | "mean_rejected_logits": rejected_logits_mean, |
| | "nll_loss": nll_loss, |
| | "chosen_rewards": aux_outputs[0], |
| | "rejected_rewards": aux_outputs[1], |
| | } |
| | if self.aux_loss_enabled: |
| | output["aux_loss"] = outputs.aux_loss |
| |
|
| | return output |
| |
|
| | def concatenated_forward( |
| | self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]], is_ref_model: bool = False |
| | ) -> dict[str, torch.Tensor]: |
| | """ |
| | Runs the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. |
| | |
| | We do this to avoid doing two forward passes, because it's faster for FSDP. |
| | |
| | Args: |
| | model: |
| | Model to run the forward pass on. |
| | batch: |
| | Batch of input data. |
| | is_ref_model: |
| | Whether this method is being called for the reference model. If `True`, length desensitization is not |
| | applied. |
| | """ |
| | num_examples = batch["prompt_input_ids"].shape[0] |
| |
|
| | concatenated_batch = self.concatenated_inputs(batch, padding_value=self.padding_value) |
| |
|
| | model_kwargs = {"use_cache": False} |
| | if self.aux_loss_enabled: |
| | model_kwargs["output_router_logits"] = True |
| |
|
| | |
| | if "pixel_values" in concatenated_batch: |
| | model_kwargs["pixel_values"] = concatenated_batch["pixel_values"] |
| | if "pixel_attention_mask" in concatenated_batch: |
| | model_kwargs["pixel_attention_mask"] = concatenated_batch["pixel_attention_mask"] |
| | if "image_sizes" in concatenated_batch: |
| | model_kwargs["image_sizes"] = concatenated_batch["image_sizes"] |
| |
|
| | prompt_input_ids = concatenated_batch["prompt_input_ids"] |
| | prompt_attention_mask = concatenated_batch["prompt_attention_mask"] |
| | completion_input_ids = concatenated_batch["completion_input_ids"] |
| | completion_attention_mask = concatenated_batch["completion_attention_mask"] |
| | if self.is_encoder_decoder: |
| | labels = completion_input_ids |
| | labels[completion_attention_mask == 0] = self.label_pad_token_id |
| | outputs = model( |
| | input_ids=prompt_input_ids, |
| | attention_mask=prompt_attention_mask, |
| | labels=labels, |
| | **model_kwargs, |
| | ) |
| | logits = outputs.logits |
| | loss_mask = completion_attention_mask.bool() |
| | else: |
| | |
| | input_ids = torch.cat((prompt_input_ids, completion_input_ids), dim=1) |
| | attention_mask = torch.cat((prompt_attention_mask, completion_attention_mask), dim=1) |
| | |
| | loss_mask = torch.cat( |
| | (torch.zeros_like(prompt_attention_mask), completion_attention_mask), |
| | dim=1, |
| | ) |
| |
|
| | |
| | if self.max_length is not None and self.max_length < attention_mask.size(1): |
| | if self.truncation_mode == "keep_start": |
| | |
| | |
| | |
| | attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask) |
| | attention_mask = attention_mask[:, : self.max_length] |
| | input_ids = input_ids[:, : self.max_length] |
| | loss_mask = loss_mask[:, : self.max_length] |
| | elif self.truncation_mode == "keep_end": |
| | |
| | |
| | |
| | attention_mask, input_ids, loss_mask = flush_right(attention_mask, input_ids, loss_mask) |
| | input_ids = input_ids[:, -self.max_length :] |
| | attention_mask = attention_mask[:, -self.max_length :] |
| | loss_mask = loss_mask[:, -self.max_length :] |
| | attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask) |
| | else: |
| | raise ValueError( |
| | f"Unknown truncation mode: '{self.truncation_mode}'. Should be one of ['keep_end', " |
| | "'keep_start']." |
| | ) |
| | else: |
| | |
| | |
| | |
| | attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask) |
| |
|
| | if self.use_logits_to_keep: |
| | |
| | |
| | |
| | |
| | first_compute_index = loss_mask.nonzero(as_tuple=True)[1].min() |
| | logits_to_keep = (loss_mask.shape[1] - first_compute_index).item() + 1 |
| | model_kwargs["logits_to_keep"] = logits_to_keep |
| |
|
| | model_kwargs["output_hidden_states"] = True |
| |
|
| | if self.padding_free: |
| | |
| | |
| | |
| | input_ids = input_ids[attention_mask.bool()].unsqueeze(0) |
| | loss_mask = loss_mask[attention_mask.bool()].unsqueeze(0) |
| | position_ids = attention_mask.cumsum(1)[attention_mask.bool()].unsqueeze(0) - 1 |
| | model_kwargs["position_ids"] = position_ids |
| | else: |
| | model_kwargs["attention_mask"] = attention_mask |
| |
|
| | outputs = model(input_ids, **model_kwargs) |
| | logits = outputs.logits |
| |
|
| | |
| | labels = torch.roll(input_ids, shifts=-1, dims=1) |
| | loss_mask = torch.roll(loss_mask, shifts=-1, dims=1).bool() |
| |
|
| | if self.use_logits_to_keep: |
| | |
| | |
| | |
| | |
| | |
| | |
| | labels = labels[:, -logits_to_keep:] |
| | loss_mask = loss_mask[:, -logits_to_keep:] |
| |
|
| | if logits.shape[:2] != labels.shape[:2]: |
| | |
| | seq_len = labels.shape[1] |
| | logits = logits[:, -seq_len:] |
| |
|
| | |
| | labels[~loss_mask] = 0 |
| | per_token_logps = selective_log_softmax(logits, labels) |
| | per_token_logps[~loss_mask] = 0 |
| | per_token_logps = torch.roll(per_token_logps, shifts=1, dims=1) |
| |
|
| | if self.padding_free: |
| | |
| | batch_size, seq_len = attention_mask.shape |
| | per_token_logps_ = torch.zeros( |
| | batch_size, seq_len, device=outputs.logits.device, dtype=outputs.logits.dtype |
| | ) |
| | per_token_logps_[attention_mask.bool()] = per_token_logps |
| | per_token_logps = per_token_logps_ |
| |
|
| | all_logps = per_token_logps[:, 1:].sum(-1) |
| |
|
| | output = {} |
| |
|
| | if self.use_weighting: |
| | with torch.no_grad(): |
| | |
| | logprobs = F.log_softmax(logits, dim=-1) |
| | weights_adjustment_factor = torch.logsumexp(2 * logprobs, dim=-1) |
| | per_token_logps_adjusted = per_token_logps - weights_adjustment_factor |
| | all_weights = (per_token_logps_adjusted * loss_mask).sum(-1) / loss_mask.sum(-1) |
| | chosen_weights = all_weights[:num_examples] |
| | rejected_weights = all_weights[num_examples:] |
| | output["policy_weights"] = torch.clamp(torch.exp(chosen_weights + rejected_weights), max=1) |
| |
|
| | if self.args.rpo_alpha is not None or "sft" in self.loss_type: |
| | |
| | chosen_logits = logits[:num_examples, :-1] if not self.is_encoder_decoder else logits[:num_examples] |
| | chosen_labels = labels[:num_examples, :-1] if not self.is_encoder_decoder else labels[:num_examples] |
| |
|
| | |
| | output["nll_loss"] = F.cross_entropy( |
| | torch.flatten(chosen_logits, end_dim=1), torch.flatten(chosen_labels, end_dim=1), ignore_index=0 |
| | ) |
| |
|
| | if "ipo" in self.loss_type: |
| | all_logps = all_logps / loss_mask.sum(-1) |
| |
|
| | if self.args.ld_alpha is not None and not is_ref_model: |
| | |
| | completion_lengths = loss_mask.sum(dim=1) |
| |
|
| | chosen_lengths = completion_lengths[:num_examples] |
| | rejected_lengths = completion_lengths[num_examples:] |
| | public_lengths = torch.min(chosen_lengths, rejected_lengths) |
| | public_lengths = torch.cat([public_lengths, public_lengths], dim=0) |
| |
|
| | seq_len = per_token_logps.size(1) |
| | position_ids = torch.arange(seq_len, device=per_token_logps.device).expand_as(per_token_logps) |
| |
|
| | ld_mask = position_ids < public_lengths.unsqueeze(1) |
| | mask = position_ids < completion_lengths.unsqueeze(1) |
| |
|
| | front_mask = (ld_mask & mask).float() |
| | rear_mask = (~ld_mask & mask).float() |
| | front_logps = (per_token_logps * front_mask).sum(dim=1) |
| | rear_logps = (per_token_logps * rear_mask).sum(dim=1) |
| |
|
| | all_logps = front_logps + self.args.ld_alpha * rear_logps |
| |
|
| | output["chosen_logps"] = all_logps[:num_examples] |
| | output["rejected_logps"] = all_logps[num_examples:] |
| |
|
| | |
| | if self.padding_free: |
| | |
| | |
| | |
| | |
| | split_idx = (position_ids == 0).nonzero(as_tuple=True)[1][num_examples] |
| | mean_chosen_logits = logits[0, :split_idx][loss_mask[0, :split_idx]].mean() |
| | mean_rejected_logits = logits[0, split_idx:][loss_mask[0, split_idx:]].mean() |
| | else: |
| | mean_chosen_logits = logits[:num_examples][loss_mask[:num_examples]].mean() |
| | mean_rejected_logits = logits[num_examples:][loss_mask[num_examples:]].mean() |
| |
|
| | output["mean_chosen_logits"] = mean_chosen_logits |
| | output["mean_rejected_logits"] = mean_rejected_logits |
| |
|
| | if self.aux_loss_enabled: |
| | output["aux_loss"] = outputs.aux_loss |
| |
|
| | return output |
| |
|
| | def get_batch_loss_metrics( |
| | self, |
| | model: Union[PreTrainedModel, nn.Module], |
| | batch: dict[str, Union[list, torch.LongTensor]], |
| | train_eval: Literal["train", "eval"] = "train", |
| | ) -> tuple[torch.Tensor, dict[str, float]]: |
| | """Compute the DPO loss and other metrics for the given batch of inputs for train or test.""" |
| | metrics = {} |
| |
|
| | if self.args.use_liger_loss: |
| | model_output = self._compute_loss_liger(model, batch) |
| | losses = model_output["loss"] |
| | chosen_rewards = model_output["chosen_rewards"] |
| | rejected_rewards = model_output["rejected_rewards"] |
| | else: |
| | model_output = self.concatenated_forward(model, batch) |
| |
|
| | |
| | if "ref_chosen_logps" in batch and "ref_rejected_logps" in batch: |
| | ref_chosen_logps = batch["ref_chosen_logps"] |
| | ref_rejected_logps = batch["ref_rejected_logps"] |
| | else: |
| | ref_chosen_logps, ref_rejected_logps = self.compute_ref_log_probs(batch) |
| |
|
| | |
| | losses = 0 |
| | chosen_rewards = 0 |
| | rejected_rewards = 0 |
| |
|
| | |
| | for idx, loss_type in enumerate(self.loss_type): |
| | |
| | _losses, _chosen_rewards, _rejected_rewards = self.dpo_loss( |
| | model_output["chosen_logps"], |
| | model_output["rejected_logps"], |
| | ref_chosen_logps, |
| | ref_rejected_logps, |
| | loss_type, |
| | model_output, |
| | ) |
| |
|
| | |
| | weight = self.loss_weights[idx] if self.loss_weights else 1.0 |
| | losses = losses + _losses * weight |
| | chosen_rewards = chosen_rewards + _chosen_rewards * weight |
| | rejected_rewards = rejected_rewards + _rejected_rewards * weight |
| |
|
| | reward_accuracies = (chosen_rewards > rejected_rewards).float() |
| |
|
| | if self.args.rpo_alpha is not None: |
| | losses = losses + self.args.rpo_alpha * model_output["nll_loss"] |
| |
|
| | if self.use_weighting: |
| | losses = losses * model_output["policy_weights"] |
| |
|
| | if self.aux_loss_enabled: |
| | losses = losses + self.aux_loss_coef * model_output["aux_loss"] |
| |
|
| | prefix = "eval_" if train_eval == "eval" else "" |
| | metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean().item() |
| | metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean().item() |
| | metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean().item() |
| | metrics[f"{prefix}rewards/margins"] = ( |
| | self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards).mean().item() |
| | ) |
| | metrics[f"{prefix}logps/chosen"] = ( |
| | self.accelerator.gather_for_metrics(model_output["chosen_logps"]).detach().mean().item() |
| | ) |
| | metrics[f"{prefix}logps/rejected"] = ( |
| | self.accelerator.gather_for_metrics(model_output["rejected_logps"]).detach().mean().item() |
| | ) |
| | metrics[f"{prefix}logits/chosen"] = ( |
| | self.accelerator.gather_for_metrics(model_output["mean_chosen_logits"]).detach().mean().item() |
| | ) |
| | metrics[f"{prefix}logits/rejected"] = ( |
| | self.accelerator.gather_for_metrics(model_output["mean_rejected_logits"]).detach().mean().item() |
| | ) |
| | if self.args.rpo_alpha is not None or "sft" in self.loss_type: |
| | metrics[f"{prefix}nll_loss"] = ( |
| | self.accelerator.gather_for_metrics(model_output["nll_loss"]).detach().mean().item() |
| | ) |
| | if self.aux_loss_enabled: |
| | metrics[f"{prefix}aux_loss"] = ( |
| | self.accelerator.gather_for_metrics(model_output["aux_loss"]).detach().mean().item() |
| | ) |
| |
|
| | return losses.mean(), metrics |
| |
|
| | def compute_loss( |
| | self, |
| | model: Union[PreTrainedModel, nn.Module], |
| | inputs: dict[str, Union[torch.Tensor, Any]], |
| | return_outputs=False, |
| | num_items_in_batch=None, |
| | ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, float]]]: |
| | compute_loss_context_manager = ( |
| | autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() |
| | ) |
| | with compute_loss_context_manager: |
| | loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train") |
| |
|
| | |
| | loss = loss.to(self.args.device) |
| | |
| | self.store_metrics(metrics, train_eval="train") |
| |
|
| | if return_outputs: |
| | return loss, metrics |
| |
|
| | return loss |
| |
|
| | def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) -> tuple[str, str]: |
| | """Generate samples from the model and reference model for the given batch of inputs.""" |
| |
|
| | |
| | |
| | generate_context_manager = ( |
| | autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() |
| | ) |
| |
|
| | with generate_context_manager: |
| | policy_output = model.generate( |
| | input_ids=batch["prompt_input_ids"], |
| | attention_mask=batch["prompt_attention_mask"], |
| | max_length=self.max_length, |
| | do_sample=True, |
| | pad_token_id=self.padding_value, |
| | ) |
| |
|
| | |
| | if "ref_output" in batch: |
| | ref_output = batch["ref_output"] |
| | else: |
| | if self.ref_model is None: |
| | with self.null_ref_context(): |
| | ref_output = self.model.generate( |
| | input_ids=batch["prompt_input_ids"], |
| | attention_mask=batch["prompt_attention_mask"], |
| | max_length=self.max_length, |
| | do_sample=True, |
| | pad_token_id=self.padding_value, |
| | ) |
| | else: |
| | ref_output = self.ref_model.generate( |
| | input_ids=batch["prompt_input_ids"], |
| | attention_mask=batch["prompt_attention_mask"], |
| | max_length=self.max_length, |
| | do_sample=True, |
| | pad_token_id=self.padding_value, |
| | ) |
| |
|
| | policy_output = pad_to_length(policy_output, self.max_length, self.padding_value) |
| | policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True) |
| |
|
| | ref_output = pad_to_length(ref_output, self.max_length, self.padding_value) |
| | ref_output_decoded = self.processing_class.batch_decode(ref_output, skip_special_tokens=True) |
| |
|
| | return policy_output_decoded, ref_output_decoded |
| |
|
| | def prediction_step( |
| | self, |
| | model: Union[PreTrainedModel, nn.Module], |
| | inputs: dict[str, Union[torch.Tensor, Any]], |
| | prediction_loss_only: bool, |
| | ignore_keys: Optional[list[str]] = None, |
| | ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: |
| | if ignore_keys is None: |
| | if hasattr(model, "config"): |
| | ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", []) |
| | else: |
| | ignore_keys = [] |
| |
|
| | prediction_context_manager = ( |
| | autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() |
| | ) |
| |
|
| | with torch.no_grad(), prediction_context_manager: |
| | loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval") |
| |
|
| | |
| | self.store_metrics(metrics, train_eval="eval") |
| |
|
| | if prediction_loss_only: |
| | return loss.detach(), None, None |
| |
|
| | |
| | logits_dict = { |
| | "eval_logits/chosen": metrics["eval_logits/chosen"], |
| | "eval_logits/rejected": metrics["eval_logits/rejected"], |
| | } |
| | logits = [v for k, v in logits_dict.items() if k not in ignore_keys] |
| | logits = torch.tensor(logits, device=self.accelerator.device) |
| | labels = torch.zeros(logits.shape[0], device=self.accelerator.device) |
| |
|
| | return (loss.detach(), logits, labels) |
| |
|
| | def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None: |
| | for key, value in metrics.items(): |
| | self._stored_metrics[train_eval][key].append(value) |
| |
|
| | def evaluation_loop( |
| | self, |
| | dataloader: DataLoader, |
| | description: str, |
| | prediction_loss_only: Optional[bool] = None, |
| | ignore_keys: Optional[list[str]] = None, |
| | metric_key_prefix: str = "eval", |
| | ) -> EvalLoopOutput: |
| | """ |
| | Overriding built-in evaluation loop to store metrics for each batch. Prediction/evaluation loop, shared by |
| | `Trainer.evaluate()` and `Trainer.predict()`. |
| | |
| | Works both with or without labels. |
| | """ |
| |
|
| | |
| | if self.generate_during_eval: |
| | |
| | num_samples = len(dataloader.dataset) |
| | random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size) |
| |
|
| | |
| | random_batch_dataset = dataloader.dataset.select(random_indices) |
| | random_batch = self.data_collator(random_batch_dataset) |
| | random_batch = self._prepare_inputs(random_batch) |
| |
|
| | policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, random_batch) |
| |
|
| | table = pd.DataFrame( |
| | columns=["Prompt", "Policy", "Ref Model"], |
| | data=[ |
| | [prompt, pol[len(prompt) :], ref[len(prompt) :]] |
| | for prompt, pol, ref in zip( |
| | random_batch_dataset["prompt"], policy_output_decoded, ref_output_decoded |
| | ) |
| | ], |
| | ) |
| | if "wandb" in self.args.report_to and self.accelerator.is_main_process: |
| | wandb.log({"game_log": wandb.Table(data=table)}) |
| |
|
| | if "comet_ml" in self.args.report_to: |
| | log_table_to_comet_experiment( |
| | name="game_log.csv", |
| | table=table, |
| | ) |
| |
|
| | if "mlflow" in self.args.report_to and self.accelerator.is_main_process: |
| | mlflow.log_table(data=table, artifact_file="game_log.json") |
| |
|
| | |
| | initial_output = super().evaluation_loop( |
| | dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix |
| | ) |
| |
|
| | return initial_output |
| |
|
| | def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None: |
| | """ |
| | Log `logs` on the various objects watching training, including stored metrics. |
| | |
| | Args: |
| | logs (`dict[str, float]`): |
| | The values to log. |
| | start_time (`float` or `None`, *optional*, defaults to `None`): |
| | Start time of the training. |
| | """ |
| | |
| | train_eval = "train" if "loss" in logs else "eval" |
| | |
| | for key, metrics in self._stored_metrics[train_eval].items(): |
| | logs[key] = torch.tensor(metrics).mean().item() |
| | del self._stored_metrics[train_eval] |
| | return super().log(logs, start_time) |
| |
|
| | |
| | def _save_checkpoint(self, model, trial): |
| | if self.args.hub_model_id is None: |
| | model_name = Path(self.args.output_dir).name |
| | else: |
| | model_name = self.args.hub_model_id.split("/")[-1] |
| | self.create_model_card(model_name=model_name) |
| | super()._save_checkpoint(model, trial) |
| |
|
| | def create_model_card( |
| | self, |
| | model_name: Optional[str] = None, |
| | dataset_name: Optional[str] = None, |
| | tags: Union[str, list[str], None] = None, |
| | ): |
| | """ |
| | Creates a draft of a model card using the information available to the `Trainer`. |
| | |
| | Args: |
| | model_name (`str` or `None`, *optional*, defaults to `None`): |
| | Name of the model. |
| | dataset_name (`str` or `None`, *optional*, defaults to `None`): |
| | Name of the dataset used for training. |
| | tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`): |
| | Tags to be associated with the model card. |
| | """ |
| | if not self.is_world_process_zero(): |
| | return |
| |
|
| | if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): |
| | base_model = self.model.config._name_or_path |
| | else: |
| | base_model = None |
| |
|
| | |
| | if tags is None: |
| | tags = set() |
| | elif isinstance(tags, str): |
| | tags = {tags} |
| | else: |
| | tags = set(tags) |
| |
|
| | if hasattr(self.model.config, "unsloth_version"): |
| | tags.add("unsloth") |
| |
|
| | if "JOB_ID" in os.environ: |
| | tags.add("hf_jobs") |
| |
|
| | tags.update(self._tag_names) |
| |
|
| | |
| | citation = textwrap.dedent( |
| | """\ |
| | @inproceedings{rafailov2023direct, |
| | title = {{Direct Preference Optimization: Your Language Model is Secretly a Reward Model}}, |
| | author = {Rafael Rafailov and Archit Sharma and Eric Mitchell and Christopher D. Manning and Stefano Ermon and Chelsea Finn}, |
| | year = 2023, |
| | booktitle = {Advances in Neural Information Processing Systems 36: Annual Conference on Neural Information Processing Systems 2023, NeurIPS 2023, New Orleans, LA, USA, December 10 - 16, 2023}, |
| | url = {http://papers.nips.cc/paper_files/paper/2023/hash/a85b405ed65c6477a4fe8302b5e06ce7-Abstract-Conference.html}, |
| | editor = {Alice Oh and Tristan Naumann and Amir Globerson and Kate Saenko and Moritz Hardt and Sergey Levine}, |
| | }""" |
| | ) |
| |
|
| | model_card = generate_model_card( |
| | base_model=base_model, |
| | model_name=model_name, |
| | hub_model_id=self.hub_model_id, |
| | dataset_name=dataset_name, |
| | tags=tags, |
| | wandb_url=wandb.run.url if is_wandb_available() and wandb.run is not None else None, |
| | comet_url=get_comet_experiment_url(), |
| | trainer_name="DPO", |
| | trainer_citation=citation, |
| | paper_title="Direct Preference Optimization: Your Language Model is Secretly a Reward Model", |
| | paper_id="2305.18290", |
| | ) |
| |
|
| | model_card.save(os.path.join(self.args.output_dir, "README.md")) |
| | class UnslothDPOTrainer(_UnslothDPOTrainer): |
| | """ |
| | |
| | Trainer for Direct Preference Optimization (DPO) method. |
| | |
| | This class is a wrapper around the [`transformers.Trainer`] class and inherits all of its attributes and methods. |
| | |
| | Args: |
| | model (`Union[str, PreTrainedModel]`): |
| | Model to be trained. Can be either: |
| | |
| | - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a |
| | path to a *directory* containing model weights saved using |
| | [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded |
| | using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keyword arguments in |
| | `args.model_init_kwargs`. |
| | - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported. |
| | ref_model (`PreTrainedModelWrapper`): |
| | Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation |
| | and loss. If no reference model is provided, the trainer will create a reference model with the same |
| | architecture as the model to be optimized. |
| | args ([`DPOConfig`], *optional*, defaults to `None`): |
| | Configuration for this trainer. If `None`, a default configuration is used. |
| | data_collator (`DataCollator`, *optional*): |
| | Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`. |
| | Will default to [`DataCollatorForPreference`]. |
| | train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]): |
| | Dataset to use for training. DPO supports [preference](#preference) type and. The format of the samples can |
| | be either: |
| | |
| | - [Standard](dataset_formats#standard): Each sample contains plain text. |
| | - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role |
| | and content). |
| | eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, |
| | IterableDataset]]`): |
| | Dataset to use for evaluation. It must meet the same requirements as `train_dataset`. |
| | processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*, defaults to `None`): |
| | Processing class used to process the data. If `None`, the processing class is loaded from the model's name |
| | with [`~transformers.AutoTokenizer.from_pretrained`]. |
| | compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): |
| | The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return |
| | a dictionary string to metric values. *Note* When passing TrainingArgs with `batch_eval_metrics` set to |
| | `True`, your compute_metrics function must take a boolean `compute_result` argument. This will be triggered |
| | after the last eval batch to signal that the function needs to calculate and return the global summary |
| | statistics rather than accumulating the batch-level statistics. |
| | callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`): |
| | List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed |
| | in [here](https://huggingface.co/docs/transformers/main_classes/callback). |
| | |
| | If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`] |
| | method. |
| | optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, |
| | None)`): |
| | A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your |
| | model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`. |
| | optimizer_cls_and_kwargs (`Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]`, *optional*, defaults to |
| | `None`): |
| | A tuple containing the optimizer class and keyword arguments to use. Overrides `optim` and `optim_args` in |
| | `args`. Incompatible with the `optimizers` argument. |
| | preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*, defaults to |
| | `None`): |
| | A function that preprocess the logits right before caching them at each evaluation step. Must take two |
| | tensors, the logits and the labels, and return the logits once processed as desired. The modifications made |
| | by this function will be reflected in the predictions received by `compute_metrics`. |
| | |
| | Note that the labels (second parameter) will be `None` if the dataset does not have them. |
| | peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`): |
| | PEFT configuration used to wrap the model. If `None`, the model is not wrapped. |
| | |
| | """ |
| | def __init__( |
| | self, |
| | model, |
| | ref_model = None, |
| | args = None, |
| | data_collator = None, |
| | train_dataset = None, |
| | eval_dataset = None, |
| | processing_class = None, |
| | compute_metrics = None, |
| | callbacks = None, |
| | optimizer_cls_and_kwargs = None, |
| | preprocess_logits_for_metrics = None, |
| | peft_config = None, |
| | **kwargs |
| | ): |
| | if args is None: args = UnslothDPOConfig() |
| | use_bf16 = getattr(args, 'bf16', False) |
| | if type(use_bf16) is not bool: use_bf16 = False |
| | use_fp16 = getattr(args, 'fp16', False) |
| | if type(use_fp16) is not bool: use_fp16 = False |
| | force_float32 = False |
| | full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1' |
| | if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'): |
| | print('Unsloth: Switching to float32 training since model cannot work with float16') |
| | force_float32 = True |
| | mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') |
| | dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None) |
| | if dtype is None: dtype = model.get_input_embeddings().dtype |
| | from unsloth_zoo.utils import _get_dtype |
| | dtype = _get_dtype(dtype) |
| | float16 = dtype == torch.float16 |
| | if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`') |
| | if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`') |
| | if force_float32: |
| | |
| | args.fp16 = False |
| | args.bf16 = False |
| | os.environ['ACCELERATE_MIXED_PRECISION'] = 'no' |
| | elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32': |
| | |
| | args.fp16 = float16 |
| | args.bf16 = not float16 |
| | os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16' |
| | if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no': |
| | args.eval_strategy = 'steps' |
| | if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1 |
| | ga_steps = getattr(args, 'gradient_accumulation_steps', None) |
| | if ga_steps is not None and ga_steps > 1: |
| | from transformers import __version__ as transformers_version |
| | if Version(transformers_version) <= Version('4.45.2'): |
| | print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n' |
| | '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`') |
| | if getattr(args, 'eval_strategy', 'no') != 'no': |
| | eval_bsz = getattr(args, 'per_device_eval_batch_size', 8) |
| | if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size |
| | if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps |
| | fp16_full_eval = getattr(args, 'fp16_full_eval', False) |
| | if type(fp16_full_eval) is not bool: fp16_full_eval = False |
| | bf16_full_eval = getattr(args, 'bf16_full_eval', False) |
| | if type(bf16_full_eval) is not bool: bf16_full_eval = False |
| | if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True |
| | if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False |
| | if force_float32: |
| | args.bf16_full_eval = False |
| | args.fp16_full_eval = False |
| | elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16': |
| | args.bf16_full_eval = True |
| | args.fp16_full_eval = False |
| | elif not bf16_full_eval and not fp16_full_eval: |
| | args.bf16_full_eval = args.bf16 |
| | args.fp16_full_eval = args.fp16 |
| | _output_logits = False |
| | if locals().get('compute_metrics', None) is not None: _output_logits = True |
| | if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True |
| | if _output_logits: |
| | os.environ['UNSLOTH_RETURN_LOGITS'] = '1' |
| | if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'): |
| | pass |
| | else: |
| | model_max_seq_length = getattr(model, 'max_seq_length', None) |
| | args_max_seq_length = getattr(args, 'max_seq_length', None) |
| | if args_max_seq_length is None and model_max_seq_length is not None: |
| | max_seq_length = model.max_seq_length |
| | if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length |
| | if model is not None and hasattr(model, 'for_training'): |
| | model.for_training() |
| | if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right' |
| | if 'processing_class' in locals(): |
| | if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right' |
| | if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right' |
| | __tokenizer = processing_class if 'processing_class' in locals() else tokenizer |
| | from unsloth_zoo.vision_utils import UnslothVisionDataCollator |
| | if not isinstance(data_collator, UnslothVisionDataCollator): |
| | if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names: |
| | data_collator = TransformersDataCollatorForLanguageModeling( |
| | __tokenizer, |
| | mlm = False, |
| | mlm_probability = 0.0, |
| | pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), |
| | ) |
| | elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names: |
| | data_collator = DataCollatorForSeq2Seq( |
| | __tokenizer, |
| | pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), |
| | ) |
| | else: |
| | if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False |
| | if hasattr(args, 'dataset_text_field'): args.dataset_text_field = '' |
| | if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True} |
| | if not isinstance(data_collator, UnslothVisionDataCollator): |
| | if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'): |
| | if isinstance(data_collator, DataCollatorForSeq2Seq): |
| | data_collator = DataCollatorForSeq2Seq( |
| | __tokenizer.tokenizer, |
| | pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), |
| | ) |
| | else: |
| | data_collator = TransformersDataCollatorForLanguageModeling( |
| | __tokenizer.tokenizer, |
| | mlm = False, |
| | mlm_probability = 0.0, |
| | pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), |
| | ) |
| | other_metrics = [] |
| | |
| | from unsloth_zoo.logging_utils import PatchRLStatistics |
| | PatchRLStatistics('dpo_trainer', other_metrics) |
| | if hasattr(train_dataset, 'column_names'): |
| | column_names = set(train_dataset.column_names) |
| | check = ['chosen', 'rejected', 'prompt', 'chosen_input_ids', 'chosen_attention_mask', |
| | 'chosen_labels', 'rejected_input_ids', 'rejected_attention_mask', 'rejected_labels', |
| | 'prompt_input_ids', 'prompt_attention_mask'] |
| | if all(x in column_names for x in check): |
| | train_dataset = train_dataset.remove_columns(['chosen', 'rejected', 'prompt']) |
| | del check, column_names |
| | |
| | |
| | |
| | if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1: |
| | if getattr(args, "_n_gpu", 1) != 1: |
| | args._n_gpu = 1 |
| | if "model" in locals() and hasattr(model, "for_training"): |
| | model.for_training() |
| | super().__init__( |
| | model = model, |
| | ref_model = ref_model, |
| | args = args, |
| | data_collator = data_collator, |
| | train_dataset = train_dataset, |
| | eval_dataset = eval_dataset, |
| | processing_class = processing_class, |
| | compute_metrics = compute_metrics, |
| | callbacks = callbacks, |
| | optimizer_cls_and_kwargs = optimizer_cls_and_kwargs, |
| | preprocess_logits_for_metrics = preprocess_logits_for_metrics, |
| | peft_config = peft_config,**kwargs) |
| | if "model" in locals() and hasattr(model, "for_inference"): |
| | model.for_inference() |
| | if hasattr(self, 'neftune_hook_handle'): |
| | self.neftune_hook_handle.remove() |
| | if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle |
| | if getattr(args, 'neftune_noise_alpha', None) is not None: |
| | model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha |
| | pass |
| | if hasattr(self, 'accelerator'): |
| | scaler = self.accelerator.scaler |
| | current_model = model |
| | while hasattr(current_model, 'model'): |
| | current_model.accelerator_scaler = scaler |
| | current_model = current_model.model |
| | current_model.accelerator_scaler = scaler |
| | pass |
| | if hasattr(self, 'train'): |
| | self.train = MethodType(prepare_for_training_mode(self.__class__.train), self) |
| | pass |
| | |
| | pass |
| |
|
| |
|
| | if hasattr(logger, "addFilter"): |
| | import logging |
| | class HideLoggingMessage(logging.Filter): |
| | def __init__(self, text): self.text = text |
| | def filter(self, x): return not (self.text in x.getMessage()) |
| | pass |
| | logger.addFilter(HideLoggingMessage("`use_cache=True`")) |
| |
|
| |
|