| |
| |
| |
| |
| |
|
|
| from contextlib import nullcontext |
| from math import ceil |
| from typing import Callable, Optional, Union |
|
|
| import torch |
| from torch import Tensor |
| import gc |
| from torch.nn import functional as F |
|
|
| from flow_matching.path import MixtureDiscreteProbPath |
|
|
| from flow_matching.solver.solver import Solver |
| from flow_matching.utils import categorical, ModelWrapper |
| from .utils import get_nearest_times |
| from ..utils.multi_guidance import * |
|
|
| try: |
| from tqdm import tqdm |
|
|
| TQDM_AVAILABLE = True |
| except ImportError: |
| TQDM_AVAILABLE = False |
|
|
|
|
| class MixtureDiscreteEulerSolver(Solver): |
| r"""Solver that simulates the CTMC process :math:`(X_t)_{t_{\text{init}}\leq t\leq t_{\text{final}}}` defined by :math:`p_t` the marginal probability path of ``path``. |
| Given :math:`X_t \sim p_t`, the algorithm of solver step from :math:`t` to :math:`t+h` for the i-th coordinate is: |
| |
| .. math:: |
| |
| \begin{align*} |
| & X_1^i \sim p_{1|t}^i(\cdot|X_t)\\ |
| & \lambda^i \gets \sum_{x^i\ne X_t^i} u_t^i(x^i, X_t^i|X_1^i)\\ |
| & Z^i_{\text{change}} \sim U[0,1]\\ |
| & X_{t+h}^i \sim \begin{cases} |
| \frac{u_t^i(\cdot, X_t^i|X_1^i)}{\lambda^i}(1-\delta_{X_t^i}(\cdot)) \text{ if $Z^i_{\text{change}}\le 1-e^{-h\lambda^i}$}\\ |
| \delta_{X_t^i}(\cdot) \text{ else } |
| \end{cases} |
| \end{align*} |
| |
| Where :math:`p_{1|t}(\cdot|X_t)` is the output of ``model``, and the conditional probability velocity is of the mixture probability path is: |
| |
| .. math:: |
| |
| u_t^i(x^i, y^i|x_1^i) = \hat{u}_t^i(x^i, y^i|x_1^i) + c_{\text{div\_free}}\left[\hat{u}_t^i(x^i, y^i|x_1^i) - \check{u}_t^i(x^i, y^i|x_1^i) \right], |
| |
| where |
| |
| .. math:: |
| \hat{u}_t^i(x^i, y^i|x_1^i) = \frac{\dot{\kappa}_t}{1-\kappa_t} \left[ \delta_{x_1^i}(x^i) - \delta_{y^i}(x^i) \right], |
| |
| and |
| |
| .. math:: |
| |
| \check{u}_t^i(x^i, y^i|x_1^i) = \frac{\dot{\kappa}_t}{\kappa_t}\left[ \delta_{y^i}(x^i) - p(x^i) \right]. |
| |
| The source distribution :math:`p(x^i)` is given by ``p``. |
| |
| Args: |
| model (ModelWrapper): trained with x-prediction, outputting posterior probabilities (in the range :math:`[0,1]`), output must be [..., vocabulary_size]. |
| path (MixtureDiscreteProbPath): Probability path used for x-prediction training. |
| vocabulary_size (int): size of the discrete vocabulary. |
| source_distribution_p (Optional[Tensor], optional): Source distribution, must be of shape [vocabulary_size]. Required only when divergence-free term for the probability velocity is non-zero. Defaults to None. |
| """ |
|
|
| def __init__( |
| self, |
| model: ModelWrapper, |
| path: MixtureDiscreteProbPath, |
| vocabulary_size: int, |
| source_distribution_p: Optional[Tensor] = None, |
| ): |
| super().__init__() |
| self.model = model |
| self.path = path |
| self.vocabulary_size = vocabulary_size |
|
|
| if source_distribution_p is not None: |
| assert source_distribution_p.shape == torch.Size( |
| [vocabulary_size] |
| ), f"Source distribution p dimension must match the vocabulary size {vocabulary_size}. Got {source_distribution_p.shape}." |
|
|
| self.source_distribution_p = source_distribution_p |
|
|
| @torch.no_grad() |
| def sample( |
| self, |
| x_init: Tensor, |
| step_size: Optional[float], |
| div_free: Union[float, Callable[[float], float]] = 0.0, |
| dtype_categorical: torch.dtype = torch.float32, |
| time_grid: Tensor = torch.tensor([0.0, 1.0]), |
| return_intermediates: bool = False, |
| verbose: bool = False, |
| **model_extras, |
| ) -> Tensor: |
| """ |
| Sample a sequence of discrete values from the given model. |
| |
| .. code-block:: python |
| |
| import torch |
| from flow_matching.utils import ModelWrapper |
| from flow_matching.solver import MixtureDiscreteEulerSolver |
| |
| class DummyModel(ModelWrapper): |
| def __init__(self): |
| super().__init__(None) |
| def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor: |
| return ... |
| |
| model = DummyModel() |
| solver = MixtureDiscreteEulerSolver(model=model) |
| |
| x_init = torch.LongTensor([122, 725]) |
| step_size = 0.001 |
| time_grid = torch.tensor([0.0, 1.0]) |
| |
| result = solver.sample(x_init=x_init, step_size=step_size, time_grid=time_grid) |
| |
| Args: |
| x_init (Tensor): The initial state. |
| step_size (Optional[float]): If float then time discretization is uniform with the given step size. If None then time discretization is set to be time_grid. |
| div_free (Union[float, Callable[[float], float]]): The coefficient of the divergence-free term in the probability velocity. Can be either a float or a time dependent function. Defaults to 0.0. |
| dtype_categorical (torch.dtype): Precision to use for categorical sampler. Defaults to torch.float32. |
| time_grid (Tensor): The CTMC process is solved in the interval [time_grid[0], time_grid[-1]] and if step_size is None then time discretization is set by the time grid. Defaults to torch.tensor([0.0,1.0]). |
| return_intermediates (bool): If True then return intermediate time steps according to time_grid. Defaults to False. |
| verbose (bool): Whether to print progress bars. Defaults to False. |
| **model_extras: Additional input for the model. |
| |
| Returns: |
| Tensor: The sampled sequence of discrete values. |
| |
| Raises: |
| ImportError: To run in verbose mode, tqdm must be installed. |
| """ |
| if not div_free == 0.0: |
| assert ( |
| self.source_distribution_p is not None |
| ), "Source distribution p must be specified in order to add a divergence-free term to the probability velocity." |
|
|
| |
| time_grid = time_grid.to(device=x_init.device) |
|
|
| if step_size is None: |
| |
| t_discretization = time_grid |
| n_steps = len(time_grid) - 1 |
| else: |
| |
| t_init = time_grid[0].item() |
| t_final = time_grid[-1].item() |
| assert ( |
| t_final - t_init |
| ) > step_size, f"Time interval [time_grid[0], time_grid[-1]] must be larger than step_size. Got a time interval [{t_init}, {t_final}] and step_size {step_size}." |
|
|
| n_steps = ceil((t_final - t_init) / step_size) |
| t_discretization = torch.tensor( |
| [t_init + step_size * i for i in range(n_steps)] + [t_final], |
| device=x_init.device, |
| ) |
|
|
| if return_intermediates: |
| |
| order = torch.argsort(time_grid) |
| |
| time_grid = get_nearest_times( |
| time_grid=time_grid, t_discretization=t_discretization |
| ) |
|
|
| x_t = x_init.clone() |
| steps_counter = 0 |
| res = [] |
|
|
| if return_intermediates: |
| res = [x_init.clone()] |
|
|
| if verbose: |
| if not TQDM_AVAILABLE: |
| raise ImportError( |
| "tqdm is required for verbose mode. Please install it." |
| ) |
| ctx = tqdm(total=t_final, desc=f"NFE: {steps_counter}") |
| else: |
| ctx = nullcontext() |
|
|
| with ctx: |
| for i in range(n_steps): |
| t = t_discretization[i : i + 1] |
| h = t_discretization[i + 1 : i + 2] - t_discretization[i : i + 1] |
|
|
| |
| p_1t = self.model(x=x_t, t=t.repeat(x_t.shape[0]), **model_extras) |
| x_1 = categorical(p_1t.to(dtype=dtype_categorical)) |
|
|
| |
| if i == n_steps - 1: |
| x_t = x_1 |
| else: |
| |
| scheduler_output = self.path.scheduler(t=t) |
|
|
| k_t = scheduler_output.alpha_t |
| d_k_t = scheduler_output.d_alpha_t |
|
|
| delta_1 = F.one_hot(x_1, num_classes=self.vocabulary_size).to( |
| k_t.dtype |
| ) |
| u = d_k_t / (1 - k_t) * delta_1 |
|
|
| |
| div_free_t = div_free(t) if callable(div_free) else div_free |
|
|
| if div_free_t > 0: |
| p_0 = self.source_distribution_p[(None,) * x_t.dim()] |
| u = u + div_free_t * d_k_t / (k_t * (1 - k_t)) * ( |
| (1 - k_t) * p_0 + k_t * delta_1 |
| ) |
|
|
| |
| delta_t = F.one_hot(x_t, num_classes=self.vocabulary_size) |
| u = torch.where( |
| delta_t.to(dtype=torch.bool), torch.zeros_like(u), u |
| ) |
| |
| |
| |
| |
| intensity = u.sum(dim=-1) |
| mask_jump = torch.rand(size=x_t.shape, device=x_t.device) < 1 - torch.exp(-h * intensity) |
|
|
| if mask_jump.sum() > 0: |
| x_t[mask_jump] = categorical( |
| u[mask_jump].to(dtype=dtype_categorical) |
| ) |
|
|
| steps_counter += 1 |
| t = t + h |
|
|
| if return_intermediates and (t in time_grid): |
| res.append(x_t.clone()) |
|
|
| if verbose: |
| ctx.n = t.item() |
| ctx.refresh() |
| ctx.set_description(f"NFE: {steps_counter}") |
|
|
| if return_intermediates: |
| if step_size is None: |
| return torch.stack(res, dim=0) |
| else: |
| return torch.stack(res, dim=0)[order] |
| else: |
| return x_t |
|
|
|
|
| @torch.no_grad() |
| def multi_guidance_sample( |
| self, |
| args, |
| x_init: Tensor, |
| step_size: Optional[float], |
| div_free: Union[float, Callable[[float], float]] = 0.0, |
| dtype_categorical: torch.dtype = torch.float32, |
| time_grid: Tensor = torch.tensor([0.0, 1.0]), |
| return_intermediates: bool = False, |
| verbose: bool = False, |
| score_models: list = None, |
| num_objectives: int = 1, |
| weights: list = None, |
| **model_extras, |
| ) -> Tensor: |
|
|
| |
| |
| |
| |
| |
| |
|
|
| import pdb |
|
|
| if not div_free == 0.0: |
| raise NotImplementedError |
|
|
| |
| time_grid = time_grid.to(device=x_init.device) |
|
|
| if step_size is None: |
| |
| t_discretization = time_grid |
| n_steps = len(time_grid) - 1 |
| else: |
| |
| t_init = time_grid[0].item() |
| t_final = time_grid[-1].item() |
| assert ( |
| t_final - t_init |
| ) > step_size, f"Time interval [time_grid[0], time_grid[-1]] must be larger than step_size. Got a time interval [{t_init}, {t_final}] and step_size {step_size}." |
|
|
| n_steps = ceil((t_final - t_init) / step_size) |
| t_discretization = torch.tensor( |
| [t_init + step_size * i for i in range(n_steps)] + [t_final], |
| device=x_init.device, |
| ) |
|
|
| if return_intermediates: |
| |
| order = torch.argsort(time_grid) |
| |
| time_grid = get_nearest_times( |
| time_grid=time_grid, t_discretization=t_discretization |
| ) |
|
|
| x_t = x_init.clone() |
| steps_counter = 0 |
| res = [] |
|
|
| if return_intermediates: |
| res = [x_init.clone()] |
|
|
| if verbose: |
| if not TQDM_AVAILABLE: |
| raise ImportError( |
| "tqdm is required for verbose mode. Please install it." |
| ) |
| ctx = tqdm(total=t_final, desc=f"NFE: {steps_counter}") |
| else: |
| ctx = nullcontext() |
|
|
| |
| if weights is not None: |
| w = torch.tensor(weights).to(device=x_init.device) |
| else: |
| w, _ = select_random_weight_vector(num_objectives, args.num_div) |
| |
| w = w.to(device=x_init.device) |
| print(f"Weight Vector: {w}") |
| Phi = args.Phi_init |
| ema_r_t = None |
|
|
| with ctx: |
| for i in range(n_steps): |
| t = t_discretization[i : i + 1] |
| h = t_discretization[i + 1 : i + 2] - t_discretization[i : i + 1] |
|
|
| p_1t = self.model(x=x_t, t=t.repeat(x_t.shape[0]), **model_extras) |
| x_1 = categorical(p_1t.to(dtype=dtype_categorical)) |
|
|
| |
| if i != n_steps - 1: |
| |
| scheduler_output = self.path.scheduler(t=t) |
| k_t = scheduler_output.alpha_t |
| d_k_t = scheduler_output.d_alpha_t |
| u_t = d_k_t / (1 - k_t) * p_1t |
|
|
| guided_u_t, pos_indices, cand_tokens, improvement_values, delta_S = guided_transition_scoring(x_t, u_t, w, score_models, t, w, args) |
|
|
| best_candidate, accepted_mask, valid_mask, Phi, ema_r_t = adaptive_hypercone_filtering(improvement_values, cand_tokens, delta_S, w, Phi, args, ema_r_t=ema_r_t) |
| |
| |
|
|
| |
|
|
| x_t = euler_sample(x_t, pos_indices, best_candidate, guided_u_t, h) |
|
|
|
|
| steps_counter += 1 |
| t = t + h |
|
|
| scores = [] |
| for i, s in enumerate(score_models): |
| sig = inspect.signature(s.forward) if hasattr(s, 'forward') else inspect.signature(s) |
| if 't' in sig.parameters: |
| candidate_scores = s(x_t, 1) |
| else: |
| candidate_scores = s(x_t) |
|
|
| if isinstance(candidate_scores, tuple): |
| for score in candidate_scores: |
| scores.append(score.item()) |
| else: |
| scores.append(candidate_scores.item()) |
| print(scores) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| if return_intermediates and (t in time_grid): |
| res.append(x_t.clone()) |
|
|
| if verbose: |
| ctx.n = t.item() |
| ctx.refresh() |
| ctx.set_description(f"NFE: {steps_counter}") |
|
|
| |
| if return_intermediates: |
| if step_size is None: |
| return torch.stack(res, dim=0) |
| else: |
| return torch.stack(res, dim=0)[order] |
| else: |
| |
| return x_t |