| import math |
| from typing import Dict |
|
|
| import numpy as np |
| import torch |
| from torch import nn |
| import torch.nn.functional as F |
| from torch_scatter import scatter_add, scatter_mean |
|
|
| import utils |
|
|
|
|
| class EnVariationalDiffusion(nn.Module): |
| """ |
| The E(n) Diffusion Module. |
| """ |
|
|
| def __init__( |
| self, |
| dynamics: nn.Module, atom_nf: int, residue_nf: int, |
| n_dims: int, size_histogram: Dict, |
| timesteps: int = 1000, parametrization='eps', |
| noise_schedule='learned', noise_precision=1e-4, |
| loss_type='vlb', norm_values=(1., 1.), norm_biases=(None, 0.), |
| virtual_node_idx=None): |
| super().__init__() |
|
|
| assert loss_type in {'vlb', 'l2'} |
| self.loss_type = loss_type |
| if noise_schedule == 'learned': |
| assert loss_type == 'vlb', 'A noise schedule can only be learned' \ |
| ' with a vlb objective.' |
|
|
| |
| assert parametrization == 'eps' |
|
|
| if noise_schedule == 'learned': |
| self.gamma = GammaNetwork() |
| else: |
| self.gamma = PredefinedNoiseSchedule(noise_schedule, |
| timesteps=timesteps, |
| precision=noise_precision) |
|
|
| |
| self.dynamics = dynamics |
|
|
| self.atom_nf = atom_nf |
| self.residue_nf = residue_nf |
| self.n_dims = n_dims |
| self.num_classes = self.atom_nf |
|
|
| self.T = timesteps |
| self.parametrization = parametrization |
|
|
| self.norm_values = norm_values |
| self.norm_biases = norm_biases |
| self.register_buffer('buffer', torch.zeros(1)) |
|
|
| |
| self.size_distribution = DistributionNodes(size_histogram) |
|
|
| |
| self.vnode_idx = virtual_node_idx |
|
|
| if noise_schedule != 'learned': |
| self.check_issues_norm_values() |
|
|
| def check_issues_norm_values(self, num_stdevs=8): |
| zeros = torch.zeros((1, 1)) |
| gamma_0 = self.gamma(zeros) |
| sigma_0 = self.sigma(gamma_0, target_tensor=zeros).item() |
|
|
| |
| |
| norm_value = self.norm_values[1] |
|
|
| if sigma_0 * num_stdevs > 1. / norm_value: |
| raise ValueError( |
| f'Value for normalization value {norm_value} probably too ' |
| f'large with sigma_0 {sigma_0:.5f} and ' |
| f'1 / norm_value = {1. / norm_value}') |
|
|
| def sigma_and_alpha_t_given_s(self, gamma_t: torch.Tensor, |
| gamma_s: torch.Tensor, |
| target_tensor: torch.Tensor): |
| """ |
| Computes sigma t given s, using gamma_t and gamma_s. Used during sampling. |
| These are defined as: |
| alpha t given s = alpha t / alpha s, |
| sigma t given s = sqrt(1 - (alpha t given s) ^2 ). |
| """ |
| sigma2_t_given_s = self.inflate_batch_array( |
| -torch.expm1(F.softplus(gamma_s) - F.softplus(gamma_t)), target_tensor |
| ) |
|
|
| |
| log_alpha2_t = F.logsigmoid(-gamma_t) |
| log_alpha2_s = F.logsigmoid(-gamma_s) |
| log_alpha2_t_given_s = log_alpha2_t - log_alpha2_s |
|
|
| alpha_t_given_s = torch.exp(0.5 * log_alpha2_t_given_s) |
| alpha_t_given_s = self.inflate_batch_array( |
| alpha_t_given_s, target_tensor) |
|
|
| sigma_t_given_s = torch.sqrt(sigma2_t_given_s) |
|
|
| return sigma2_t_given_s, sigma_t_given_s, alpha_t_given_s |
|
|
| def kl_prior_with_pocket(self, xh_lig, xh_pocket, mask_lig, mask_pocket, |
| num_nodes): |
| """Computes the KL between q(z1 | x) and the prior p(z1) = Normal(0, 1). |
| |
| This is essentially a lot of work for something that is in practice |
| negligible in the loss. However, you compute it so that you see it when |
| you've made a mistake in your noise schedule. |
| """ |
| batch_size = len(num_nodes) |
|
|
| |
| ones = torch.ones((batch_size, 1), device=xh_lig.device) |
| gamma_T = self.gamma(ones) |
| alpha_T = self.alpha(gamma_T, xh_lig) |
|
|
| |
| mu_T_lig = alpha_T[mask_lig] * xh_lig |
| mu_T_lig_x, mu_T_lig_h = mu_T_lig[:, :self.n_dims], \ |
| mu_T_lig[:, self.n_dims:] |
|
|
| |
| sigma_T_x = self.sigma(gamma_T, mu_T_lig_x).squeeze() |
| sigma_T_h = self.sigma(gamma_T, mu_T_lig_h).squeeze() |
|
|
| |
| mu_T_pocket = alpha_T[mask_pocket] * xh_pocket |
| mu_T_pocket_x, mu_T_pocket_h = mu_T_pocket[:, :self.n_dims], \ |
| mu_T_pocket[:, self.n_dims:] |
|
|
| |
| zeros_lig = torch.zeros_like(mu_T_lig_h) |
| zeros_pocket = torch.zeros_like(mu_T_pocket_h) |
| ones = torch.ones_like(sigma_T_h) |
| mu_norm2 = self.sum_except_batch((mu_T_lig_h - zeros_lig) ** 2, mask_lig) + \ |
| self.sum_except_batch((mu_T_pocket_h - zeros_pocket) ** 2, mask_pocket) |
| kl_distance_h = self.gaussian_KL(mu_norm2, sigma_T_h, ones, d=1) |
|
|
| |
| zeros_lig = torch.zeros_like(mu_T_lig_x) |
| zeros_pocket = torch.zeros_like(mu_T_pocket_x) |
| ones = torch.ones_like(sigma_T_x) |
| mu_norm2 = self.sum_except_batch((mu_T_lig_x - zeros_lig) ** 2, mask_lig) + \ |
| self.sum_except_batch((mu_T_pocket_x - zeros_pocket) ** 2, mask_pocket) |
| subspace_d = self.subspace_dimensionality(num_nodes) |
| kl_distance_x = self.gaussian_KL(mu_norm2, sigma_T_x, ones, subspace_d) |
|
|
| return kl_distance_x + kl_distance_h |
|
|
| def compute_x_pred(self, net_out, zt, gamma_t, batch_mask): |
| """Commputes x_pred, i.e. the most likely prediction of x.""" |
| if self.parametrization == 'x': |
| x_pred = net_out |
| elif self.parametrization == 'eps': |
| sigma_t = self.sigma(gamma_t, target_tensor=net_out) |
| alpha_t = self.alpha(gamma_t, target_tensor=net_out) |
| eps_t = net_out |
| x_pred = 1. / alpha_t[batch_mask] * (zt - sigma_t[batch_mask] * eps_t) |
| else: |
| raise ValueError(self.parametrization) |
|
|
| return x_pred |
|
|
| def log_constants_p_x_given_z0(self, n_nodes, device): |
| """Computes p(x|z0).""" |
|
|
| batch_size = len(n_nodes) |
| degrees_of_freedom_x = self.subspace_dimensionality(n_nodes) |
|
|
| zeros = torch.zeros((batch_size, 1), device=device) |
| gamma_0 = self.gamma(zeros) |
|
|
| |
| log_sigma_x = 0.5 * gamma_0.view(batch_size) |
|
|
| return degrees_of_freedom_x * (- log_sigma_x - 0.5 * np.log(2 * np.pi)) |
|
|
| def log_pxh_given_z0_without_constants( |
| self, ligand, z_0_lig, eps_lig, net_out_lig, |
| pocket, z_0_pocket, eps_pocket, net_out_pocket, |
| gamma_0, epsilon=1e-10): |
|
|
| |
| z_h_lig = z_0_lig[:, self.n_dims:] |
| z_h_pocket = z_0_pocket[:, self.n_dims:] |
|
|
| |
| eps_lig_x = eps_lig[:, :self.n_dims] |
| net_lig_x = net_out_lig[:, :self.n_dims] |
| eps_pocket_x = eps_pocket[:, :self.n_dims] |
| net_pocket_x = net_out_pocket[:, :self.n_dims] |
|
|
| |
| sigma_0 = self.sigma(gamma_0, target_tensor=z_0_lig) |
| sigma_0_cat = sigma_0 * self.norm_values[1] |
|
|
| |
| |
| |
| log_p_x_given_z0_without_constants_ligand = -0.5 * ( |
| self.sum_except_batch((eps_lig_x - net_lig_x) ** 2, ligand['mask']) |
| ) |
|
|
| log_p_x_given_z0_without_constants_pocket = -0.5 * ( |
| self.sum_except_batch((eps_pocket_x - net_pocket_x) ** 2, |
| pocket['mask']) |
| ) |
|
|
| |
| |
| ligand_onehot = ligand['one_hot'] * self.norm_values[1] + self.norm_biases[1] |
| pocket_onehot = pocket['one_hot'] * self.norm_values[1] + self.norm_biases[1] |
|
|
| estimated_ligand_onehot = z_h_lig * self.norm_values[1] + self.norm_biases[1] |
| estimated_pocket_onehot = z_h_pocket * self.norm_values[1] + self.norm_biases[1] |
|
|
| |
| centered_ligand_onehot = estimated_ligand_onehot - 1 |
| centered_pocket_onehot = estimated_pocket_onehot - 1 |
|
|
| |
| |
| log_ph_cat_proportional_ligand = torch.log( |
| self.cdf_standard_gaussian((centered_ligand_onehot + 0.5) / sigma_0_cat[ligand['mask']]) |
| - self.cdf_standard_gaussian((centered_ligand_onehot - 0.5) / sigma_0_cat[ligand['mask']]) |
| + epsilon |
| ) |
| log_ph_cat_proportional_pocket = torch.log( |
| self.cdf_standard_gaussian((centered_pocket_onehot + 0.5) / sigma_0_cat[pocket['mask']]) |
| - self.cdf_standard_gaussian((centered_pocket_onehot - 0.5) / sigma_0_cat[pocket['mask']]) |
| + epsilon |
| ) |
|
|
| |
| log_Z = torch.logsumexp(log_ph_cat_proportional_ligand, dim=1, |
| keepdim=True) |
| log_probabilities_ligand = log_ph_cat_proportional_ligand - log_Z |
|
|
| log_Z = torch.logsumexp(log_ph_cat_proportional_pocket, dim=1, |
| keepdim=True) |
| log_probabilities_pocket = log_ph_cat_proportional_pocket - log_Z |
|
|
| |
| |
| log_ph_given_z0_ligand = self.sum_except_batch( |
| log_probabilities_ligand * ligand_onehot, ligand['mask']) |
| log_ph_given_z0_pocket = self.sum_except_batch( |
| log_probabilities_pocket * pocket_onehot, pocket['mask']) |
|
|
| |
| log_ph_given_z0 = log_ph_given_z0_ligand + log_ph_given_z0_pocket |
|
|
| return log_p_x_given_z0_without_constants_ligand, \ |
| log_p_x_given_z0_without_constants_pocket, log_ph_given_z0 |
|
|
| def sample_p_xh_given_z0(self, z0_lig, z0_pocket, lig_mask, pocket_mask, |
| batch_size, fix_noise=False): |
| """Samples x ~ p(x|z0).""" |
| t_zeros = torch.zeros(size=(batch_size, 1), device=z0_lig.device) |
| gamma_0 = self.gamma(t_zeros) |
| |
| sigma_x = self.SNR(-0.5 * gamma_0) |
| net_out_lig, net_out_pocket = self.dynamics( |
| z0_lig, z0_pocket, t_zeros, lig_mask, pocket_mask) |
|
|
| |
| mu_x_lig = self.compute_x_pred(net_out_lig, z0_lig, gamma_0, lig_mask) |
| mu_x_pocket = self.compute_x_pred(net_out_pocket, z0_pocket, gamma_0, |
| pocket_mask) |
| xh_lig, xh_pocket = self.sample_normal(mu_x_lig, mu_x_pocket, sigma_x, |
| lig_mask, pocket_mask, fix_noise) |
|
|
| x_lig, h_lig = self.unnormalize( |
| xh_lig[:, :self.n_dims], z0_lig[:, self.n_dims:]) |
| x_pocket, h_pocket = self.unnormalize( |
| xh_pocket[:, :self.n_dims], z0_pocket[:, self.n_dims:]) |
|
|
| h_lig = F.one_hot(torch.argmax(h_lig, dim=1), self.atom_nf) |
| h_pocket = F.one_hot(torch.argmax(h_pocket, dim=1), self.residue_nf) |
|
|
| return x_lig, h_lig, x_pocket, h_pocket |
|
|
| def sample_normal(self, mu_lig, mu_pocket, sigma, lig_mask, pocket_mask, |
| fix_noise=False): |
| """Samples from a Normal distribution.""" |
| if fix_noise: |
| |
| raise NotImplementedError("fix_noise option isn't implemented yet") |
| eps_lig, eps_pocket = self.sample_combined_position_feature_noise( |
| lig_mask, pocket_mask) |
|
|
| return mu_lig + sigma[lig_mask] * eps_lig, \ |
| mu_pocket + sigma[pocket_mask] * eps_pocket |
|
|
| def noised_representation(self, xh_lig, xh_pocket, lig_mask, pocket_mask, |
| gamma_t): |
| |
| alpha_t = self.alpha(gamma_t, xh_lig) |
| sigma_t = self.sigma(gamma_t, xh_lig) |
|
|
| |
| eps_lig, eps_pocket = self.sample_combined_position_feature_noise( |
| lig_mask, pocket_mask) |
|
|
| |
| z_t_lig = alpha_t[lig_mask] * xh_lig + sigma_t[lig_mask] * eps_lig |
| z_t_pocket = alpha_t[pocket_mask] * xh_pocket + \ |
| sigma_t[pocket_mask] * eps_pocket |
|
|
| return z_t_lig, z_t_pocket, eps_lig, eps_pocket |
|
|
| def log_pN(self, N_lig, N_pocket): |
| """ |
| Prior on the sample size for computing |
| log p(x,h,N) = log p(x,h|N) + log p(N), where log p(x,h|N) is the |
| model's output |
| Args: |
| N: array of sample sizes |
| Returns: |
| log p(N) |
| """ |
| log_pN = self.size_distribution.log_prob(N_lig, N_pocket) |
| return log_pN |
|
|
| def delta_log_px(self, num_nodes): |
| return -self.subspace_dimensionality(num_nodes) * \ |
| np.log(self.norm_values[0]) |
|
|
| def forward(self, ligand, pocket, return_info=False): |
| """ |
| Computes the loss and NLL terms |
| """ |
| |
| ligand, pocket = self.normalize(ligand, pocket) |
|
|
| |
| delta_log_px = self.delta_log_px(ligand['size'] + pocket['size']) |
|
|
| |
| |
| |
| lowest_t = 0 if self.training else 1 |
| t_int = torch.randint( |
| lowest_t, self.T + 1, size=(ligand['size'].size(0), 1), |
| device=ligand['x'].device).float() |
| s_int = t_int - 1 |
|
|
| |
| t_is_zero = (t_int == 0).float() |
| t_is_not_zero = 1 - t_is_zero |
|
|
| |
| |
| s = s_int / self.T |
| t = t_int / self.T |
|
|
| |
| gamma_s = self.inflate_batch_array(self.gamma(s), ligand['x']) |
| gamma_t = self.inflate_batch_array(self.gamma(t), ligand['x']) |
|
|
| |
| xh_lig = torch.cat([ligand['x'], ligand['one_hot']], dim=1) |
| xh_pocket = torch.cat([pocket['x'], pocket['one_hot']], dim=1) |
|
|
| |
| z_t_lig, z_t_pocket, eps_t_lig, eps_t_pocket = \ |
| self.noised_representation(xh_lig, xh_pocket, ligand['mask'], |
| pocket['mask'], gamma_t) |
|
|
| |
| net_out_lig, net_out_pocket = self.dynamics( |
| z_t_lig, z_t_pocket, t, ligand['mask'], pocket['mask']) |
|
|
| |
| xh_lig_hat = self.xh_given_zt_and_epsilon(z_t_lig, net_out_lig, gamma_t, |
| ligand['mask']) |
|
|
| |
| error_t_lig = self.sum_except_batch((eps_t_lig - net_out_lig) ** 2, |
| ligand['mask']) |
|
|
| error_t_pocket = self.sum_except_batch( |
| (eps_t_pocket - net_out_pocket) ** 2, pocket['mask']) |
|
|
| |
| SNR_weight = (1 - self.SNR(gamma_s - gamma_t)).squeeze(1) |
| assert error_t_lig.size() == SNR_weight.size() |
|
|
| |
| |
| neg_log_constants = -self.log_constants_p_x_given_z0( |
| n_nodes=ligand['size'] + pocket['size'], device=error_t_lig.device) |
|
|
| |
| |
| kl_prior = self.kl_prior_with_pocket( |
| xh_lig, xh_pocket, ligand['mask'], pocket['mask'], |
| ligand['size'] + pocket['size']) |
|
|
| if self.training: |
| |
| |
| log_p_x_given_z0_without_constants_ligand, \ |
| log_p_x_given_z0_without_constants_pocket, log_ph_given_z0 = \ |
| self.log_pxh_given_z0_without_constants( |
| ligand, z_t_lig, eps_t_lig, net_out_lig, |
| pocket, z_t_pocket, eps_t_pocket, net_out_pocket, gamma_t) |
|
|
| loss_0_x_ligand = -log_p_x_given_z0_without_constants_ligand * \ |
| t_is_zero.squeeze() |
| loss_0_x_pocket = -log_p_x_given_z0_without_constants_pocket * \ |
| t_is_zero.squeeze() |
| loss_0_h = -log_ph_given_z0 * t_is_zero.squeeze() |
|
|
| |
| error_t_lig = error_t_lig * t_is_not_zero.squeeze() |
| error_t_pocket = error_t_pocket * t_is_not_zero.squeeze() |
|
|
| else: |
| |
| t_zeros = torch.zeros_like(s) |
| gamma_0 = self.inflate_batch_array(self.gamma(t_zeros), ligand['x']) |
|
|
| |
| z_0_lig, z_0_pocket, eps_0_lig, eps_0_pocket = \ |
| self.noised_representation(xh_lig, xh_pocket, ligand['mask'], |
| pocket['mask'], gamma_0) |
|
|
| net_out_0_lig, net_out_0_pocket = self.dynamics( |
| z_0_lig, z_0_pocket, t_zeros, ligand['mask'], pocket['mask']) |
|
|
| log_p_x_given_z0_without_constants_ligand, \ |
| log_p_x_given_z0_without_constants_pocket, log_ph_given_z0 = \ |
| self.log_pxh_given_z0_without_constants( |
| ligand, z_0_lig, eps_0_lig, net_out_0_lig, |
| pocket, z_0_pocket, eps_0_pocket, net_out_0_pocket, gamma_0) |
| loss_0_x_ligand = -log_p_x_given_z0_without_constants_ligand |
| loss_0_x_pocket = -log_p_x_given_z0_without_constants_pocket |
| loss_0_h = -log_ph_given_z0 |
|
|
| |
| log_pN = self.log_pN(ligand['size'], pocket['size']) |
|
|
| info = { |
| 'eps_hat_lig_x': scatter_mean( |
| net_out_lig[:, :self.n_dims].abs().mean(1), ligand['mask'], |
| dim=0).mean(), |
| 'eps_hat_lig_h': scatter_mean( |
| net_out_lig[:, self.n_dims:].abs().mean(1), ligand['mask'], |
| dim=0).mean(), |
| 'eps_hat_pocket_x': scatter_mean( |
| net_out_pocket[:, :self.n_dims].abs().mean(1), pocket['mask'], |
| dim=0).mean(), |
| 'eps_hat_pocket_h': scatter_mean( |
| net_out_pocket[:, self.n_dims:].abs().mean(1), pocket['mask'], |
| dim=0).mean(), |
| } |
| loss_terms = (delta_log_px, error_t_lig, error_t_pocket, SNR_weight, |
| loss_0_x_ligand, loss_0_x_pocket, loss_0_h, |
| neg_log_constants, kl_prior, log_pN, |
| t_int.squeeze(), xh_lig_hat) |
| return (*loss_terms, info) if return_info else loss_terms |
|
|
| def xh_given_zt_and_epsilon(self, z_t, epsilon, gamma_t, batch_mask): |
| """ Equation (7) in the EDM paper """ |
| alpha_t = self.alpha(gamma_t, z_t) |
| sigma_t = self.sigma(gamma_t, z_t) |
| xh = z_t / alpha_t[batch_mask] - epsilon * sigma_t[batch_mask] / \ |
| alpha_t[batch_mask] |
| return xh |
|
|
| def sample_p_zt_given_zs(self, zs_lig, zs_pocket, ligand_mask, pocket_mask, |
| gamma_t, gamma_s, fix_noise=False): |
| sigma2_t_given_s, sigma_t_given_s, alpha_t_given_s = \ |
| self.sigma_and_alpha_t_given_s(gamma_t, gamma_s, zs_lig) |
|
|
| mu_lig = alpha_t_given_s[ligand_mask] * zs_lig |
| mu_pocket = alpha_t_given_s[pocket_mask] * zs_pocket |
| zt_lig, zt_pocket = self.sample_normal( |
| mu_lig, mu_pocket, sigma_t_given_s, ligand_mask, pocket_mask, |
| fix_noise) |
|
|
| |
| zt_x = self.remove_mean_batch( |
| torch.cat((zt_lig[:, :self.n_dims], zt_pocket[:, :self.n_dims]), |
| dim=0), |
| torch.cat((ligand_mask, pocket_mask)) |
| ) |
| zt_lig = torch.cat((zt_x[:len(ligand_mask)], |
| zt_lig[:, self.n_dims:]), dim=1) |
| zt_pocket = torch.cat((zt_x[len(ligand_mask):], |
| zt_pocket[:, self.n_dims:]), dim=1) |
|
|
| return zt_lig, zt_pocket |
|
|
| def sample_p_zs_given_zt(self, s, t, zt_lig, zt_pocket, ligand_mask, |
| pocket_mask, fix_noise=False): |
| """Samples from zs ~ p(zs | zt). Only used during sampling.""" |
| gamma_s = self.gamma(s) |
| gamma_t = self.gamma(t) |
|
|
| sigma2_t_given_s, sigma_t_given_s, alpha_t_given_s = \ |
| self.sigma_and_alpha_t_given_s(gamma_t, gamma_s, zt_lig) |
|
|
| sigma_s = self.sigma(gamma_s, target_tensor=zt_lig) |
| sigma_t = self.sigma(gamma_t, target_tensor=zt_lig) |
|
|
| |
| eps_t_lig, eps_t_pocket = self.dynamics( |
| zt_lig, zt_pocket, t, ligand_mask, pocket_mask) |
|
|
| |
| combined_mask = torch.cat((ligand_mask, pocket_mask)) |
| self.assert_mean_zero_with_mask( |
| torch.cat((zt_lig[:, :self.n_dims], |
| zt_pocket[:, :self.n_dims]), dim=0), |
| combined_mask) |
| self.assert_mean_zero_with_mask( |
| torch.cat((eps_t_lig[:, :self.n_dims], |
| eps_t_pocket[:, :self.n_dims]), dim=0), |
| combined_mask) |
|
|
| |
| |
| mu_lig = zt_lig / alpha_t_given_s[ligand_mask] - \ |
| (sigma2_t_given_s / alpha_t_given_s / sigma_t)[ligand_mask] * \ |
| eps_t_lig |
| mu_pocket = zt_pocket / alpha_t_given_s[pocket_mask] - \ |
| (sigma2_t_given_s / alpha_t_given_s / sigma_t)[pocket_mask] * \ |
| eps_t_pocket |
|
|
| |
| sigma = sigma_t_given_s * sigma_s / sigma_t |
|
|
| |
| zs_lig, zs_pocket = self.sample_normal(mu_lig, mu_pocket, sigma, |
| ligand_mask, pocket_mask, |
| fix_noise) |
|
|
| |
| zs_x = self.remove_mean_batch( |
| torch.cat((zs_lig[:, :self.n_dims], |
| zs_pocket[:, :self.n_dims]), dim=0), |
| torch.cat((ligand_mask, pocket_mask)) |
| ) |
| zs_lig = torch.cat((zs_x[:len(ligand_mask)], |
| zs_lig[:, self.n_dims:]), dim=1) |
| zs_pocket = torch.cat((zs_x[len(ligand_mask):], |
| zs_pocket[:, self.n_dims:]), dim=1) |
| return zs_lig, zs_pocket |
|
|
| def sample_combined_position_feature_noise(self, lig_indices, |
| pocket_indices): |
| """ |
| Samples mean-centered normal noise for z_x, and standard normal noise |
| for z_h. |
| """ |
| z_x = self.sample_center_gravity_zero_gaussian_batch( |
| size=(len(lig_indices) + len(pocket_indices), self.n_dims), |
| lig_indices=lig_indices, |
| pocket_indices=pocket_indices |
| ) |
| z_h_lig = self.sample_gaussian( |
| size=(len(lig_indices), self.atom_nf), |
| device=lig_indices.device) |
| z_lig = torch.cat([z_x[:len(lig_indices)], z_h_lig], dim=1) |
| z_h_pocket = self.sample_gaussian( |
| size=(len(pocket_indices), self.residue_nf), |
| device=pocket_indices.device) |
| z_pocket = torch.cat([z_x[len(lig_indices):], z_h_pocket], dim=1) |
| return z_lig, z_pocket |
|
|
| @torch.no_grad() |
| def sample(self, n_samples, num_nodes_lig, num_nodes_pocket, |
| return_frames=1, timesteps=None, device='cpu'): |
| """ |
| Draw samples from the generative model. Optionally, return intermediate |
| states for visualization purposes. |
| """ |
| timesteps = self.T if timesteps is None else timesteps |
| assert 0 < return_frames <= timesteps |
| assert timesteps % return_frames == 0 |
|
|
| lig_mask = utils.num_nodes_to_batch_mask(n_samples, num_nodes_lig, |
| device) |
| pocket_mask = utils.num_nodes_to_batch_mask(n_samples, num_nodes_pocket, |
| device) |
|
|
| combined_mask = torch.cat((lig_mask, pocket_mask)) |
|
|
| z_lig, z_pocket = self.sample_combined_position_feature_noise( |
| lig_mask, pocket_mask) |
|
|
| self.assert_mean_zero_with_mask( |
| torch.cat((z_lig[:, :self.n_dims], z_pocket[:, :self.n_dims]), dim=0), |
| combined_mask |
| ) |
|
|
| out_lig = torch.zeros((return_frames,) + z_lig.size(), |
| device=z_lig.device) |
| out_pocket = torch.zeros((return_frames,) + z_pocket.size(), |
| device=z_pocket.device) |
|
|
| |
| for s in reversed(range(0, timesteps)): |
| s_array = torch.full((n_samples, 1), fill_value=s, |
| device=z_lig.device) |
| t_array = s_array + 1 |
| s_array = s_array / timesteps |
| t_array = t_array / timesteps |
|
|
| z_lig, z_pocket = self.sample_p_zs_given_zt( |
| s_array, t_array, z_lig, z_pocket, lig_mask, pocket_mask) |
|
|
| |
| if (s * return_frames) % timesteps == 0: |
| idx = (s * return_frames) // timesteps |
| out_lig[idx], out_pocket[idx] = \ |
| self.unnormalize_z(z_lig, z_pocket) |
|
|
| |
| x_lig, h_lig, x_pocket, h_pocket = self.sample_p_xh_given_z0( |
| z_lig, z_pocket, lig_mask, pocket_mask, n_samples) |
|
|
| self.assert_mean_zero_with_mask( |
| torch.cat((x_lig, x_pocket), dim=0), combined_mask |
| ) |
|
|
| |
| if return_frames == 1: |
| x = torch.cat((x_lig, x_pocket)) |
| max_cog = scatter_add(x, combined_mask, dim=0).abs().max().item() |
| if max_cog > 5e-2: |
| print(f'Warning CoG drift with error {max_cog:.3f}. Projecting ' |
| f'the positions down.') |
| x = self.remove_mean_batch(x, combined_mask) |
| x_lig, x_pocket = x[:len(x_lig)], x[len(x_lig):] |
|
|
| |
| out_lig[0] = torch.cat([x_lig, h_lig], dim=1) |
| out_pocket[0] = torch.cat([x_pocket, h_pocket], dim=1) |
|
|
| |
| return out_lig.squeeze(0), out_pocket.squeeze(0), lig_mask, pocket_mask |
|
|
| def get_repaint_schedule(self, resamplings, jump_length, timesteps): |
| """ Each integer in the schedule list describes how many denoising steps |
| need to be applied before jumping back """ |
| repaint_schedule = [] |
| curr_t = 0 |
| while curr_t < timesteps: |
| if curr_t + jump_length < timesteps: |
| if len(repaint_schedule) > 0: |
| repaint_schedule[-1] += jump_length |
| repaint_schedule.extend([jump_length] * (resamplings - 1)) |
| else: |
| repaint_schedule.extend([jump_length] * resamplings) |
| curr_t += jump_length |
| else: |
| residual = (timesteps - curr_t) |
| if len(repaint_schedule) > 0: |
| repaint_schedule[-1] += residual |
| else: |
| repaint_schedule.append(residual) |
| curr_t += residual |
|
|
| return list(reversed(repaint_schedule)) |
|
|
| @torch.no_grad() |
| def inpaint(self, ligand, pocket, lig_fixed, pocket_fixed, resamplings=1, |
| jump_length=1, return_frames=1, timesteps=None): |
| """ |
| Draw samples from the generative model while fixing parts of the input. |
| Optionally, return intermediate states for visualization purposes. |
| See: |
| Lugmayr, Andreas, et al. |
| "Repaint: Inpainting using denoising diffusion probabilistic models." |
| Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern |
| Recognition. 2022. |
| """ |
| timesteps = self.T if timesteps is None else timesteps |
| assert 0 < return_frames <= timesteps |
| assert timesteps % return_frames == 0 |
| assert jump_length == 1 or return_frames == 1, \ |
| "Chain visualization is only implemented for jump_length=1" |
|
|
| if len(lig_fixed.size()) == 1: |
| lig_fixed = lig_fixed.unsqueeze(1) |
| if len(pocket_fixed.size()) == 1: |
| pocket_fixed = pocket_fixed.unsqueeze(1) |
|
|
| ligand, pocket = self.normalize(ligand, pocket) |
|
|
| n_samples = len(ligand['size']) |
| combined_mask = torch.cat((ligand['mask'], pocket['mask'])) |
| xh0_lig = torch.cat([ligand['x'], ligand['one_hot']], dim=1) |
| xh0_pocket = torch.cat([pocket['x'], pocket['one_hot']], dim=1) |
|
|
| |
| mean_known = scatter_mean( |
| torch.cat((ligand['x'][lig_fixed.bool().view(-1)], |
| pocket['x'][pocket_fixed.bool().view(-1)])), |
| torch.cat((ligand['mask'][lig_fixed.bool().view(-1)], |
| pocket['mask'][pocket_fixed.bool().view(-1)])), |
| dim=0 |
| ) |
| xh0_lig[:, :self.n_dims] = \ |
| xh0_lig[:, :self.n_dims] - mean_known[ligand['mask']] |
| xh0_pocket[:, :self.n_dims] = \ |
| xh0_pocket[:, :self.n_dims] - mean_known[pocket['mask']] |
|
|
| |
| z_lig, z_pocket = self.sample_combined_position_feature_noise( |
| ligand['mask'], pocket['mask']) |
|
|
| |
| out_lig = torch.zeros((return_frames,) + z_lig.size(), |
| device=z_lig.device) |
| out_pocket = torch.zeros((return_frames,) + z_pocket.size(), |
| device=z_pocket.device) |
|
|
| |
| schedule = self.get_repaint_schedule(resamplings, jump_length, timesteps) |
| s = timesteps - 1 |
| for i, n_denoise_steps in enumerate(schedule): |
| for j in range(n_denoise_steps): |
| |
| s_array = torch.full((n_samples, 1), fill_value=s, |
| device=z_lig.device) |
| t_array = s_array + 1 |
| s_array = s_array / timesteps |
| t_array = t_array / timesteps |
|
|
| |
| gamma_s = self.inflate_batch_array(self.gamma(s_array), |
| ligand['x']) |
| z_lig_known, z_pocket_known, _, _ = self.noised_representation( |
| xh0_lig, xh0_pocket, ligand['mask'], pocket['mask'], gamma_s) |
|
|
| |
| z_lig_unknown, z_pocket_unknown = self.sample_p_zs_given_zt( |
| s_array, t_array, z_lig, z_pocket, ligand['mask'], |
| pocket['mask']) |
|
|
| |
| |
| |
| com_noised = scatter_mean( |
| torch.cat((z_lig_known[:, :self.n_dims][lig_fixed.bool().view(-1)], |
| z_pocket_known[:, :self.n_dims][pocket_fixed.bool().view(-1)])), |
| torch.cat((ligand['mask'][lig_fixed.bool().view(-1)], |
| pocket['mask'][pocket_fixed.bool().view(-1)])), |
| dim=0 |
| ) |
| com_denoised = scatter_mean( |
| torch.cat((z_lig_unknown[:, :self.n_dims][lig_fixed.bool().view(-1)], |
| z_pocket_unknown[:, :self.n_dims][pocket_fixed.bool().view(-1)])), |
| torch.cat((ligand['mask'][lig_fixed.bool().view(-1)], |
| pocket['mask'][pocket_fixed.bool().view(-1)])), |
| dim=0 |
| ) |
| z_lig_known[:, :self.n_dims] = \ |
| z_lig_known[:, :self.n_dims] + (com_denoised - com_noised)[ligand['mask']] |
| z_pocket_known[:, :self.n_dims] = \ |
| z_pocket_known[:, :self.n_dims] + (com_denoised - com_noised)[pocket['mask']] |
|
|
| |
| z_lig = z_lig_known * lig_fixed + \ |
| z_lig_unknown * (1 - lig_fixed) |
| z_pocket = z_pocket_known * pocket_fixed + \ |
| z_pocket_unknown * (1 - pocket_fixed) |
|
|
| self.assert_mean_zero_with_mask( |
| torch.cat((z_lig[:, :self.n_dims], |
| z_pocket[:, :self.n_dims]), dim=0), combined_mask |
| ) |
|
|
| |
| if n_denoise_steps > jump_length or i == len(schedule) - 1: |
| if (s * return_frames) % timesteps == 0: |
| idx = (s * return_frames) // timesteps |
| out_lig[idx], out_pocket[idx] = \ |
| self.unnormalize_z(z_lig, z_pocket) |
|
|
| |
| if j == n_denoise_steps - 1 and i < len(schedule) - 1: |
| |
| t = s + jump_length |
| t_array = torch.full((n_samples, 1), fill_value=t, |
| device=z_lig.device) |
| t_array = t_array / timesteps |
|
|
| gamma_s = self.inflate_batch_array(self.gamma(s_array), |
| ligand['x']) |
| gamma_t = self.inflate_batch_array(self.gamma(t_array), |
| ligand['x']) |
|
|
| z_lig, z_pocket = self.sample_p_zt_given_zs( |
| z_lig, z_pocket, ligand['mask'], pocket['mask'], |
| gamma_t, gamma_s) |
|
|
| s = t |
|
|
| s -= 1 |
|
|
| |
| x_lig, h_lig, x_pocket, h_pocket = self.sample_p_xh_given_z0( |
| z_lig, z_pocket, ligand['mask'], pocket['mask'], n_samples) |
|
|
| self.assert_mean_zero_with_mask( |
| torch.cat((x_lig, x_pocket), dim=0), combined_mask |
| ) |
|
|
| |
| if return_frames == 1: |
| x = torch.cat((x_lig, x_pocket)) |
| max_cog = scatter_add(x, combined_mask, dim=0).abs().max().item() |
| if max_cog > 5e-2: |
| print(f'Warning CoG drift with error {max_cog:.3f}. Projecting ' |
| f'the positions down.') |
| x = self.remove_mean_batch(x, combined_mask) |
| x_lig, x_pocket = x[:len(x_lig)], x[len(x_lig):] |
|
|
| |
| out_lig[0] = torch.cat([x_lig, h_lig], dim=1) |
| out_pocket[0] = torch.cat([x_pocket, h_pocket], dim=1) |
|
|
| |
| return out_lig.squeeze(0), out_pocket.squeeze(0), ligand['mask'], \ |
| pocket['mask'] |
|
|
| @staticmethod |
| def gaussian_KL(q_mu_minus_p_mu_squared, q_sigma, p_sigma, d): |
| """Computes the KL distance between two normal distributions. |
| Args: |
| q_mu_minus_p_mu_squared: Squared difference between mean of |
| distribution q and distribution p: ||mu_q - mu_p||^2 |
| q_sigma: Standard deviation of distribution q. |
| p_sigma: Standard deviation of distribution p. |
| d: dimension |
| Returns: |
| The KL distance |
| """ |
| return d * torch.log(p_sigma / q_sigma) + \ |
| 0.5 * (d * q_sigma ** 2 + q_mu_minus_p_mu_squared) / \ |
| (p_sigma ** 2) - 0.5 * d |
|
|
| @staticmethod |
| def inflate_batch_array(array, target): |
| """ |
| Inflates the batch array (array) with only a single axis |
| (i.e. shape = (batch_size,), or possibly more empty axes |
| (i.e. shape (batch_size, 1, ..., 1)) to match the target shape. |
| """ |
| target_shape = (array.size(0),) + (1,) * (len(target.size()) - 1) |
| return array.view(target_shape) |
|
|
| def sigma(self, gamma, target_tensor): |
| """Computes sigma given gamma.""" |
| return self.inflate_batch_array(torch.sqrt(torch.sigmoid(gamma)), |
| target_tensor) |
|
|
| def alpha(self, gamma, target_tensor): |
| """Computes alpha given gamma.""" |
| return self.inflate_batch_array(torch.sqrt(torch.sigmoid(-gamma)), |
| target_tensor) |
|
|
| @staticmethod |
| def SNR(gamma): |
| """Computes signal to noise ratio (alpha^2/sigma^2) given gamma.""" |
| return torch.exp(-gamma) |
|
|
| def normalize(self, ligand=None, pocket=None): |
| if ligand is not None: |
| ligand['x'] = ligand['x'] / self.norm_values[0] |
|
|
| |
| ligand['one_hot'] = \ |
| (ligand['one_hot'].float() - self.norm_biases[1]) / \ |
| self.norm_values[1] |
|
|
| if pocket is not None: |
| pocket['x'] = pocket['x'] / self.norm_values[0] |
| pocket['one_hot'] = \ |
| (pocket['one_hot'].float() - self.norm_biases[1]) / \ |
| self.norm_values[1] |
|
|
| return ligand, pocket |
|
|
| def unnormalize(self, x, h_cat): |
| x = x * self.norm_values[0] |
| h_cat = h_cat * self.norm_values[1] + self.norm_biases[1] |
|
|
| return x, h_cat |
|
|
| def unnormalize_z(self, z_lig, z_pocket): |
| |
| x_lig, h_lig = z_lig[:, :self.n_dims], z_lig[:, self.n_dims:] |
| x_pocket, h_pocket = z_pocket[:, :self.n_dims], z_pocket[:, self.n_dims:] |
|
|
| |
| x_lig, h_lig = self.unnormalize(x_lig, h_lig) |
| x_pocket, h_pocket = self.unnormalize(x_pocket, h_pocket) |
| return torch.cat([x_lig, h_lig], dim=1), \ |
| torch.cat([x_pocket, h_pocket], dim=1) |
|
|
| def subspace_dimensionality(self, input_size): |
| """Compute the dimensionality on translation-invariant linear subspace |
| where distributions on x are defined.""" |
| return (input_size - 1) * self.n_dims |
|
|
| @staticmethod |
| def remove_mean_batch(x, indices): |
| mean = scatter_mean(x, indices, dim=0) |
| x = x - mean[indices] |
| return x |
|
|
| @staticmethod |
| def assert_mean_zero_with_mask(x, node_mask, eps=1e-10): |
| largest_value = x.abs().max().item() |
| error = scatter_add(x, node_mask, dim=0).abs().max().item() |
| rel_error = error / (largest_value + eps) |
| assert rel_error < 1e-2, f'Mean is not zero, relative_error {rel_error}' |
|
|
| @staticmethod |
| def sample_center_gravity_zero_gaussian_batch(size, lig_indices, |
| pocket_indices): |
| assert len(size) == 2 |
| x = torch.randn(size, device=lig_indices.device) |
|
|
| |
| |
| x_projected = EnVariationalDiffusion.remove_mean_batch( |
| x, torch.cat((lig_indices, pocket_indices))) |
| return x_projected |
|
|
| @staticmethod |
| def sum_except_batch(x, indices): |
| return scatter_add(x.sum(-1), indices, dim=0) |
|
|
| @staticmethod |
| def cdf_standard_gaussian(x): |
| return 0.5 * (1. + torch.erf(x / math.sqrt(2))) |
|
|
| @staticmethod |
| def sample_gaussian(size, device): |
| x = torch.randn(size, device=device) |
| return x |
|
|
|
|
| class DistributionNodes: |
| def __init__(self, histogram): |
|
|
| histogram = torch.tensor(histogram).float() |
| histogram = histogram + 1e-3 |
|
|
| prob = histogram / histogram.sum() |
|
|
| self.idx_to_n_nodes = torch.tensor( |
| [[(i, j) for j in range(prob.shape[1])] for i in range(prob.shape[0])] |
| ).view(-1, 2) |
|
|
| self.n_nodes_to_idx = {tuple(x.tolist()): i |
| for i, x in enumerate(self.idx_to_n_nodes)} |
|
|
| self.prob = prob |
| self.m = torch.distributions.Categorical(self.prob.view(-1), |
| validate_args=True) |
|
|
| self.n1_given_n2 = \ |
| [torch.distributions.Categorical(prob[:, j], validate_args=True) |
| for j in range(prob.shape[1])] |
| self.n2_given_n1 = \ |
| [torch.distributions.Categorical(prob[i, :], validate_args=True) |
| for i in range(prob.shape[0])] |
|
|
| |
| entropy = self.m.entropy() |
| print("Entropy of n_nodes: H[N]", entropy.item()) |
|
|
| def sample(self, n_samples=1): |
| idx = self.m.sample((n_samples,)) |
| num_nodes_lig, num_nodes_pocket = self.idx_to_n_nodes[idx].T |
| return num_nodes_lig, num_nodes_pocket |
|
|
| def sample_conditional(self, n1=None, n2=None): |
| assert (n1 is None) ^ (n2 is None), \ |
| "Exactly one input argument must be None" |
|
|
| m = self.n1_given_n2 if n2 is not None else self.n2_given_n1 |
| c = n2 if n2 is not None else n1 |
|
|
| return torch.tensor([m[i].sample() for i in c], device=c.device) |
|
|
| def log_prob(self, batch_n_nodes_1, batch_n_nodes_2): |
| assert len(batch_n_nodes_1.size()) == 1 |
| assert len(batch_n_nodes_2.size()) == 1 |
|
|
| idx = torch.tensor( |
| [self.n_nodes_to_idx[(n1, n2)] |
| for n1, n2 in zip(batch_n_nodes_1.tolist(), batch_n_nodes_2.tolist())] |
| ) |
|
|
| |
| log_probs = self.m.log_prob(idx) |
|
|
| return log_probs.to(batch_n_nodes_1.device) |
|
|
| def log_prob_n1_given_n2(self, n1, n2): |
| assert len(n1.size()) == 1 |
| assert len(n2.size()) == 1 |
| log_probs = torch.stack([self.n1_given_n2[c].log_prob(i.cpu()) |
| for i, c in zip(n1, n2)]) |
| return log_probs.to(n1.device) |
|
|
| def log_prob_n2_given_n1(self, n2, n1): |
| assert len(n2.size()) == 1 |
| assert len(n1.size()) == 1 |
| log_probs = torch.stack([self.n2_given_n1[c].log_prob(i.cpu()) |
| for i, c in zip(n2, n1)]) |
| return log_probs.to(n2.device) |
|
|
|
|
| class PositiveLinear(torch.nn.Module): |
| """Linear layer with weights forced to be positive.""" |
|
|
| def __init__(self, in_features: int, out_features: int, bias: bool = True, |
| weight_init_offset: int = -2): |
| super(PositiveLinear, self).__init__() |
| self.in_features = in_features |
| self.out_features = out_features |
| self.weight = torch.nn.Parameter( |
| torch.empty((out_features, in_features))) |
| if bias: |
| self.bias = torch.nn.Parameter(torch.empty(out_features)) |
| else: |
| self.register_parameter('bias', None) |
| self.weight_init_offset = weight_init_offset |
| self.reset_parameters() |
|
|
| def reset_parameters(self) -> None: |
| torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) |
|
|
| with torch.no_grad(): |
| self.weight.add_(self.weight_init_offset) |
|
|
| if self.bias is not None: |
| fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight) |
| bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 |
| torch.nn.init.uniform_(self.bias, -bound, bound) |
|
|
| def forward(self, input): |
| positive_weight = F.softplus(self.weight) |
| return F.linear(input, positive_weight, self.bias) |
|
|
|
|
| class GammaNetwork(torch.nn.Module): |
| """The gamma network models a monotonic increasing function. |
| Construction as in the VDM paper.""" |
| def __init__(self): |
| super().__init__() |
|
|
| self.l1 = PositiveLinear(1, 1) |
| self.l2 = PositiveLinear(1, 1024) |
| self.l3 = PositiveLinear(1024, 1) |
|
|
| self.gamma_0 = torch.nn.Parameter(torch.tensor([-5.])) |
| self.gamma_1 = torch.nn.Parameter(torch.tensor([10.])) |
| self.show_schedule() |
|
|
| def show_schedule(self, num_steps=50): |
| t = torch.linspace(0, 1, num_steps).view(num_steps, 1) |
| gamma = self.forward(t) |
| print('Gamma schedule:') |
| print(gamma.detach().cpu().numpy().reshape(num_steps)) |
|
|
| def gamma_tilde(self, t): |
| l1_t = self.l1(t) |
| return l1_t + self.l3(torch.sigmoid(self.l2(l1_t))) |
|
|
| def forward(self, t): |
| zeros, ones = torch.zeros_like(t), torch.ones_like(t) |
| |
| gamma_tilde_0 = self.gamma_tilde(zeros) |
| gamma_tilde_1 = self.gamma_tilde(ones) |
| gamma_tilde_t = self.gamma_tilde(t) |
|
|
| |
| normalized_gamma = (gamma_tilde_t - gamma_tilde_0) / ( |
| gamma_tilde_1 - gamma_tilde_0) |
|
|
| |
| gamma = self.gamma_0 + (self.gamma_1 - self.gamma_0) * normalized_gamma |
|
|
| return gamma |
|
|
|
|
| def cosine_beta_schedule(timesteps, s=0.008, raise_to_power: float = 1): |
| """ |
| cosine schedule |
| as proposed in https://openreview.net/forum?id=-NEXDKk8gZ |
| """ |
| steps = timesteps + 2 |
| x = np.linspace(0, steps, steps) |
| alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2 |
| alphas_cumprod = alphas_cumprod / alphas_cumprod[0] |
| betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) |
| betas = np.clip(betas, a_min=0, a_max=0.999) |
| alphas = 1. - betas |
| alphas_cumprod = np.cumprod(alphas, axis=0) |
|
|
| if raise_to_power != 1: |
| alphas_cumprod = np.power(alphas_cumprod, raise_to_power) |
|
|
| return alphas_cumprod |
|
|
|
|
| def clip_noise_schedule(alphas2, clip_value=0.001): |
| """ |
| For a noise schedule given by alpha^2, this clips alpha_t / alpha_t-1. |
| This may help improve stability during |
| sampling. |
| """ |
| alphas2 = np.concatenate([np.ones(1), alphas2], axis=0) |
|
|
| alphas_step = (alphas2[1:] / alphas2[:-1]) |
|
|
| alphas_step = np.clip(alphas_step, a_min=clip_value, a_max=1.) |
| alphas2 = np.cumprod(alphas_step, axis=0) |
|
|
| return alphas2 |
|
|
|
|
| def polynomial_schedule(timesteps: int, s=1e-4, power=3.): |
| """ |
| A noise schedule based on a simple polynomial equation: 1 - x^power. |
| """ |
| steps = timesteps + 1 |
| x = np.linspace(0, steps, steps) |
| alphas2 = (1 - np.power(x / steps, power))**2 |
|
|
| alphas2 = clip_noise_schedule(alphas2, clip_value=0.001) |
|
|
| precision = 1 - 2 * s |
|
|
| alphas2 = precision * alphas2 + s |
|
|
| return alphas2 |
|
|
|
|
| class PredefinedNoiseSchedule(torch.nn.Module): |
| """ |
| Predefined noise schedule. Essentially creates a lookup array for predefined |
| (non-learned) noise schedules. |
| """ |
| def __init__(self, noise_schedule, timesteps, precision): |
| super(PredefinedNoiseSchedule, self).__init__() |
| self.timesteps = timesteps |
|
|
| if noise_schedule == 'cosine': |
| alphas2 = cosine_beta_schedule(timesteps) |
| elif 'polynomial' in noise_schedule: |
| splits = noise_schedule.split('_') |
| assert len(splits) == 2 |
| power = float(splits[1]) |
| alphas2 = polynomial_schedule(timesteps, s=precision, power=power) |
| else: |
| raise ValueError(noise_schedule) |
|
|
| sigmas2 = 1 - alphas2 |
|
|
| log_alphas2 = np.log(alphas2) |
| log_sigmas2 = np.log(sigmas2) |
|
|
| log_alphas2_to_sigmas2 = log_alphas2 - log_sigmas2 |
|
|
| self.gamma = torch.nn.Parameter( |
| torch.from_numpy(-log_alphas2_to_sigmas2).float(), |
| requires_grad=False) |
|
|
| def forward(self, t): |
| t_int = torch.round(t * self.timesteps).long() |
| return self.gamma[t_int] |
|
|