| | """ |
| | 2025.10.1 |
| | 2025.10.1 |
| | 4.56.2 |
| | 0.22.2 |
| | __UNSLOTH_VERSIONING__ |
| | """ |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from torch import Tensor |
| | import torch |
| | import torch.nn as nn |
| | from torch.nn import functional as F |
| | from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable |
| | from trl.trainer.ddpo_trainer import (Accelerator, Any, Callable, DDPOConfig, DDPOStableDiffusionPipeline, DDPOTrainer, Optional, Path, PerPromptStatTracker, ProjectConfiguration, PyTorchModelHubMixin, Union, defaultdict, futures, generate_model_card, get_comet_experiment_url, is_wandb_available, logger, logging, os, set_seed, textwrap, torch, wandb, warnings) |
| |
|
| |
|
| | import os |
| | from typing import * |
| | from dataclasses import dataclass, field |
| | from packaging.version import Version |
| | import torch |
| | import numpy as np |
| | from contextlib import nullcontext |
| | from torch.nn import functional as F |
| | from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling |
| | from transformers.training_args import ParallelMode |
| |
|
| | |
| | import functools |
| | from types import MethodType |
| | def prepare_for_training_mode(f): |
| | @functools.wraps(f) |
| | def wrapper(self, *args, **kwargs): |
| | |
| | if hasattr(self, 'model') and hasattr(self.model, "for_training"): |
| | self.model.for_training() |
| | output = f(self, *args, **kwargs) |
| | |
| | if hasattr(self, 'model') and hasattr(self.model, "for_inference"): |
| | self.model.for_inference() |
| | return output |
| | return wrapper |
| | pass |
| |
|
| | torch_compile_options = { |
| | "epilogue_fusion" : True, |
| | "max_autotune" : False, |
| | "shape_padding" : True, |
| | "trace.enabled" : False, |
| | "triton.cudagraphs" : False, |
| | } |
| |
|
| | @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) |
| | def chunked_selective_log_softmax(logits, index): |
| | |
| | chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0) |
| | chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0) |
| | all_per_token_logps = [] |
| | |
| | for chunk_logits, chunk_index in zip(chunked_logits, chunked_index): |
| | chunk_logits = chunk_logits.to(torch.float32) |
| | selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1) |
| | logsumexp_values = torch.logsumexp(chunk_logits, dim = -1) |
| | per_token_logps = selected_logits - logsumexp_values |
| | all_per_token_logps.append(per_token_logps) |
| | pass |
| | all_per_token_logps = torch.concat(all_per_token_logps) |
| | all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1])) |
| | return all_per_token_logps |
| |
|
| | def calculate_pad_tokens_in_prompt( |
| | input_ids: torch.Tensor, |
| | logits_to_keep: int, |
| | pad_token_id: int |
| | ) -> torch.Tensor: |
| | """ |
| | Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens |
| | """ |
| | if logits_to_keep >= input_ids.shape[1]: |
| | raise ValueError("logits_to_keep must be smaller than the sequence length.") |
| |
|
| | prompt_section = input_ids[:, :-logits_to_keep] |
| |
|
| | padding_mask = (prompt_section == pad_token_id) |
| |
|
| | pad_token_counts = padding_mask.sum(dim=1) |
| |
|
| | return pad_token_counts |
| |
|
| | def create_completion_attention_mask( |
| | completion_input_ids: torch.Tensor, |
| | left_pad_tokens_per_prompt: torch.Tensor, |
| | max_left_pad: int, |
| | pad_token_id: int |
| | ) -> torch.Tensor: |
| | """ |
| | Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad] |
| | |
| | Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens |
| | and pad are pad tokens, this function would make a completion mask that would 0 out the pad |
| | and p tokens. so in this example [0,0,0,1,1,1,0,0,0] |
| | """ |
| | batch_size, completion_len = completion_input_ids.shape |
| | device = completion_input_ids.device |
| |
|
| | num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt |
| |
|
| | indices = torch.arange(completion_len, device=device).unsqueeze(0) |
| | shift_mask = indices >= num_tokens_to_mask.unsqueeze(1) |
| |
|
| | non_padding_mask = (completion_input_ids != pad_token_id) |
| |
|
| | final_mask = shift_mask & non_padding_mask |
| |
|
| | return final_mask |
| |
|
| | def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor: |
| | """ |
| | Moves all padding tokens in each sequence of a batch to the right. |
| | """ |
| | mask = (tensor != pad_id) |
| | |
| | sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True) |
| | packed_tensor = torch.gather(tensor, 1, sorted_indices) |
| | return packed_tensor |
| | @dataclass |
| | class UnslothDDPOConfig(DDPOConfig): |
| | """ |
| | |
| | Configuration class for the [`DDPOTrainer`]. |
| | |
| | Using [`~transformers.HfArgumentParser`] we can turn this class into |
| | [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the |
| | command line. |
| | |
| | Parameters: |
| | exp_name (`str`, *optional*, defaults to `os.path.basename(sys.argv[0])[: -len(".py")]`): |
| | Name of this experiment (by default is the file name without the extension name). |
| | run_name (`str`, *optional*, defaults to `""`): |
| | Name of this run. |
| | seed (`int`, *optional*, defaults to `0`): |
| | Random seed. |
| | log_with (`Literal["wandb", "tensorboard"]]` or `None`, *optional*, defaults to `None`): |
| | Log with either 'wandb' or 'tensorboard', check |
| | https://huggingface.co/docs/accelerate/usage_guides/tracking for more details. |
| | tracker_kwargs (`Dict`, *optional*, defaults to `{}`): |
| | Keyword arguments for the tracker (e.g. wandb_project). |
| | accelerator_kwargs (`Dict`, *optional*, defaults to `{}`): |
| | Keyword arguments for the accelerator. |
| | project_kwargs (`Dict`, *optional*, defaults to `{}`): |
| | Keyword arguments for the accelerator project config (e.g. `logging_dir`). |
| | tracker_project_name (`str`, *optional*, defaults to `"trl"`): |
| | Name of project to use for tracking. |
| | logdir (`str`, *optional*, defaults to `"logs"`): |
| | Top-level logging directory for checkpoint saving. |
| | num_epochs (`int`, *optional*, defaults to `100`): |
| | Number of epochs to train. |
| | save_freq (`int`, *optional*, defaults to `1`): |
| | Number of epochs between saving model checkpoints. |
| | num_checkpoint_limit (`int`, *optional*, defaults to `5`): |
| | Number of checkpoints to keep before overwriting old ones. |
| | mixed_precision (`str`, *optional*, defaults to `"fp16"`): |
| | Mixed precision training. |
| | allow_tf32 (`bool`, *optional*, defaults to `True`): |
| | Allow `tf32` on Ampere GPUs. |
| | resume_from (`str`, *optional*, defaults to `""`): |
| | Resume training from a checkpoint. |
| | sample_num_steps (`int`, *optional*, defaults to `50`): |
| | Number of sampler inference steps. |
| | sample_eta (`float`, *optional*, defaults to `1.0`): |
| | Eta parameter for the DDIM sampler. |
| | sample_guidance_scale (`float`, *optional*, defaults to `5.0`): |
| | Classifier-free guidance weight. |
| | sample_batch_size (`int`, *optional*, defaults to `1`): |
| | Batch size (per GPU) to use for sampling. |
| | sample_num_batches_per_epoch (`int`, *optional*, defaults to `2`): |
| | Number of batches to sample per epoch. |
| | train_batch_size (`int`, *optional*, defaults to `1`): |
| | Batch size (per GPU) to use for training. |
| | train_use_8bit_adam (`bool`, *optional*, defaults to `False`): |
| | Use 8bit Adam optimizer from bitsandbytes. |
| | train_learning_rate (`float`, *optional*, defaults to `3e-4`): |
| | Learning rate. |
| | train_adam_beta1 (`float`, *optional*, defaults to `0.9`): |
| | Adam beta1. |
| | train_adam_beta2 (`float`, *optional*, defaults to `0.999`): |
| | Adam beta2. |
| | train_adam_weight_decay (`float`, *optional*, defaults to `1e-4`): |
| | Adam weight decay. |
| | train_adam_epsilon (`float`, *optional*, defaults to `1e-8`): |
| | Adam epsilon. |
| | train_gradient_accumulation_steps (`int`, *optional*, defaults to `1`): |
| | Number of gradient accumulation steps. |
| | train_max_grad_norm (`float`, *optional*, defaults to `1.0`): |
| | Maximum gradient norm for gradient clipping. |
| | train_num_inner_epochs (`int`, *optional*, defaults to `1`): |
| | Number of inner epochs per outer epoch. |
| | train_cfg (`bool`, *optional*, defaults to `True`): |
| | Whether to use classifier-free guidance during training. |
| | train_adv_clip_max (`float`, *optional*, defaults to `5.0`): |
| | Clip advantages to the range. |
| | train_clip_range (`float`, *optional*, defaults to `1e-4`): |
| | PPO clip range. |
| | train_timestep_fraction (`float`, *optional*, defaults to `1.0`): |
| | Fraction of timesteps to train on. |
| | per_prompt_stat_tracking (`bool`, *optional*, defaults to `False`): |
| | Whether to track statistics for each prompt separately. |
| | per_prompt_stat_tracking_buffer_size (`int`, *optional*, defaults to `16`): |
| | Number of reward values to store in the buffer for each prompt. |
| | per_prompt_stat_tracking_min_count (`int`, *optional*, defaults to `16`): |
| | Minimum number of reward values to store in the buffer. |
| | async_reward_computation (`bool`, *optional*, defaults to `False`): |
| | Whether to compute rewards asynchronously. |
| | max_workers (`int`, *optional*, defaults to `2`): |
| | Maximum number of workers to use for async reward computation. |
| | negative_prompts (`str`, *optional*, defaults to `""`): |
| | Comma-separated list of prompts to use as negative examples. |
| | push_to_hub (`bool`, *optional*, defaults to `False`): |
| | Whether to push the final model checkpoint to the Hub. |
| | |
| | """ |
| | vllm_sampling_params: Optional[Any] = field( |
| | default = None, |
| | metadata = {'help': 'vLLM SamplingParams'}, |
| | ) |
| | unsloth_num_chunks : Optional[int] = field( |
| | default = -1, |
| | metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'}, |
| | ) |
| | |
| | def __init__( |
| | self, |
| | exp_name = 'colab_kernel_launcher', |
| | run_name = '', |
| | seed = 3407, |
| | log_with = None, |
| | tracker_project_name = 'trl', |
| | logdir = 'logs', |
| | num_epochs = 100, |
| | save_freq = 1, |
| | num_checkpoint_limit = 5, |
| | mixed_precision = 'fp16', |
| | allow_tf32 = True, |
| | resume_from = '', |
| | sample_num_steps = 50, |
| | sample_eta = 1.0, |
| | sample_guidance_scale = 5.0, |
| | sample_batch_size = 1, |
| | sample_num_batches_per_epoch = 2, |
| | train_batch_size = 1, |
| | train_use_8bit_adam = False, |
| | train_learning_rate = 5e-05, |
| | train_adam_beta1 = 0.9, |
| | train_adam_beta2 = 0.999, |
| | train_adam_weight_decay = 0.01, |
| | train_adam_epsilon = 1e-08, |
| | train_gradient_accumulation_steps = 2, |
| | train_max_grad_norm = 1.0, |
| | train_num_inner_epochs = 1, |
| | train_cfg = True, |
| | train_adv_clip_max = 5.0, |
| | train_clip_range = 0.0001, |
| | train_timestep_fraction = 1.0, |
| | per_prompt_stat_tracking = False, |
| | per_prompt_stat_tracking_buffer_size = 16, |
| | per_prompt_stat_tracking_min_count = 16, |
| | async_reward_computation = False, |
| | max_workers = 2, |
| | negative_prompts = '', |
| | push_to_hub = False, |
| | vllm_sampling_params = None, |
| | unsloth_num_chunks = -1, |
| | |
| | **kwargs, |
| | ): |
| | if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!') |
| | if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!') |
| | |
| | super().__init__( |
| | exp_name = exp_name, |
| | run_name = run_name, |
| | seed = seed, |
| | log_with = log_with, |
| | tracker_project_name = tracker_project_name, |
| | logdir = logdir, |
| | num_epochs = num_epochs, |
| | save_freq = save_freq, |
| | num_checkpoint_limit = num_checkpoint_limit, |
| | mixed_precision = mixed_precision, |
| | allow_tf32 = allow_tf32, |
| | resume_from = resume_from, |
| | sample_num_steps = sample_num_steps, |
| | sample_eta = sample_eta, |
| | sample_guidance_scale = sample_guidance_scale, |
| | sample_batch_size = sample_batch_size, |
| | sample_num_batches_per_epoch = sample_num_batches_per_epoch, |
| | train_batch_size = train_batch_size, |
| | train_use_8bit_adam = train_use_8bit_adam, |
| | train_learning_rate = train_learning_rate, |
| | train_adam_beta1 = train_adam_beta1, |
| | train_adam_beta2 = train_adam_beta2, |
| | train_adam_weight_decay = train_adam_weight_decay, |
| | train_adam_epsilon = train_adam_epsilon, |
| | train_gradient_accumulation_steps = train_gradient_accumulation_steps, |
| | train_max_grad_norm = train_max_grad_norm, |
| | train_num_inner_epochs = train_num_inner_epochs, |
| | train_cfg = train_cfg, |
| | train_adv_clip_max = train_adv_clip_max, |
| | train_clip_range = train_clip_range, |
| | train_timestep_fraction = train_timestep_fraction, |
| | per_prompt_stat_tracking = per_prompt_stat_tracking, |
| | per_prompt_stat_tracking_buffer_size = per_prompt_stat_tracking_buffer_size, |
| | per_prompt_stat_tracking_min_count = per_prompt_stat_tracking_min_count, |
| | async_reward_computation = async_reward_computation, |
| | max_workers = max_workers, |
| | negative_prompts = negative_prompts, |
| | push_to_hub = push_to_hub,**kwargs) |
| | self.vllm_sampling_params = vllm_sampling_params |
| | self.unsloth_num_chunks = unsloth_num_chunks |
| | |
| | pass |
| |
|
| | class _UnslothDDPOTrainer(PyTorchModelHubMixin): |
| | """""" |
| |
|
| | _tag_names = ["trl", "ddpo"] |
| |
|
| | def __init__( |
| | self, |
| | config: DDPOConfig, |
| | reward_function: Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor], |
| | prompt_function: Callable[[], tuple[str, Any]], |
| | sd_pipeline: DDPOStableDiffusionPipeline, |
| | image_samples_hook: Optional[Callable[[Any, Any, Any], Any]] = None, |
| | ): |
| | warnings.warn( |
| | "DDPOTrainer is deprecated and will be removed in version 0.23.0.", |
| | DeprecationWarning, |
| | ) |
| | if image_samples_hook is None: |
| | logger.warning("No image_samples_hook provided; no images will be logged") |
| |
|
| | self.prompt_fn = prompt_function |
| | self.reward_fn = reward_function |
| | self.config = config |
| | self.image_samples_callback = image_samples_hook |
| |
|
| | accelerator_project_config = ProjectConfiguration(**self.config.project_kwargs) |
| |
|
| | if self.config.resume_from: |
| | self.config.resume_from = os.path.normpath(os.path.expanduser(self.config.resume_from)) |
| | if "checkpoint_" not in os.path.basename(self.config.resume_from): |
| | |
| | checkpoints = list( |
| | filter( |
| | lambda x: "checkpoint_" in x, |
| | os.listdir(self.config.resume_from), |
| | ) |
| | ) |
| | if len(checkpoints) == 0: |
| | raise ValueError(f"No checkpoints found in {self.config.resume_from}") |
| | checkpoint_numbers = sorted([int(x.split("_")[-1]) for x in checkpoints]) |
| | self.config.resume_from = os.path.join( |
| | self.config.resume_from, |
| | f"checkpoint_{checkpoint_numbers[-1]}", |
| | ) |
| |
|
| | accelerator_project_config.iteration = checkpoint_numbers[-1] + 1 |
| |
|
| | |
| | self.num_train_timesteps = int(self.config.sample_num_steps * self.config.train_timestep_fraction) |
| |
|
| | self.accelerator = Accelerator( |
| | log_with=self.config.log_with, |
| | mixed_precision=self.config.mixed_precision, |
| | project_config=accelerator_project_config, |
| | |
| | |
| | |
| | gradient_accumulation_steps=self.config.train_gradient_accumulation_steps * self.num_train_timesteps, |
| | **self.config.accelerator_kwargs, |
| | ) |
| |
|
| | is_okay, message = self._config_check() |
| | if not is_okay: |
| | raise ValueError(message) |
| |
|
| | is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard" |
| |
|
| | if self.accelerator.is_main_process: |
| | self.accelerator.init_trackers( |
| | self.config.tracker_project_name, |
| | config=dict(ddpo_trainer_config=config.to_dict()) if not is_using_tensorboard else config.to_dict(), |
| | init_kwargs=self.config.tracker_kwargs, |
| | ) |
| |
|
| | logger.info(f"\n{config}") |
| |
|
| | set_seed(self.config.seed, device_specific=True) |
| |
|
| | self.sd_pipeline = sd_pipeline |
| |
|
| | self.sd_pipeline.set_progress_bar_config( |
| | position=1, |
| | disable=not self.accelerator.is_local_main_process, |
| | leave=False, |
| | desc="Timestep", |
| | dynamic_ncols=True, |
| | ) |
| |
|
| | |
| | |
| | if self.accelerator.mixed_precision == "fp16": |
| | inference_dtype = torch.float16 |
| | elif self.accelerator.mixed_precision == "bf16": |
| | inference_dtype = torch.bfloat16 |
| | else: |
| | inference_dtype = torch.float32 |
| |
|
| | self.sd_pipeline.vae.to(self.accelerator.device, dtype=inference_dtype) |
| | self.sd_pipeline.text_encoder.to(self.accelerator.device, dtype=inference_dtype) |
| | self.sd_pipeline.unet.to(self.accelerator.device, dtype=inference_dtype) |
| |
|
| | trainable_layers = self.sd_pipeline.get_trainable_layers() |
| |
|
| | self.accelerator.register_save_state_pre_hook(self._save_model_hook) |
| | self.accelerator.register_load_state_pre_hook(self._load_model_hook) |
| |
|
| | |
| | |
| | if self.config.allow_tf32 and torch.cuda.is_available(): |
| | torch.backends.cuda.matmul.allow_tf32 = True |
| |
|
| | self.optimizer = self._setup_optimizer( |
| | trainable_layers.parameters() if not isinstance(trainable_layers, list) else trainable_layers |
| | ) |
| |
|
| | self.neg_prompt_embed = self.sd_pipeline.text_encoder( |
| | self.sd_pipeline.tokenizer( |
| | [""] if self.config.negative_prompts is None else self.config.negative_prompts, |
| | return_tensors="pt", |
| | padding="max_length", |
| | truncation=True, |
| | max_length=self.sd_pipeline.tokenizer.model_max_length, |
| | ).input_ids.to(self.accelerator.device) |
| | )[0] |
| |
|
| | if config.per_prompt_stat_tracking: |
| | self.stat_tracker = PerPromptStatTracker( |
| | config.per_prompt_stat_tracking_buffer_size, |
| | config.per_prompt_stat_tracking_min_count, |
| | ) |
| |
|
| | |
| | |
| | self.autocast = self.sd_pipeline.autocast or self.accelerator.autocast |
| |
|
| | if hasattr(self.sd_pipeline, "use_lora") and self.sd_pipeline.use_lora: |
| | unet, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer) |
| | self.trainable_layers = list(filter(lambda p: p.requires_grad, unet.parameters())) |
| | else: |
| | self.trainable_layers, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer) |
| |
|
| | if self.config.async_reward_computation: |
| | self.executor = futures.ThreadPoolExecutor(max_workers=config.max_workers) |
| |
|
| | if config.resume_from: |
| | logger.info(f"Resuming from {config.resume_from}") |
| | self.accelerator.load_state(config.resume_from) |
| | self.first_epoch = int(config.resume_from.split("_")[-1]) + 1 |
| | else: |
| | self.first_epoch = 0 |
| |
|
| | def compute_rewards(self, prompt_image_pairs, is_async=False): |
| | if not is_async: |
| | rewards = [] |
| | for images, prompts, prompt_metadata in prompt_image_pairs: |
| | reward, reward_metadata = self.reward_fn(images, prompts, prompt_metadata) |
| | rewards.append( |
| | ( |
| | torch.as_tensor(reward, device=self.accelerator.device), |
| | reward_metadata, |
| | ) |
| | ) |
| | else: |
| | rewards = self.executor.map(lambda x: self.reward_fn(*x), prompt_image_pairs) |
| | rewards = [ |
| | (torch.as_tensor(reward.result(), device=self.accelerator.device), reward_metadata.result()) |
| | for reward, reward_metadata in rewards |
| | ] |
| |
|
| | return zip(*rewards) |
| |
|
| | def step(self, epoch: int, global_step: int): |
| | """ |
| | Perform a single step of training. |
| | |
| | Args: |
| | epoch (int): The current epoch. |
| | global_step (int): The current global step. |
| | |
| | Side Effects: |
| | - Model weights are updated |
| | - Logs the statistics to the accelerator trackers. |
| | - If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step, |
| | and the accelerator tracker. |
| | |
| | Returns: |
| | global_step (int): The updated global step. |
| | |
| | """ |
| | samples, prompt_image_data = self._generate_samples( |
| | iterations=self.config.sample_num_batches_per_epoch, |
| | batch_size=self.config.sample_batch_size, |
| | ) |
| |
|
| | |
| | samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()} |
| | rewards, rewards_metadata = self.compute_rewards( |
| | prompt_image_data, is_async=self.config.async_reward_computation |
| | ) |
| |
|
| | for i, image_data in enumerate(prompt_image_data): |
| | image_data.extend([rewards[i], rewards_metadata[i]]) |
| |
|
| | if self.image_samples_callback is not None: |
| | self.image_samples_callback(prompt_image_data, global_step, self.accelerator.trackers[0]) |
| |
|
| | rewards = torch.cat(rewards) |
| | rewards = self.accelerator.gather(rewards).cpu().numpy() |
| |
|
| | self.accelerator.log( |
| | { |
| | "reward": rewards, |
| | "epoch": epoch, |
| | "reward_mean": rewards.mean(), |
| | "reward_std": rewards.std(), |
| | }, |
| | step=global_step, |
| | ) |
| |
|
| | if self.config.per_prompt_stat_tracking: |
| | |
| | prompt_ids = self.accelerator.gather(samples["prompt_ids"]).cpu().numpy() |
| | prompts = self.sd_pipeline.tokenizer.batch_decode(prompt_ids, skip_special_tokens=True) |
| | advantages = self.stat_tracker.update(prompts, rewards) |
| | else: |
| | advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-8) |
| |
|
| | |
| | samples["advantages"] = ( |
| | torch.as_tensor(advantages) |
| | .reshape(self.accelerator.num_processes, -1)[self.accelerator.process_index] |
| | .to(self.accelerator.device) |
| | ) |
| |
|
| | del samples["prompt_ids"] |
| |
|
| | total_batch_size, num_timesteps = samples["timesteps"].shape |
| |
|
| | for inner_epoch in range(self.config.train_num_inner_epochs): |
| | |
| | perm = torch.randperm(total_batch_size, device=self.accelerator.device) |
| | samples = {k: v[perm] for k, v in samples.items()} |
| |
|
| | |
| | |
| | perms = torch.stack( |
| | [torch.randperm(num_timesteps, device=self.accelerator.device) for _ in range(total_batch_size)] |
| | ) |
| |
|
| | for key in ["timesteps", "latents", "next_latents", "log_probs"]: |
| | samples[key] = samples[key][ |
| | torch.arange(total_batch_size, device=self.accelerator.device)[:, None], |
| | perms, |
| | ] |
| |
|
| | original_keys = samples.keys() |
| | original_values = samples.values() |
| | |
| | reshaped_values = [v.reshape(-1, self.config.train_batch_size, *v.shape[1:]) for v in original_values] |
| |
|
| | |
| | transposed_values = zip(*reshaped_values) |
| | |
| | samples_batched = [dict(zip(original_keys, row_values)) for row_values in transposed_values] |
| |
|
| | self.sd_pipeline.unet.train() |
| | global_step = self._train_batched_samples(inner_epoch, epoch, global_step, samples_batched) |
| | |
| | if not self.accelerator.sync_gradients: |
| | raise ValueError( |
| | "Optimization step should have been performed by this point. Please check calculated gradient accumulation settings." |
| | ) |
| |
|
| | if epoch != 0 and epoch % self.config.save_freq == 0 and self.accelerator.is_main_process: |
| | self.accelerator.save_state() |
| |
|
| | return global_step |
| |
|
| | def calculate_loss(self, latents, timesteps, next_latents, log_probs, advantages, embeds): |
| | """ |
| | Calculate the loss for a batch of an unpacked sample |
| | |
| | Args: |
| | latents (torch.Tensor): |
| | The latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height, width] |
| | timesteps (torch.Tensor): |
| | The timesteps sampled from the diffusion model, shape: [batch_size] |
| | next_latents (torch.Tensor): |
| | The next latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height, |
| | width] |
| | log_probs (torch.Tensor): |
| | The log probabilities of the latents, shape: [batch_size] |
| | advantages (torch.Tensor): |
| | The advantages of the latents, shape: [batch_size] |
| | embeds (torch.Tensor): |
| | The embeddings of the prompts, shape: [2*batch_size or batch_size, ...] Note: the "or" is because if |
| | train_cfg is True, the expectation is that negative prompts are concatenated to the embeds |
| | |
| | Returns: |
| | loss (torch.Tensor), approx_kl (torch.Tensor), clipfrac (torch.Tensor) (all of these are of shape (1,)) |
| | """ |
| | with self.autocast(): |
| | if self.config.train_cfg: |
| | noise_pred = self.sd_pipeline.unet( |
| | torch.cat([latents] * 2), |
| | torch.cat([timesteps] * 2), |
| | embeds, |
| | ).sample |
| | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
| | noise_pred = noise_pred_uncond + self.config.sample_guidance_scale * ( |
| | noise_pred_text - noise_pred_uncond |
| | ) |
| | else: |
| | noise_pred = self.sd_pipeline.unet( |
| | latents, |
| | timesteps, |
| | embeds, |
| | ).sample |
| | |
| |
|
| | scheduler_step_output = self.sd_pipeline.scheduler_step( |
| | noise_pred, |
| | timesteps, |
| | latents, |
| | eta=self.config.sample_eta, |
| | prev_sample=next_latents, |
| | ) |
| |
|
| | log_prob = scheduler_step_output.log_probs |
| |
|
| | advantages = torch.clamp( |
| | advantages, |
| | -self.config.train_adv_clip_max, |
| | self.config.train_adv_clip_max, |
| | ) |
| |
|
| | ratio = torch.exp(log_prob - log_probs) |
| |
|
| | loss = self.loss(advantages, self.config.train_clip_range, ratio) |
| |
|
| | approx_kl = 0.5 * torch.mean((log_prob - log_probs) ** 2) |
| |
|
| | clipfrac = torch.mean((torch.abs(ratio - 1.0) > self.config.train_clip_range).float()) |
| |
|
| | return loss, approx_kl, clipfrac |
| |
|
| | def loss( |
| | self, |
| | advantages: torch.Tensor, |
| | clip_range: float, |
| | ratio: torch.Tensor, |
| | ): |
| | unclipped_loss = -advantages * ratio |
| | clipped_loss = -advantages * torch.clamp( |
| | ratio, |
| | 1.0 - clip_range, |
| | 1.0 + clip_range, |
| | ) |
| | return torch.mean(torch.maximum(unclipped_loss, clipped_loss)) |
| |
|
| | def _setup_optimizer(self, trainable_layers_parameters): |
| | if self.config.train_use_8bit_adam: |
| | import bitsandbytes |
| |
|
| | optimizer_cls = bitsandbytes.optim.AdamW8bit |
| | else: |
| | optimizer_cls = torch.optim.AdamW |
| |
|
| | return optimizer_cls( |
| | trainable_layers_parameters, |
| | lr=self.config.train_learning_rate, |
| | betas=(self.config.train_adam_beta1, self.config.train_adam_beta2), |
| | weight_decay=self.config.train_adam_weight_decay, |
| | eps=self.config.train_adam_epsilon, |
| | ) |
| |
|
| | def _save_model_hook(self, models, weights, output_dir): |
| | self.sd_pipeline.save_checkpoint(models, weights, output_dir) |
| | weights.pop() |
| |
|
| | def _load_model_hook(self, models, input_dir): |
| | self.sd_pipeline.load_checkpoint(models, input_dir) |
| | models.pop() |
| |
|
| | def _generate_samples(self, iterations, batch_size): |
| | """ |
| | Generate samples from the model |
| | |
| | Args: |
| | iterations (int): Number of iterations to generate samples for |
| | batch_size (int): Batch size to use for sampling |
| | |
| | Returns: |
| | samples (list[dict[str, torch.Tensor]]), prompt_image_pairs (list[list[Any]]) |
| | """ |
| | samples = [] |
| | prompt_image_pairs = [] |
| | self.sd_pipeline.unet.eval() |
| |
|
| | sample_neg_prompt_embeds = self.neg_prompt_embed.repeat(batch_size, 1, 1) |
| |
|
| | for _ in range(iterations): |
| | prompts, prompt_metadata = zip(*[self.prompt_fn() for _ in range(batch_size)]) |
| |
|
| | prompt_ids = self.sd_pipeline.tokenizer( |
| | prompts, |
| | return_tensors="pt", |
| | padding="max_length", |
| | truncation=True, |
| | max_length=self.sd_pipeline.tokenizer.model_max_length, |
| | ).input_ids.to(self.accelerator.device) |
| | prompt_embeds = self.sd_pipeline.text_encoder(prompt_ids)[0] |
| |
|
| | with self.autocast(): |
| | sd_output = self.sd_pipeline( |
| | prompt_embeds=prompt_embeds, |
| | negative_prompt_embeds=sample_neg_prompt_embeds, |
| | num_inference_steps=self.config.sample_num_steps, |
| | guidance_scale=self.config.sample_guidance_scale, |
| | eta=self.config.sample_eta, |
| | output_type="pt", |
| | ) |
| |
|
| | images = sd_output.images |
| | latents = sd_output.latents |
| | log_probs = sd_output.log_probs |
| |
|
| | latents = torch.stack(latents, dim=1) |
| | log_probs = torch.stack(log_probs, dim=1) |
| | timesteps = self.sd_pipeline.scheduler.timesteps.repeat(batch_size, 1) |
| |
|
| | samples.append( |
| | { |
| | "prompt_ids": prompt_ids, |
| | "prompt_embeds": prompt_embeds, |
| | "timesteps": timesteps, |
| | "latents": latents[:, :-1], |
| | "next_latents": latents[:, 1:], |
| | "log_probs": log_probs, |
| | "negative_prompt_embeds": sample_neg_prompt_embeds, |
| | } |
| | ) |
| | prompt_image_pairs.append([images, prompts, prompt_metadata]) |
| |
|
| | return samples, prompt_image_pairs |
| |
|
| | def _train_batched_samples(self, inner_epoch, epoch, global_step, batched_samples): |
| | """ |
| | Train on a batch of samples. Main training segment |
| | |
| | Args: |
| | inner_epoch (int): The current inner epoch |
| | epoch (int): The current epoch |
| | global_step (int): The current global step |
| | batched_samples (list[dict[str, torch.Tensor]]): The batched samples to train on |
| | |
| | Side Effects: |
| | - Model weights are updated |
| | - Logs the statistics to the accelerator trackers. |
| | |
| | Returns: |
| | global_step (int): The updated global step |
| | """ |
| | info = defaultdict(list) |
| | for _i, sample in enumerate(batched_samples): |
| | if self.config.train_cfg: |
| | |
| | embeds = torch.cat([sample["negative_prompt_embeds"], sample["prompt_embeds"]]) |
| | else: |
| | embeds = sample["prompt_embeds"] |
| |
|
| | for j in range(self.num_train_timesteps): |
| | with self.accelerator.accumulate(self.sd_pipeline.unet): |
| | loss, approx_kl, clipfrac = self.calculate_loss( |
| | sample["latents"][:, j], |
| | sample["timesteps"][:, j], |
| | sample["next_latents"][:, j], |
| | sample["log_probs"][:, j], |
| | sample["advantages"], |
| | embeds, |
| | ) |
| | info["approx_kl"].append(approx_kl) |
| | info["clipfrac"].append(clipfrac) |
| | info["loss"].append(loss) |
| |
|
| | self.accelerator.backward(loss) |
| | if self.accelerator.sync_gradients: |
| | self.accelerator.clip_grad_norm_( |
| | self.trainable_layers.parameters() |
| | if not isinstance(self.trainable_layers, list) |
| | else self.trainable_layers, |
| | self.config.train_max_grad_norm, |
| | ) |
| | self.optimizer.step() |
| | self.optimizer.zero_grad() |
| |
|
| | |
| | if self.accelerator.sync_gradients: |
| | |
| | info = {k: torch.mean(torch.stack(v)) for k, v in info.items()} |
| | info = self.accelerator.reduce(info, reduction="mean") |
| | info.update({"epoch": epoch, "inner_epoch": inner_epoch}) |
| | self.accelerator.log(info, step=global_step) |
| | global_step += 1 |
| | info = defaultdict(list) |
| | return global_step |
| |
|
| | def _config_check(self) -> tuple[bool, str]: |
| | samples_per_epoch = ( |
| | self.config.sample_batch_size * self.accelerator.num_processes * self.config.sample_num_batches_per_epoch |
| | ) |
| | total_train_batch_size = ( |
| | self.config.train_batch_size |
| | * self.accelerator.num_processes |
| | * self.config.train_gradient_accumulation_steps |
| | ) |
| |
|
| | if not self.config.sample_batch_size >= self.config.train_batch_size: |
| | return ( |
| | False, |
| | f"Sample batch size ({self.config.sample_batch_size}) must be greater than or equal to the train batch size ({self.config.train_batch_size})", |
| | ) |
| | if not self.config.sample_batch_size % self.config.train_batch_size == 0: |
| | return ( |
| | False, |
| | f"Sample batch size ({self.config.sample_batch_size}) must be divisible by the train batch size ({self.config.train_batch_size})", |
| | ) |
| | if not samples_per_epoch % total_train_batch_size == 0: |
| | return ( |
| | False, |
| | f"Number of samples per epoch ({samples_per_epoch}) must be divisible by the total train batch size ({total_train_batch_size})", |
| | ) |
| | return True, "" |
| |
|
| | def train(self, epochs: Optional[int] = None): |
| | """ |
| | Train the model for a given number of epochs |
| | """ |
| | global_step = 0 |
| | if epochs is None: |
| | epochs = self.config.num_epochs |
| | for epoch in range(self.first_epoch, epochs): |
| | global_step = self.step(epoch, global_step) |
| |
|
| | def _save_pretrained(self, save_directory): |
| | self.sd_pipeline.save_pretrained(save_directory) |
| | self.create_model_card() |
| |
|
| | |
| | def _save_checkpoint(self, model, trial): |
| | if self.args.hub_model_id is None: |
| | model_name = Path(self.args.output_dir).name |
| | else: |
| | model_name = self.args.hub_model_id.split("/")[-1] |
| | self.create_model_card(model_name=model_name) |
| | super()._save_checkpoint(model, trial) |
| |
|
| | def create_model_card( |
| | self, |
| | model_name: Optional[str] = None, |
| | dataset_name: Optional[str] = None, |
| | tags: Union[str, list[str], None] = None, |
| | ): |
| | """ |
| | Creates a draft of a model card using the information available to the `Trainer`. |
| | |
| | Args: |
| | model_name (`str` or `None`, *optional*, defaults to `None`): |
| | Name of the model. |
| | dataset_name (`str` or `None`, *optional*, defaults to `None`): |
| | Name of the dataset used for training. |
| | tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`): |
| | Tags to be associated with the model card. |
| | """ |
| | if not self.is_world_process_zero(): |
| | return |
| |
|
| | if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): |
| | base_model = self.model.config._name_or_path |
| | else: |
| | base_model = None |
| |
|
| | |
| | if tags is None: |
| | tags = set() |
| | elif isinstance(tags, str): |
| | tags = {tags} |
| | else: |
| | tags = set(tags) |
| |
|
| | if hasattr(self.model.config, "unsloth_version"): |
| | tags.add("unsloth") |
| |
|
| | if "JOB_ID" in os.environ: |
| | tags.add("hf_jobs") |
| |
|
| | tags.update(self._tag_names) |
| |
|
| | |
| | citation = textwrap.dedent("""\ |
| | @inproceedings{black2024training, |
| | title = {{Training Diffusion Models with Reinforcement Learning}}, |
| | author = {Kevin Black and Michael Janner and Yilun Du and Ilya Kostrikov and Sergey Levine}, |
| | year = 2024, |
| | booktitle = {The Twelfth International Conference on Learning Representations, {ICLR} 2024, Vienna, Austria, May 7-11, 2024}, |
| | publisher = {OpenReview.net}, |
| | url = {https://openreview.net/forum?id=YCWjhGrJFD}, |
| | }""") |
| |
|
| | model_card = generate_model_card( |
| | base_model=base_model, |
| | model_name=model_name, |
| | hub_model_id=self.hub_model_id, |
| | dataset_name=dataset_name, |
| | tags=tags, |
| | wandb_url=wandb.run.url if is_wandb_available() and wandb.run is not None else None, |
| | comet_url=get_comet_experiment_url(), |
| | trainer_name="DDPO", |
| | trainer_citation=citation, |
| | paper_title="Training Diffusion Models with Reinforcement Learning", |
| | paper_id="2305.13301", |
| | ) |
| |
|
| | model_card.save(os.path.join(self.args.output_dir, "README.md")) |
| | class UnslothDDPOTrainer(_UnslothDDPOTrainer): |
| | """ |
| | |
| | The DDPOTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models. Note, this trainer is heavily |
| | inspired by the work here: https://github.com/kvablack/ddpo-pytorch As of now only Stable Diffusion based pipelines |
| | are supported |
| | |
| | Attributes: |
| | **config** (`DDPOConfig`) -- Configuration object for DDPOTrainer. Check the documentation of `PPOConfig` for more: |
| | details. |
| | **reward_function** (Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor]) -- Reward function to be used: |
| | **prompt_function** (Callable[[], tuple[str, Any]]) -- Function to generate prompts to guide model |
| | **sd_pipeline** (`DDPOStableDiffusionPipeline`) -- Stable Diffusion pipeline to be used for training. |
| | **image_samples_hook** (Optional[Callable[[Any, Any, Any], Any]]) -- Hook to be called to log images |
| | |
| | """ |
| | def __init__( |
| | self, |
| | config, |
| | reward_function, |
| | prompt_function, |
| | sd_pipeline, |
| | image_samples_hook = None, |
| | **kwargs |
| | ): |
| | if args is None: args = UnslothDDPOConfig() |
| | other_metrics = [] |
| | |
| | from unsloth_zoo.logging_utils import PatchRLStatistics |
| | PatchRLStatistics('ddpo_trainer', other_metrics) |
| | |
| | |
| | |
| | if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1: |
| | if getattr(args, "_n_gpu", 1) != 1: |
| | args._n_gpu = 1 |
| | if "model" in locals() and hasattr(model, "for_training"): |
| | model.for_training() |
| | super().__init__( |
| | config = config, |
| | reward_function = reward_function, |
| | prompt_function = prompt_function, |
| | sd_pipeline = sd_pipeline, |
| | image_samples_hook = image_samples_hook,**kwargs) |
| | if "model" in locals() and hasattr(model, "for_inference"): |
| | model.for_inference() |
| | |
| | pass |
| |
|
| |
|
| | if hasattr(logger, "addFilter"): |
| | import logging |
| | class HideLoggingMessage(logging.Filter): |
| | def __init__(self, text): self.text = text |
| | def filter(self, x): return not (self.text in x.getMessage()) |
| | pass |
| | logger.addFilter(HideLoggingMessage("`use_cache=True`")) |
| |
|
| |
|