| import torch |
| import math |
| from tqdm import trange, tqdm |
|
|
| import k_diffusion as K |
| |
| |
| def get_alphas_sigmas(t): |
| """Returns the scaling factors for the clean image (alpha) and for the |
| noise (sigma), given a timestep.""" |
| return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2) |
|
|
| def alpha_sigma_to_t(alpha, sigma): |
| """Returns a timestep, given the scaling factors for the clean image and for |
| the noise.""" |
| return torch.atan2(sigma, alpha) / math.pi * 2 |
|
|
| def t_to_alpha_sigma(t): |
| """Returns the scaling factors for the clean image and for the noise, given |
| a timestep.""" |
| return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2) |
|
|
|
|
| @torch.no_grad() |
| def sample_discrete_euler(model, x, steps, sigma_max=1, **extra_args): |
| """Draws samples from a model given starting noise. Euler method""" |
|
|
| |
| ts = x.new_ones([x.shape[0]]) |
|
|
| |
| t = torch.linspace(sigma_max, 0, steps + 1) |
|
|
| |
|
|
| for t_curr, t_prev in tqdm(zip(t[:-1], t[1:])): |
| |
| t_curr_tensor = t_curr * torch.ones( |
| (x.shape[0],), dtype=x.dtype, device=x.device |
| ) |
| dt = t_prev - t_curr |
| x = x + dt * model(x, t_curr_tensor, **extra_args) |
|
|
| |
| return x |
|
|
| @torch.no_grad() |
| def sample(model, x, steps, eta, **extra_args): |
| """Draws samples from a model given starting noise. v-diffusion""" |
| ts = x.new_ones([x.shape[0]]) |
|
|
| |
| t = torch.linspace(1, 0, steps + 1)[:-1] |
|
|
| alphas, sigmas = get_alphas_sigmas(t) |
|
|
| |
| for i in trange(steps): |
|
|
| |
| with torch.cuda.amp.autocast(): |
| v = model(x, ts * t[i], **extra_args).float() |
|
|
| |
| pred = x * alphas[i] - v * sigmas[i] |
| eps = x * sigmas[i] + v * alphas[i] |
|
|
| |
| |
| if i < steps - 1: |
| |
| |
| ddim_sigma = eta * (sigmas[i + 1]**2 / sigmas[i]**2).sqrt() * \ |
| (1 - alphas[i]**2 / alphas[i + 1]**2).sqrt() |
| adjusted_sigma = (sigmas[i + 1]**2 - ddim_sigma**2).sqrt() |
|
|
| |
| |
| x = pred * alphas[i + 1] + eps * adjusted_sigma |
|
|
| |
| if eta: |
| x += torch.randn_like(x) * ddim_sigma |
|
|
| |
| return pred |
|
|
| |
| |
| def get_bmask(i, steps, mask): |
| strength = (i+1)/(steps) |
| |
| bmask = torch.where(mask<=strength,1,0) |
| return bmask |
|
|
| def make_cond_model_fn(model, cond_fn): |
| def cond_model_fn(x, sigma, **kwargs): |
| with torch.enable_grad(): |
| x = x.detach().requires_grad_() |
| denoised = model(x, sigma, **kwargs) |
| cond_grad = cond_fn(x, sigma, denoised=denoised, **kwargs).detach() |
| cond_denoised = denoised.detach() + cond_grad * K.utils.append_dims(sigma**2, x.ndim) |
| return cond_denoised |
| return cond_model_fn |
|
|
| |
| |
| |
| |
| |
| def sample_k( |
| model_fn, |
| noise, |
| init_data=None, |
| mask=None, |
| steps=100, |
| sampler_type="dpmpp-2m-sde", |
| sigma_min=0.5, |
| sigma_max=50, |
| rho=1.0, device="cuda", |
| callback=None, |
| cond_fn=None, |
| **extra_args |
| ): |
|
|
| denoiser = K.external.VDenoiser(model_fn) |
|
|
| if cond_fn is not None: |
| denoiser = make_cond_model_fn(denoiser, cond_fn) |
|
|
| |
| sigmas = K.sampling.get_sigmas_polyexponential(steps, sigma_min, sigma_max, rho, device=device) |
| |
| noise = noise * sigmas[0] |
|
|
| wrapped_callback = callback |
|
|
|
|
| if mask is None and init_data is not None: |
| |
| |
|
|
| x = init_data + noise |
| |
| elif mask is not None and init_data is not None: |
| |
| bmask = get_bmask(0, steps, mask) |
| |
| input_noised = init_data + noise |
| |
| x = input_noised * bmask + noise * (1-bmask) |
| |
| |
| |
| |
| def inpainting_callback(args): |
| i = args["i"] |
| x = args["x"] |
| sigma = args["sigma"] |
| |
| |
| input_noised = init_data + torch.randn_like(init_data) * sigma |
| |
| bmask = get_bmask(i, steps, mask) |
| |
| new_x = input_noised * bmask + x * (1-bmask) |
| |
| x[:,:,:] = new_x[:,:,:] |
| |
| if callback is None: |
| wrapped_callback = inpainting_callback |
| else: |
| wrapped_callback = lambda args: (inpainting_callback(args), callback(args)) |
| else: |
| |
| |
| x = noise |
| |
|
|
| with torch.cuda.amp.autocast(): |
| if sampler_type == "k-heun": |
| return K.sampling.sample_heun(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) |
| elif sampler_type == "k-lms": |
| return K.sampling.sample_lms(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) |
| elif sampler_type == "k-dpmpp-2s-ancestral": |
| return K.sampling.sample_dpmpp_2s_ancestral(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) |
| elif sampler_type == "k-dpm-2": |
| return K.sampling.sample_dpm_2(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) |
| elif sampler_type == "k-dpm-fast": |
| return K.sampling.sample_dpm_fast(denoiser, x, sigma_min, sigma_max, steps, disable=False, callback=wrapped_callback, extra_args=extra_args) |
| elif sampler_type == "k-dpm-adaptive": |
| return K.sampling.sample_dpm_adaptive(denoiser, x, sigma_min, sigma_max, rtol=0.01, atol=0.01, disable=False, callback=wrapped_callback, extra_args=extra_args) |
| elif sampler_type == "dpmpp-2m-sde": |
| return K.sampling.sample_dpmpp_2m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) |
| elif sampler_type == "dpmpp-3m-sde": |
| return K.sampling.sample_dpmpp_3m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) |
|
|
| |
| |
| |
| |
| |
| def sample_rf( |
| model_fn, |
| noise, |
| init_data=None, |
| steps=100, |
| sigma_max=1, |
| device="cuda", |
| callback=None, |
| cond_fn=None, |
| **extra_args |
| ): |
|
|
| if sigma_max > 1: |
| sigma_max = 1 |
|
|
| if cond_fn is not None: |
| denoiser = make_cond_model_fn(denoiser, cond_fn) |
|
|
| wrapped_callback = callback |
|
|
| if init_data is not None: |
| |
| |
| x = init_data * (1 - sigma_max) + noise * sigma_max |
| else: |
| |
| |
| x = noise |
|
|
| with torch.cuda.amp.autocast(): |
| |
| |
| return sample_discrete_euler(model_fn, x, steps, sigma_max, **extra_args) |