| import os |
| import time |
| import copy |
| from pathlib import Path |
| from math import ceil |
| from contextlib import contextmanager, nullcontext |
| from functools import partial, wraps |
| from collections.abc import Iterable |
|
|
| import torch |
| from torch import nn |
| import torch.nn.functional as F |
| from torch.utils.data import random_split, DataLoader |
| from torch.optim import Adam |
| from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR |
| from torch.cuda.amp import autocast, GradScaler |
|
|
| import pytorch_warmup as warmup |
|
|
| import shutil |
|
|
| import esm |
| from einops import rearrange |
|
|
| from packaging import version |
| __version__ = '1.9.3' |
|
|
| device = torch.device( |
| "cuda:0" if torch.cuda.is_available() else "cpu" |
| ) |
|
|
| import matplotlib.pyplot as plt |
|
|
| def cycle(dl): |
| while True: |
| for data in dl: |
| yield data |
| |
| from packaging import version |
|
|
| import numpy as np |
|
|
| from ema_pytorch import EMA |
|
|
| from accelerate import Accelerator, DistributedType, DistributedDataParallelKwargs |
|
|
| from fsspec.core import url_to_fs |
| from fsspec.implementations.local import LocalFileSystem |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| from PD_pLMProbXDiff.ModelPack import ( |
| resize_image_to, ProteinDesigner_B, |
| ) |
| from PD_pLMProbXDiff.UtilityPack import ( |
| get_Model_A_error, convert_into_tokens,convert_into_tokens_using_prob, |
| decode_one_ems_token_rec, decode_many_ems_token_rec, |
| decode_one_ems_token_rec_for_folding, |
| decode_many_ems_token_rec_for_folding, |
| decode_one_ems_token_rec_for_folding_with_mask, |
| decode_many_ems_token_rec_for_folding_with_mask, |
| read_mask_from_input, |
| get_DSSP_result, |
| string_diff, |
| load_in_pLM, |
| ) |
| from PD_pLMProbXDiff.DataSetPack import ( |
| pad_a_np_arr |
| ) |
|
|
| |
| criterion_MSE_sum = nn.MSELoss(reduction='sum') |
| criterion_MAE_sum = nn.L1Loss(reduction='sum') |
|
|
| |
|
|
| def exists(val): |
| return val is not None |
|
|
| def default(val, d): |
| if exists(val): |
| return val |
| return d() if callable(d) else d |
|
|
| def cast_tuple(val, length = 1): |
| if isinstance(val, list): |
| val = tuple(val) |
| |
| return val if isinstance(val, tuple) else ((val,) * length) |
|
|
| def find_first(fn, arr): |
| for ind, el in enumerate(arr): |
| if fn(el): |
| return ind |
| return -1 |
|
|
| def pick_and_pop(keys, d): |
| values = list(map(lambda key: d.pop(key), keys)) |
| return dict(zip(keys, values)) |
|
|
| def group_dict_by_key(cond, d): |
| return_val = [dict(),dict()] |
| for key in d.keys(): |
| match = bool(cond(key)) |
| ind = int(not match) |
| return_val[ind][key] = d[key] |
| return (*return_val,) |
|
|
| def string_begins_with(prefix, str): |
| return str.startswith(prefix) |
|
|
| def group_by_key_prefix(prefix, d): |
| return group_dict_by_key(partial(string_begins_with, prefix), d) |
|
|
| def groupby_prefix_and_trim(prefix, d): |
| kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) |
| kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) |
| return kwargs_without_prefix, kwargs |
|
|
| def num_to_groups(num, divisor): |
| groups = num // divisor |
| remainder = num % divisor |
| arr = [divisor] * groups |
| if remainder > 0: |
| arr.append(remainder) |
| return arr |
|
|
| |
|
|
| def url_to_bucket(url): |
| if '://' not in url: |
| return url |
|
|
| _, suffix = url.split('://') |
|
|
| if prefix in {'gs', 's3'}: |
| return suffix.split('/')[0] |
| else: |
| raise ValueError(f'storage type prefix "{prefix}" is not supported yet') |
|
|
| |
|
|
| def eval_decorator(fn): |
| def inner(model, *args, **kwargs): |
| was_training = model.training |
| model.eval() |
| out = fn(model, *args, **kwargs) |
| model.train(was_training) |
| return out |
| return inner |
|
|
| def cast_torch_tensor(fn, cast_fp16 = False): |
| @wraps(fn) |
| def inner(model, *args, **kwargs): |
| device = kwargs.pop('_device', model.device) |
| cast_device = kwargs.pop('_cast_device', True) |
|
|
| should_cast_fp16 = cast_fp16 and model.cast_half_at_training |
|
|
| kwargs_keys = kwargs.keys() |
| all_args = (*args, *kwargs.values()) |
| split_kwargs_index = len(all_args) - len(kwargs_keys) |
| all_args = tuple(map(lambda t: torch.from_numpy(t) if exists(t) and isinstance(t, np.ndarray) else t, all_args)) |
|
|
| if cast_device: |
| all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args)) |
|
|
| if should_cast_fp16: |
| all_args = tuple(map(lambda t: t.half() if exists(t) and isinstance(t, torch.Tensor) and t.dtype != torch.bool else t, all_args)) |
|
|
| args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:] |
| kwargs = dict(tuple(zip(kwargs_keys, kwargs_values))) |
|
|
| out = fn(model, *args, **kwargs) |
| return out |
| return inner |
|
|
| |
|
|
| def split_iterable(it, split_size): |
| accum = [] |
| for ind in range(ceil(len(it) / split_size)): |
| start_index = ind * split_size |
| accum.append(it[start_index: (start_index + split_size)]) |
| return accum |
|
|
| def split(t, split_size = None): |
| if not exists(split_size): |
| return t |
|
|
| if isinstance(t, torch.Tensor): |
| return t.split(split_size, dim = 0) |
|
|
| if isinstance(t, Iterable): |
| return split_iterable(t, split_size) |
|
|
| return TypeError |
|
|
| def find_first(cond, arr): |
| for el in arr: |
| if cond(el): |
| return el |
| return None |
|
|
| def split_args_and_kwargs(*args, split_size = None, **kwargs): |
| all_args = (*args, *kwargs.values()) |
| len_all_args = len(all_args) |
| first_tensor = find_first(lambda t: isinstance(t, torch.Tensor), all_args) |
| assert exists(first_tensor) |
|
|
| batch_size = len(first_tensor) |
| split_size = default(split_size, batch_size) |
| num_chunks = ceil(batch_size / split_size) |
|
|
| dict_len = len(kwargs) |
| dict_keys = kwargs.keys() |
| split_kwargs_index = len_all_args - dict_len |
|
|
| split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * num_chunks) for arg in all_args] |
| chunk_sizes = tuple(map(len, split_all_args[0])) |
|
|
| for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)): |
| chunked_args, chunked_kwargs_values = chunked_all_args[:split_kwargs_index], chunked_all_args[split_kwargs_index:] |
| chunked_kwargs = dict(tuple(zip(dict_keys, chunked_kwargs_values))) |
| chunk_size_frac = chunk_size / batch_size |
| yield chunk_size_frac, (chunked_args, chunked_kwargs) |
|
|
| |
|
|
| def imagen_sample_in_chunks(fn): |
| @wraps(fn) |
| def inner(self, *args, max_batch_size = None, **kwargs): |
| if not exists(max_batch_size): |
| return fn(self, *args, **kwargs) |
|
|
| if self.imagen.unconditional: |
| batch_size = kwargs.get('batch_size') |
| batch_sizes = num_to_groups(batch_size, max_batch_size) |
| outputs = [fn(self, *args, **{**kwargs, 'batch_size': sub_batch_size}) for sub_batch_size in batch_sizes] |
| else: |
| outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)] |
|
|
| if isinstance(outputs[0], torch.Tensor): |
| return torch.cat(outputs, dim = 0) |
|
|
| return list(map(lambda t: torch.cat(t, dim = 0), list(zip(*outputs)))) |
|
|
| return inner |
|
|
|
|
| def restore_parts(state_dict_target, state_dict_from): |
| for name, param in state_dict_from.items(): |
|
|
| if name not in state_dict_target: |
| continue |
|
|
| if param.size() == state_dict_target[name].size(): |
| state_dict_target[name].copy_(param) |
| else: |
| print(f"layer {name}({param.size()} different than target: {state_dict_target[name].size()}") |
|
|
| return state_dict_target |
|
|
| class ImagenTrainer(nn.Module): |
| locked = False |
|
|
| def __init__( |
| self, |
| |
| model = None, |
| |
| imagen_checkpoint_path = None, |
| use_ema = True, |
| lr = 1e-4, |
| eps = 1e-8, |
| beta1 = 0.9, |
| beta2 = 0.99, |
| max_grad_norm = None, |
| group_wd_params = True, |
| warmup_steps = None, |
| cosine_decay_max_steps = None, |
| only_train_unet_number = None, |
| fp16 = False, |
| precision = None, |
| split_batches = True, |
| dl_tuple_output_keywords_names = ('images', 'text_embeds', 'text_masks', 'cond_images'), |
| verbose = True, |
| split_valid_fraction = 0.025, |
| split_valid_from_train = False, |
| split_random_seed = 42, |
| checkpoint_path = None, |
| checkpoint_every = None, |
| checkpoint_fs = None, |
| fs_kwargs: dict = None, |
| max_checkpoints_keep = 20, |
| |
| CKeys=None, |
| |
| **kwargs |
| ): |
| super().__init__() |
| assert not ImagenTrainer.locked, 'ImagenTrainer can only be initialized once per process - for the sake of distributed training, you will now have to create a separate script to train each unet (or a script that accepts unet number as an argument)' |
| assert exists(model.imagen) ^ exists(imagen_checkpoint_path), 'either imagen instance is passed into the trainer, or a checkpoint path that contains the imagen config' |
|
|
| |
|
|
| self.fs = checkpoint_fs |
|
|
| if not exists(self.fs): |
| fs_kwargs = default(fs_kwargs, {}) |
| self.fs, _ = url_to_fs(default(checkpoint_path, './'), **fs_kwargs) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| self.CKeys = CKeys |
| |
| ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs) |
|
|
| |
| self.imagen = model.imagen |
| |
| |
|
|
| self.model=model |
| self.is_elucidated = self.model.is_elucidated |
| |
|
|
| accelerate_kwargs, kwargs = groupby_prefix_and_trim('accelerate_', kwargs) |
|
|
| assert not (fp16 and exists(precision)), 'either set fp16 = True or forward the precision ("fp16", "bf16") to Accelerator' |
| accelerator_mixed_precision = default(precision, 'fp16' if fp16 else 'no') |
|
|
| self.accelerator = Accelerator(**{ |
| 'split_batches': split_batches, |
| 'mixed_precision': accelerator_mixed_precision, |
| 'kwargs_handlers': [DistributedDataParallelKwargs(find_unused_parameters = True)] |
| , **accelerate_kwargs}) |
|
|
| ImagenTrainer.locked = self.is_distributed |
|
|
| |
|
|
| self.cast_half_at_training = accelerator_mixed_precision == 'fp16' |
|
|
| |
|
|
| grad_scaler_enabled = fp16 |
| |
| self.num_unets = len(self.imagen.unets) |
|
|
| self.use_ema = use_ema and self.is_main |
| self.ema_unets = nn.ModuleList([]) |
|
|
| |
| |
|
|
| self.ema_unet_being_trained_index = -1 |
|
|
| |
|
|
| self.train_dl_iter = None |
| self.train_dl = None |
|
|
| self.valid_dl_iter = None |
| self.valid_dl = None |
|
|
| self.dl_tuple_output_keywords_names = dl_tuple_output_keywords_names |
|
|
| |
|
|
| self.split_valid_from_train = split_valid_from_train |
|
|
| assert 0 <= split_valid_fraction <= 1, 'split valid fraction must be between 0 and 1' |
| self.split_valid_fraction = split_valid_fraction |
| self.split_random_seed = split_random_seed |
|
|
| |
| |
|
|
| lr, eps, warmup_steps, cosine_decay_max_steps = map(partial(cast_tuple, length = self.num_unets), (lr, eps, warmup_steps, cosine_decay_max_steps)) |
|
|
| for ind, (unet, unet_lr, unet_eps, unet_warmup_steps, unet_cosine_decay_max_steps) in enumerate(zip(self.imagen.unets, lr, eps, warmup_steps, cosine_decay_max_steps)): |
| optimizer = Adam( |
| unet.parameters(), |
| lr = unet_lr, |
| eps = unet_eps, |
| betas = (beta1, beta2), |
| **kwargs |
| ) |
|
|
| if self.use_ema: |
| self.ema_unets.append(EMA(unet, **ema_kwargs)) |
|
|
| scaler = GradScaler(enabled = grad_scaler_enabled) |
|
|
| scheduler = warmup_scheduler = None |
|
|
| if exists(unet_cosine_decay_max_steps): |
| scheduler = CosineAnnealingLR(optimizer, T_max = unet_cosine_decay_max_steps) |
|
|
| if exists(unet_warmup_steps): |
| warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) |
|
|
| if not exists(scheduler): |
| scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0) |
|
|
| |
|
|
| setattr(self, f'optim{ind}', optimizer) |
| setattr(self, f'scaler{ind}', scaler) |
| setattr(self, f'scheduler{ind}', scheduler) |
| setattr(self, f'warmup{ind}', warmup_scheduler) |
|
|
| |
|
|
| self.max_grad_norm = max_grad_norm |
|
|
| |
|
|
| self.register_buffer('steps', torch.tensor([0] * self.num_unets)) |
|
|
| self.verbose = verbose |
|
|
| |
|
|
| self.imagen.to(self.device) |
| self.to(self.device) |
|
|
| |
|
|
| assert not (exists(checkpoint_path) ^ exists(checkpoint_every)) |
| self.checkpoint_path = checkpoint_path |
| self.checkpoint_every = checkpoint_every |
| self.max_checkpoints_keep = max_checkpoints_keep |
|
|
| self.can_checkpoint = self.is_local_main if isinstance(checkpoint_fs, LocalFileSystem) else self.is_main |
|
|
| if exists(checkpoint_path) and self.can_checkpoint: |
| bucket = url_to_bucket(checkpoint_path) |
|
|
| if not self.fs.exists(bucket): |
| self.fs.mkdir(bucket) |
|
|
| self.load_from_checkpoint_folder() |
|
|
| |
|
|
| self.only_train_unet_number = only_train_unet_number |
| self.validate_and_set_unet_being_trained(only_train_unet_number) |
|
|
| |
|
|
| @property |
| def device(self): |
| return self.accelerator.device |
|
|
| @property |
| def is_distributed(self): |
| return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1) |
|
|
| @property |
| def is_main(self): |
| return self.accelerator.is_main_process |
|
|
| @property |
| def is_local_main(self): |
| return self.accelerator.is_local_main_process |
|
|
| @property |
| def unwrapped_unet(self): |
| return self.accelerator.unwrap_model(self.unet_being_trained) |
|
|
| |
|
|
| def get_lr(self, unet_number): |
| self.validate_unet_number(unet_number) |
| unet_index = unet_number - 1 |
|
|
| optim = getattr(self, f'optim{unet_index}') |
|
|
| return optim.param_groups[0]['lr'] |
|
|
| |
|
|
| def validate_and_set_unet_being_trained(self, unet_number = None): |
| if exists(unet_number): |
| self.validate_unet_number(unet_number) |
|
|
| assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, 'you cannot only train on one unet at a time. you will need to save the trainer into a checkpoint, and resume training on a new unet' |
|
|
| self.only_train_unet_number = unet_number |
| self.imagen.only_train_unet_number = unet_number |
|
|
| if not exists(unet_number): |
| return |
|
|
| self.wrap_unet(unet_number) |
|
|
| def wrap_unet(self, unet_number): |
| if hasattr(self, 'one_unet_wrapped'): |
| return |
|
|
| unet = self.imagen.get_unet(unet_number) |
| self.unet_being_trained = self.accelerator.prepare(unet) |
| unet_index = unet_number - 1 |
|
|
| optimizer = getattr(self, f'optim{unet_index}') |
| scheduler = getattr(self, f'scheduler{unet_index}') |
|
|
| optimizer = self.accelerator.prepare(optimizer) |
|
|
| if exists(scheduler): |
| scheduler = self.accelerator.prepare(scheduler) |
|
|
| setattr(self, f'optim{unet_index}', optimizer) |
| setattr(self, f'scheduler{unet_index}', scheduler) |
|
|
| self.one_unet_wrapped = True |
|
|
| |
|
|
| def set_accelerator_scaler(self, unet_number): |
| unet_number = self.validate_unet_number(unet_number) |
| scaler = getattr(self, f'scaler{unet_number - 1}') |
|
|
| self.accelerator.scaler = scaler |
| for optimizer in self.accelerator._optimizers: |
| optimizer.scaler = scaler |
|
|
| |
|
|
| def print(self, msg): |
| if not self.is_main: |
| return |
|
|
| if not self.verbose: |
| return |
|
|
| return self.accelerator.print(msg) |
|
|
| |
|
|
| def validate_unet_number(self, unet_number = None): |
| if self.num_unets == 1: |
| unet_number = default(unet_number, 1) |
|
|
| assert 0 < unet_number <= self.num_unets, f'unet number should be in between 1 and {self.num_unets}' |
| return unet_number |
|
|
| |
|
|
| def num_steps_taken(self, unet_number = None): |
| if self.num_unets == 1: |
| unet_number = default(unet_number, 1) |
|
|
| return self.steps[unet_number - 1].item() |
|
|
| def print_untrained_unets(self): |
| print_final_error = False |
|
|
| for ind, (steps, unet) in enumerate(zip(self.steps.tolist(), self.imagen.unets)): |
| if steps > 0 or isinstance(unet, NullUnet): |
| continue |
|
|
| self.print(f'unet {ind + 1} has not been trained') |
| print_final_error = True |
|
|
| if print_final_error: |
| self.print('when sampling, you can pass stop_at_unet_number to stop early in the cascade, so it does not try to generate with untrained unets') |
|
|
| |
|
|
| def add_train_dataloader(self, dl = None): |
| if not exists(dl): |
| return |
|
|
| assert not exists(self.train_dl), 'training dataloader was already added' |
| self.train_dl = self.accelerator.prepare(dl) |
|
|
| def add_valid_dataloader(self, dl): |
| if not exists(dl): |
| return |
|
|
| assert not exists(self.valid_dl), 'validation dataloader was already added' |
| self.valid_dl = self.accelerator.prepare(dl) |
|
|
| def add_train_dataset(self, ds = None, *, batch_size, **dl_kwargs): |
| if not exists(ds): |
| return |
|
|
| assert not exists(self.train_dl), 'training dataloader was already added' |
|
|
| valid_ds = None |
| if self.split_valid_from_train: |
| train_size = int((1 - self.split_valid_fraction) * len(ds)) |
| valid_size = len(ds) - train_size |
|
|
| ds, valid_ds = random_split(ds, [train_size, valid_size], generator = torch.Generator().manual_seed(self.split_random_seed)) |
| self.print(f'training with dataset of {len(ds)} samples and validating with randomly splitted {len(valid_ds)} samples') |
|
|
| dl = DataLoader(ds, batch_size = batch_size, **dl_kwargs) |
| self.train_dl = self.accelerator.prepare(dl) |
|
|
| if not self.split_valid_from_train: |
| return |
|
|
| self.add_valid_dataset(valid_ds, batch_size = batch_size, **dl_kwargs) |
|
|
| def add_valid_dataset(self, ds, *, batch_size, **dl_kwargs): |
| if not exists(ds): |
| return |
|
|
| assert not exists(self.valid_dl), 'validation dataloader was already added' |
|
|
| dl = DataLoader(ds, batch_size = batch_size, **dl_kwargs) |
| self.valid_dl = self.accelerator.prepare(dl) |
|
|
| def create_train_iter(self): |
| assert exists(self.train_dl), 'training dataloader has not been registered with the trainer yet' |
|
|
| if exists(self.train_dl_iter): |
| return |
|
|
| self.train_dl_iter = cycle(self.train_dl) |
|
|
| def create_valid_iter(self): |
| assert exists(self.valid_dl), 'validation dataloader has not been registered with the trainer yet' |
|
|
| if exists(self.valid_dl_iter): |
| return |
|
|
| self.valid_dl_iter = cycle(self.valid_dl) |
|
|
| def train_step(self, unet_number = None, **kwargs): |
| self.create_train_iter() |
| loss = self.step_with_dl_iter(self.train_dl_iter, unet_number = unet_number, **kwargs) |
| self.update(unet_number = unet_number) |
| return loss |
|
|
| @torch.no_grad() |
| @eval_decorator |
| def valid_step(self, **kwargs): |
| self.create_valid_iter() |
|
|
| context = self.use_ema_unets if kwargs.pop('use_ema_unets', False) else nullcontext |
|
|
| with context(): |
| loss = self.step_with_dl_iter(self.valid_dl_iter, **kwargs) |
| return loss |
|
|
| def step_with_dl_iter(self, dl_iter, **kwargs): |
| dl_tuple_output = cast_tuple(next(dl_iter)) |
| model_input = dict(list(zip(self.dl_tuple_output_keywords_names, dl_tuple_output))) |
| loss = self.forward(**{**kwargs, **model_input}) |
| return loss |
|
|
| |
|
|
| @property |
| def all_checkpoints_sorted(self): |
| glob_pattern = os.path.join(self.checkpoint_path, '*.pt') |
| checkpoints = self.fs.glob(glob_pattern) |
| sorted_checkpoints = sorted(checkpoints, key = lambda x: int(str(x).split('.')[-2]), reverse = True) |
| return sorted_checkpoints |
|
|
| def load_from_checkpoint_folder(self, last_total_steps = -1): |
| if last_total_steps != -1: |
| filepath = os.path.join(self.checkpoint_path, f'checkpoint.{last_total_steps}.pt') |
| self.load(filepath) |
| return |
|
|
| sorted_checkpoints = self.all_checkpoints_sorted |
|
|
| if len(sorted_checkpoints) == 0: |
| self.print(f'no checkpoints found to load from at {self.checkpoint_path}') |
| return |
|
|
| last_checkpoint = sorted_checkpoints[0] |
| self.load(last_checkpoint) |
|
|
| def save_to_checkpoint_folder(self): |
| self.accelerator.wait_for_everyone() |
|
|
| if not self.can_checkpoint: |
| return |
|
|
| total_steps = int(self.steps.sum().item()) |
| filepath = os.path.join(self.checkpoint_path, f'checkpoint.{total_steps}.pt') |
|
|
| self.save(filepath) |
|
|
| if self.max_checkpoints_keep <= 0: |
| return |
|
|
| sorted_checkpoints = self.all_checkpoints_sorted |
| checkpoints_to_discard = sorted_checkpoints[self.max_checkpoints_keep:] |
|
|
| for checkpoint in checkpoints_to_discard: |
| self.fs.rm(checkpoint) |
|
|
| |
|
|
| def save( |
| self, |
| path, |
| overwrite = True, |
| without_optim_and_sched = False, |
| **kwargs |
| ): |
| self.accelerator.wait_for_everyone() |
|
|
| if not self.can_checkpoint: |
| return |
|
|
| fs = self.fs |
|
|
| assert not (fs.exists(path) and not overwrite) |
|
|
| self.reset_ema_unets_all_one_device() |
|
|
| save_obj = dict( |
| model = self.imagen.state_dict(), |
| version = __version__, |
| steps = self.steps.cpu(), |
| **kwargs |
| ) |
|
|
| save_optim_and_sched_iter = range(0, self.num_unets) if not without_optim_and_sched else tuple() |
|
|
| for ind in save_optim_and_sched_iter: |
| scaler_key = f'scaler{ind}' |
| optimizer_key = f'optim{ind}' |
| scheduler_key = f'scheduler{ind}' |
| warmup_scheduler_key = f'warmup{ind}' |
|
|
| scaler = getattr(self, scaler_key) |
| optimizer = getattr(self, optimizer_key) |
| scheduler = getattr(self, scheduler_key) |
| warmup_scheduler = getattr(self, warmup_scheduler_key) |
|
|
| if exists(scheduler): |
| save_obj = {**save_obj, scheduler_key: scheduler.state_dict()} |
|
|
| if exists(warmup_scheduler): |
| save_obj = {**save_obj, warmup_scheduler_key: warmup_scheduler.state_dict()} |
|
|
| save_obj = {**save_obj, scaler_key: scaler.state_dict(), optimizer_key: optimizer.state_dict()} |
|
|
| if self.use_ema: |
| save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()} |
|
|
| |
|
|
| if hasattr(self.imagen, '_config'): |
| self.print(f'this checkpoint is commandable from the CLI - "imagen --model {str(path)} \"<prompt>\""') |
|
|
| save_obj = { |
| **save_obj, |
| 'imagen_type': 'elucidated' if self.is_elucidated else 'original', |
| 'imagen_params': self.imagen._config |
| } |
|
|
| |
|
|
| with fs.open(path, 'wb') as f: |
| torch.save(save_obj, f) |
|
|
| self.print(f'checkpoint saved to {path}') |
|
|
| def load(self, path, only_model = False, strict = True, noop_if_not_exist = False): |
| fs = self.fs |
|
|
| if noop_if_not_exist and not fs.exists(path): |
| self.print(f'trainer checkpoint not found at {str(path)}') |
| return |
|
|
| assert fs.exists(path), f'{path} does not exist' |
|
|
| self.reset_ema_unets_all_one_device() |
|
|
| |
|
|
| with fs.open(path) as f: |
| loaded_obj = torch.load(f, map_location='cpu') |
|
|
| if version.parse(__version__) != version.parse(loaded_obj['version']): |
| self.print(f'loading saved imagen at version {loaded_obj["version"]}, but current package version is {__version__}') |
|
|
| try: |
| self.imagen.load_state_dict(loaded_obj['model'], strict = strict) |
| except RuntimeError: |
| print("Failed loading state dict. Trying partial load") |
| self.imagen.load_state_dict(restore_parts(self.imagen.state_dict(), |
| loaded_obj['model'])) |
|
|
| if only_model: |
| return loaded_obj |
|
|
| self.steps.copy_(loaded_obj['steps']) |
|
|
| for ind in range(0, self.num_unets): |
| scaler_key = f'scaler{ind}' |
| optimizer_key = f'optim{ind}' |
| scheduler_key = f'scheduler{ind}' |
| warmup_scheduler_key = f'warmup{ind}' |
|
|
| scaler = getattr(self, scaler_key) |
| optimizer = getattr(self, optimizer_key) |
| scheduler = getattr(self, scheduler_key) |
| warmup_scheduler = getattr(self, warmup_scheduler_key) |
|
|
| if exists(scheduler) and scheduler_key in loaded_obj: |
| scheduler.load_state_dict(loaded_obj[scheduler_key]) |
|
|
| if exists(warmup_scheduler) and warmup_scheduler_key in loaded_obj: |
| warmup_scheduler.load_state_dict(loaded_obj[warmup_scheduler_key]) |
|
|
| if exists(optimizer): |
| try: |
| optimizer.load_state_dict(loaded_obj[optimizer_key]) |
| scaler.load_state_dict(loaded_obj[scaler_key]) |
| except: |
| self.print('could not load optimizer and scaler, possibly because you have turned on mixed precision training since the last run. resuming with new optimizer and scalers') |
|
|
| if self.use_ema: |
| assert 'ema' in loaded_obj |
| try: |
| self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict) |
| except RuntimeError: |
| print("Failed loading state dict. Trying partial load") |
| self.ema_unets.load_state_dict(restore_parts(self.ema_unets.state_dict(), |
| loaded_obj['ema'])) |
|
|
| self.print(f'checkpoint loaded from {path}') |
| return loaded_obj |
|
|
| |
|
|
| @property |
| def unets(self): |
| return nn.ModuleList([ema.ema_model for ema in self.ema_unets]) |
|
|
| def get_ema_unet(self, unet_number = None): |
| if not self.use_ema: |
| return |
|
|
| unet_number = self.validate_unet_number(unet_number) |
| index = unet_number - 1 |
|
|
| if isinstance(self.unets, nn.ModuleList): |
| unets_list = [unet for unet in self.ema_unets] |
| delattr(self, 'ema_unets') |
| self.ema_unets = unets_list |
|
|
| if index != self.ema_unet_being_trained_index: |
| for unet_index, unet in enumerate(self.ema_unets): |
| unet.to(self.device if unet_index == index else 'cpu') |
|
|
| self.ema_unet_being_trained_index = index |
| return self.ema_unets[index] |
|
|
| def reset_ema_unets_all_one_device(self, device = None): |
| if not self.use_ema: |
| return |
|
|
| device = default(device, self.device) |
| self.ema_unets = nn.ModuleList([*self.ema_unets]) |
| self.ema_unets.to(device) |
|
|
| self.ema_unet_being_trained_index = -1 |
|
|
| @torch.no_grad() |
| @contextmanager |
| def use_ema_unets(self): |
| if not self.use_ema: |
| output = yield |
| return output |
|
|
| self.reset_ema_unets_all_one_device() |
| self.imagen.reset_unets_all_one_device() |
|
|
| self.unets.eval() |
|
|
| trainable_unets = self.imagen.unets |
| self.imagen.unets = self.unets |
|
|
| output = yield |
|
|
| self.imagen.unets = trainable_unets |
|
|
| |
| for ema in self.ema_unets: |
| ema.restore_ema_model_device() |
|
|
| return output |
|
|
| def print_unet_devices(self): |
| self.print('unet devices:') |
| for i, unet in enumerate(self.imagen.unets): |
| device = next(unet.parameters()).device |
| self.print(f'\tunet {i}: {device}') |
|
|
| if not self.use_ema: |
| return |
|
|
| self.print('\nema unet devices:') |
| for i, ema_unet in enumerate(self.ema_unets): |
| device = next(ema_unet.parameters()).device |
| self.print(f'\tema unet {i}: {device}') |
|
|
| |
|
|
| def state_dict(self, *args, **kwargs): |
| self.reset_ema_unets_all_one_device() |
| return super().state_dict(*args, **kwargs) |
|
|
| def load_state_dict(self, *args, **kwargs): |
| self.reset_ema_unets_all_one_device() |
| return super().load_state_dict(*args, **kwargs) |
|
|
| |
|
|
| def encode_text(self, text, **kwargs): |
| return self.imagen.encode_text(text, **kwargs) |
|
|
| |
|
|
| def update(self, unet_number = None): |
| unet_number = self.validate_unet_number(unet_number) |
| self.validate_and_set_unet_being_trained(unet_number) |
| self.set_accelerator_scaler(unet_number) |
|
|
| index = unet_number - 1 |
| unet = self.unet_being_trained |
|
|
| optimizer = getattr(self, f'optim{index}') |
| scaler = getattr(self, f'scaler{index}') |
| scheduler = getattr(self, f'scheduler{index}') |
| warmup_scheduler = getattr(self, f'warmup{index}') |
|
|
| |
|
|
| if exists(self.max_grad_norm): |
| self.accelerator.clip_grad_norm_(unet.parameters(), self.max_grad_norm) |
|
|
| optimizer.step() |
| optimizer.zero_grad() |
|
|
| if self.use_ema: |
| ema_unet = self.get_ema_unet(unet_number) |
| ema_unet.update() |
|
|
| |
|
|
| maybe_warmup_context = nullcontext() if not exists(warmup_scheduler) else warmup_scheduler.dampening() |
|
|
| with maybe_warmup_context: |
| if exists(scheduler) and not self.accelerator.optimizer_step_was_skipped: |
| scheduler.step() |
|
|
| self.steps += F.one_hot(torch.tensor(unet_number - 1, device = self.steps.device), num_classes = len(self.steps)) |
|
|
| if not exists(self.checkpoint_path): |
| return |
|
|
| total_steps = int(self.steps.sum().item()) |
|
|
| if total_steps % self.checkpoint_every: |
| return |
|
|
| self.save_to_checkpoint_folder() |
|
|
| @torch.no_grad() |
| @cast_torch_tensor |
| @imagen_sample_in_chunks |
| def sample(self, *args, **kwargs): |
| context = nullcontext if kwargs.pop('use_non_ema', False) else self.use_ema_unets |
|
|
| self.print_untrained_unets() |
| |
| if not self.is_main: |
| kwargs['use_tqdm'] = False |
|
|
| with context(): |
| output = self.imagen.sample(*args, device = self.device, **kwargs) |
|
|
| return output |
|
|
| @partial(cast_torch_tensor, cast_fp16 = True) |
| def forward( |
| self, |
| *args, |
| unet_number = None, |
| max_batch_size = None, |
| **kwargs |
| ): |
| unet_number = self.validate_unet_number(unet_number) |
| self.validate_and_set_unet_being_trained(unet_number) |
| self.set_accelerator_scaler(unet_number) |
|
|
| assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, f'you can only train unet #{self.only_train_unet_number}' |
|
|
| total_loss = 0. |
| |
| |
| |
| if self.CKeys['Debug_TrainerPack']==1: |
| print("In Trainer:Forward, check inputs:") |
| print('args: ', len(args)) |
| print('args in:',args[0].shape) |
| print('kwargs: ', kwargs.keys()) |
| for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs): |
| |
| if self.CKeys['Debug_TrainerPack']==1: |
| print("after chunks,...") |
| print('chun_frac: ', chunk_size_frac) |
| print('chun_args: ', chunked_args) |
| print('chun_kwargs: ', chunked_kwargs) |
| |
| with self.accelerator.autocast(): |
| loss = self.model( |
| *chunked_args, |
| unet_number = unet_number, |
| **chunked_kwargs |
| ) |
| loss = loss * chunk_size_frac |
| |
| |
| if self.CKeys['Debug_TrainerPack']==1: |
| print('part chun loss: ', loss) |
|
|
| total_loss += loss |
|
|
| if self.training: |
| self.accelerator.backward(loss) |
|
|
| return total_loss |
| |
| |
| |
| class ImagenTrainer_ModelB(nn.Module): |
| locked = False |
|
|
| def __init__( |
| self, |
| |
| model = None, |
| |
| imagen_checkpoint_path = None, |
| use_ema = True, |
| lr = 1e-4, |
| eps = 1e-8, |
| beta1 = 0.9, |
| beta2 = 0.99, |
| max_grad_norm = None, |
| group_wd_params = True, |
| warmup_steps = None, |
| cosine_decay_max_steps = None, |
| only_train_unet_number = None, |
| fp16 = False, |
| precision = None, |
| split_batches = True, |
| dl_tuple_output_keywords_names = ('images', 'text_embeds', 'text_masks', 'cond_images'), |
| verbose = True, |
| split_valid_fraction = 0.025, |
| split_valid_from_train = False, |
| split_random_seed = 42, |
| checkpoint_path = None, |
| checkpoint_every = None, |
| checkpoint_fs = None, |
| fs_kwargs: dict = None, |
| max_checkpoints_keep = 20, |
| |
| CKeys=None, |
| |
| **kwargs |
| ): |
| super().__init__() |
| assert not ImagenTrainer.locked, 'ImagenTrainer can only be initialized once per process - for the sake of distributed training, you will now have to create a separate script to train each unet (or a script that accepts unet number as an argument)' |
| assert exists(model.imagen) ^ exists(imagen_checkpoint_path), 'either imagen instance is passed into the trainer, or a checkpoint path that contains the imagen config' |
|
|
| |
|
|
| self.fs = checkpoint_fs |
|
|
| if not exists(self.fs): |
| fs_kwargs = default(fs_kwargs, {}) |
| self.fs, _ = url_to_fs(default(checkpoint_path, './'), **fs_kwargs) |
| |
| |
| |
| |
| |
| |
| assert isinstance(model, (ProteinDesigner_B)) |
| |
| |
| self.CKeys = CKeys |
| |
| ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs) |
|
|
| |
| self.imagen = model.imagen |
| |
| |
|
|
| self.model=model |
| self.is_elucidated = self.model.is_elucidated |
| |
|
|
| accelerate_kwargs, kwargs = groupby_prefix_and_trim('accelerate_', kwargs) |
|
|
| assert not (fp16 and exists(precision)), 'either set fp16 = True or forward the precision ("fp16", "bf16") to Accelerator' |
| accelerator_mixed_precision = default(precision, 'fp16' if fp16 else 'no') |
|
|
| self.accelerator = Accelerator(**{ |
| 'split_batches': split_batches, |
| 'mixed_precision': accelerator_mixed_precision, |
| 'kwargs_handlers': [DistributedDataParallelKwargs(find_unused_parameters = True)] |
| , **accelerate_kwargs}) |
|
|
| ImagenTrainer.locked = self.is_distributed |
|
|
| |
|
|
| self.cast_half_at_training = accelerator_mixed_precision == 'fp16' |
|
|
| |
|
|
| grad_scaler_enabled = fp16 |
| |
| self.num_unets = len(self.imagen.unets) |
|
|
| self.use_ema = use_ema and self.is_main |
| self.ema_unets = nn.ModuleList([]) |
|
|
| |
| |
|
|
| self.ema_unet_being_trained_index = -1 |
|
|
| |
|
|
| self.train_dl_iter = None |
| self.train_dl = None |
|
|
| self.valid_dl_iter = None |
| self.valid_dl = None |
|
|
| self.dl_tuple_output_keywords_names = dl_tuple_output_keywords_names |
|
|
| |
|
|
| self.split_valid_from_train = split_valid_from_train |
|
|
| assert 0 <= split_valid_fraction <= 1, 'split valid fraction must be between 0 and 1' |
| self.split_valid_fraction = split_valid_fraction |
| self.split_random_seed = split_random_seed |
|
|
| |
| |
|
|
| lr, eps, warmup_steps, cosine_decay_max_steps = map(partial(cast_tuple, length = self.num_unets), (lr, eps, warmup_steps, cosine_decay_max_steps)) |
|
|
| for ind, (unet, unet_lr, unet_eps, unet_warmup_steps, unet_cosine_decay_max_steps) in enumerate(zip(self.imagen.unets, lr, eps, warmup_steps, cosine_decay_max_steps)): |
| optimizer = Adam( |
| unet.parameters(), |
| lr = unet_lr, |
| eps = unet_eps, |
| betas = (beta1, beta2), |
| **kwargs |
| ) |
|
|
| if self.use_ema: |
| self.ema_unets.append(EMA(unet, **ema_kwargs)) |
|
|
| scaler = GradScaler(enabled = grad_scaler_enabled) |
|
|
| scheduler = warmup_scheduler = None |
|
|
| if exists(unet_cosine_decay_max_steps): |
| scheduler = CosineAnnealingLR(optimizer, T_max = unet_cosine_decay_max_steps) |
|
|
| if exists(unet_warmup_steps): |
| warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) |
|
|
| if not exists(scheduler): |
| scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0) |
|
|
| |
|
|
| setattr(self, f'optim{ind}', optimizer) |
| setattr(self, f'scaler{ind}', scaler) |
| setattr(self, f'scheduler{ind}', scheduler) |
| setattr(self, f'warmup{ind}', warmup_scheduler) |
|
|
| |
|
|
| self.max_grad_norm = max_grad_norm |
|
|
| |
|
|
| self.register_buffer('steps', torch.tensor([0] * self.num_unets)) |
|
|
| self.verbose = verbose |
|
|
| |
|
|
| self.imagen.to(self.device) |
| self.to(self.device) |
|
|
| |
|
|
| assert not (exists(checkpoint_path) ^ exists(checkpoint_every)) |
| self.checkpoint_path = checkpoint_path |
| self.checkpoint_every = checkpoint_every |
| self.max_checkpoints_keep = max_checkpoints_keep |
|
|
| self.can_checkpoint = self.is_local_main if isinstance(checkpoint_fs, LocalFileSystem) else self.is_main |
|
|
| if exists(checkpoint_path) and self.can_checkpoint: |
| bucket = url_to_bucket(checkpoint_path) |
|
|
| if not self.fs.exists(bucket): |
| self.fs.mkdir(bucket) |
|
|
| self.load_from_checkpoint_folder() |
|
|
| |
|
|
| self.only_train_unet_number = only_train_unet_number |
| self.validate_and_set_unet_being_trained(only_train_unet_number) |
|
|
| |
|
|
| @property |
| def device(self): |
| return self.accelerator.device |
|
|
| @property |
| def is_distributed(self): |
| return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1) |
|
|
| @property |
| def is_main(self): |
| return self.accelerator.is_main_process |
|
|
| @property |
| def is_local_main(self): |
| return self.accelerator.is_local_main_process |
|
|
| @property |
| def unwrapped_unet(self): |
| return self.accelerator.unwrap_model(self.unet_being_trained) |
|
|
| |
|
|
| def get_lr(self, unet_number): |
| self.validate_unet_number(unet_number) |
| unet_index = unet_number - 1 |
|
|
| optim = getattr(self, f'optim{unet_index}') |
|
|
| return optim.param_groups[0]['lr'] |
|
|
| |
|
|
| def validate_and_set_unet_being_trained(self, unet_number = None): |
| if exists(unet_number): |
| self.validate_unet_number(unet_number) |
|
|
| assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, 'you cannot only train on one unet at a time. you will need to save the trainer into a checkpoint, and resume training on a new unet' |
|
|
| self.only_train_unet_number = unet_number |
| self.imagen.only_train_unet_number = unet_number |
|
|
| if not exists(unet_number): |
| return |
|
|
| self.wrap_unet(unet_number) |
|
|
| def wrap_unet(self, unet_number): |
| if hasattr(self, 'one_unet_wrapped'): |
| return |
|
|
| unet = self.imagen.get_unet(unet_number) |
| self.unet_being_trained = self.accelerator.prepare(unet) |
| unet_index = unet_number - 1 |
|
|
| optimizer = getattr(self, f'optim{unet_index}') |
| scheduler = getattr(self, f'scheduler{unet_index}') |
|
|
| optimizer = self.accelerator.prepare(optimizer) |
|
|
| if exists(scheduler): |
| scheduler = self.accelerator.prepare(scheduler) |
|
|
| setattr(self, f'optim{unet_index}', optimizer) |
| setattr(self, f'scheduler{unet_index}', scheduler) |
|
|
| self.one_unet_wrapped = True |
|
|
| |
|
|
| def set_accelerator_scaler(self, unet_number): |
| unet_number = self.validate_unet_number(unet_number) |
| scaler = getattr(self, f'scaler{unet_number - 1}') |
|
|
| self.accelerator.scaler = scaler |
| for optimizer in self.accelerator._optimizers: |
| optimizer.scaler = scaler |
|
|
| |
|
|
| def print(self, msg): |
| if not self.is_main: |
| return |
|
|
| if not self.verbose: |
| return |
|
|
| return self.accelerator.print(msg) |
|
|
| |
|
|
| def validate_unet_number(self, unet_number = None): |
| if self.num_unets == 1: |
| unet_number = default(unet_number, 1) |
|
|
| assert 0 < unet_number <= self.num_unets, f'unet number should be in between 1 and {self.num_unets}' |
| return unet_number |
|
|
| |
|
|
| def num_steps_taken(self, unet_number = None): |
| if self.num_unets == 1: |
| unet_number = default(unet_number, 1) |
|
|
| return self.steps[unet_number - 1].item() |
|
|
| def print_untrained_unets(self): |
| print_final_error = False |
|
|
| for ind, (steps, unet) in enumerate(zip(self.steps.tolist(), self.imagen.unets)): |
| if steps > 0 or isinstance(unet, NullUnet): |
| continue |
|
|
| self.print(f'unet {ind + 1} has not been trained') |
| print_final_error = True |
|
|
| if print_final_error: |
| self.print('when sampling, you can pass stop_at_unet_number to stop early in the cascade, so it does not try to generate with untrained unets') |
|
|
| |
|
|
| def add_train_dataloader(self, dl = None): |
| if not exists(dl): |
| return |
|
|
| assert not exists(self.train_dl), 'training dataloader was already added' |
| self.train_dl = self.accelerator.prepare(dl) |
|
|
| def add_valid_dataloader(self, dl): |
| if not exists(dl): |
| return |
|
|
| assert not exists(self.valid_dl), 'validation dataloader was already added' |
| self.valid_dl = self.accelerator.prepare(dl) |
|
|
| def add_train_dataset(self, ds = None, *, batch_size, **dl_kwargs): |
| if not exists(ds): |
| return |
|
|
| assert not exists(self.train_dl), 'training dataloader was already added' |
|
|
| valid_ds = None |
| if self.split_valid_from_train: |
| train_size = int((1 - self.split_valid_fraction) * len(ds)) |
| valid_size = len(ds) - train_size |
|
|
| ds, valid_ds = random_split(ds, [train_size, valid_size], generator = torch.Generator().manual_seed(self.split_random_seed)) |
| self.print(f'training with dataset of {len(ds)} samples and validating with randomly splitted {len(valid_ds)} samples') |
|
|
| dl = DataLoader(ds, batch_size = batch_size, **dl_kwargs) |
| self.train_dl = self.accelerator.prepare(dl) |
|
|
| if not self.split_valid_from_train: |
| return |
|
|
| self.add_valid_dataset(valid_ds, batch_size = batch_size, **dl_kwargs) |
|
|
| def add_valid_dataset(self, ds, *, batch_size, **dl_kwargs): |
| if not exists(ds): |
| return |
|
|
| assert not exists(self.valid_dl), 'validation dataloader was already added' |
|
|
| dl = DataLoader(ds, batch_size = batch_size, **dl_kwargs) |
| self.valid_dl = self.accelerator.prepare(dl) |
|
|
| def create_train_iter(self): |
| assert exists(self.train_dl), 'training dataloader has not been registered with the trainer yet' |
|
|
| if exists(self.train_dl_iter): |
| return |
|
|
| self.train_dl_iter = cycle(self.train_dl) |
|
|
| def create_valid_iter(self): |
| assert exists(self.valid_dl), 'validation dataloader has not been registered with the trainer yet' |
|
|
| if exists(self.valid_dl_iter): |
| return |
|
|
| self.valid_dl_iter = cycle(self.valid_dl) |
|
|
| def train_step(self, unet_number = None, **kwargs): |
| self.create_train_iter() |
| loss = self.step_with_dl_iter(self.train_dl_iter, unet_number = unet_number, **kwargs) |
| self.update(unet_number = unet_number) |
| return loss |
|
|
| @torch.no_grad() |
| @eval_decorator |
| def valid_step(self, **kwargs): |
| self.create_valid_iter() |
|
|
| context = self.use_ema_unets if kwargs.pop('use_ema_unets', False) else nullcontext |
|
|
| with context(): |
| loss = self.step_with_dl_iter(self.valid_dl_iter, **kwargs) |
| return loss |
|
|
| def step_with_dl_iter(self, dl_iter, **kwargs): |
| dl_tuple_output = cast_tuple(next(dl_iter)) |
| model_input = dict(list(zip(self.dl_tuple_output_keywords_names, dl_tuple_output))) |
| loss = self.forward(**{**kwargs, **model_input}) |
| return loss |
|
|
| |
|
|
| @property |
| def all_checkpoints_sorted(self): |
| glob_pattern = os.path.join(self.checkpoint_path, '*.pt') |
| checkpoints = self.fs.glob(glob_pattern) |
| sorted_checkpoints = sorted(checkpoints, key = lambda x: int(str(x).split('.')[-2]), reverse = True) |
| return sorted_checkpoints |
|
|
| def load_from_checkpoint_folder(self, last_total_steps = -1): |
| if last_total_steps != -1: |
| filepath = os.path.join(self.checkpoint_path, f'checkpoint.{last_total_steps}.pt') |
| self.load(filepath) |
| return |
|
|
| sorted_checkpoints = self.all_checkpoints_sorted |
|
|
| if len(sorted_checkpoints) == 0: |
| self.print(f'no checkpoints found to load from at {self.checkpoint_path}') |
| return |
|
|
| last_checkpoint = sorted_checkpoints[0] |
| self.load(last_checkpoint) |
|
|
| def save_to_checkpoint_folder(self): |
| self.accelerator.wait_for_everyone() |
|
|
| if not self.can_checkpoint: |
| return |
|
|
| total_steps = int(self.steps.sum().item()) |
| filepath = os.path.join(self.checkpoint_path, f'checkpoint.{total_steps}.pt') |
|
|
| self.save(filepath) |
|
|
| if self.max_checkpoints_keep <= 0: |
| return |
|
|
| sorted_checkpoints = self.all_checkpoints_sorted |
| checkpoints_to_discard = sorted_checkpoints[self.max_checkpoints_keep:] |
|
|
| for checkpoint in checkpoints_to_discard: |
| self.fs.rm(checkpoint) |
|
|
| |
|
|
| def save( |
| self, |
| path, |
| overwrite = True, |
| without_optim_and_sched = False, |
| **kwargs |
| ): |
| self.accelerator.wait_for_everyone() |
|
|
| if not self.can_checkpoint: |
| return |
|
|
| fs = self.fs |
|
|
| assert not (fs.exists(path) and not overwrite) |
|
|
| self.reset_ema_unets_all_one_device() |
|
|
| save_obj = dict( |
| model = self.imagen.state_dict(), |
| version = __version__, |
| steps = self.steps.cpu(), |
| **kwargs |
| ) |
|
|
| save_optim_and_sched_iter = range(0, self.num_unets) if not without_optim_and_sched else tuple() |
|
|
| for ind in save_optim_and_sched_iter: |
| scaler_key = f'scaler{ind}' |
| optimizer_key = f'optim{ind}' |
| scheduler_key = f'scheduler{ind}' |
| warmup_scheduler_key = f'warmup{ind}' |
|
|
| scaler = getattr(self, scaler_key) |
| optimizer = getattr(self, optimizer_key) |
| scheduler = getattr(self, scheduler_key) |
| warmup_scheduler = getattr(self, warmup_scheduler_key) |
|
|
| if exists(scheduler): |
| save_obj = {**save_obj, scheduler_key: scheduler.state_dict()} |
|
|
| if exists(warmup_scheduler): |
| save_obj = {**save_obj, warmup_scheduler_key: warmup_scheduler.state_dict()} |
|
|
| save_obj = {**save_obj, scaler_key: scaler.state_dict(), optimizer_key: optimizer.state_dict()} |
|
|
| if self.use_ema: |
| save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()} |
|
|
| |
|
|
| if hasattr(self.imagen, '_config'): |
| self.print(f'this checkpoint is commandable from the CLI - "imagen --model {str(path)} \"<prompt>\""') |
|
|
| save_obj = { |
| **save_obj, |
| 'imagen_type': 'elucidated' if self.is_elucidated else 'original', |
| 'imagen_params': self.imagen._config |
| } |
|
|
| |
|
|
| with fs.open(path, 'wb') as f: |
| torch.save(save_obj, f) |
|
|
| self.print(f'checkpoint saved to {path}') |
|
|
| def load(self, path, only_model = False, strict = True, noop_if_not_exist = False): |
| fs = self.fs |
|
|
| if noop_if_not_exist and not fs.exists(path): |
| self.print(f'trainer checkpoint not found at {str(path)}') |
| return |
|
|
| assert fs.exists(path), f'{path} does not exist' |
|
|
| self.reset_ema_unets_all_one_device() |
|
|
| |
|
|
| with fs.open(path) as f: |
| loaded_obj = torch.load(f, map_location='cpu') |
|
|
| if version.parse(__version__) != version.parse(loaded_obj['version']): |
| self.print(f'loading saved imagen at version {loaded_obj["version"]}, but current package version is {__version__}') |
|
|
| try: |
| self.imagen.load_state_dict(loaded_obj['model'], strict = strict) |
| except RuntimeError: |
| print("Failed loading state dict. Trying partial load") |
| self.imagen.load_state_dict(restore_parts(self.imagen.state_dict(), |
| loaded_obj['model'])) |
|
|
| if only_model: |
| return loaded_obj |
|
|
| self.steps.copy_(loaded_obj['steps']) |
|
|
| for ind in range(0, self.num_unets): |
| scaler_key = f'scaler{ind}' |
| optimizer_key = f'optim{ind}' |
| scheduler_key = f'scheduler{ind}' |
| warmup_scheduler_key = f'warmup{ind}' |
|
|
| scaler = getattr(self, scaler_key) |
| optimizer = getattr(self, optimizer_key) |
| scheduler = getattr(self, scheduler_key) |
| warmup_scheduler = getattr(self, warmup_scheduler_key) |
|
|
| if exists(scheduler) and scheduler_key in loaded_obj: |
| scheduler.load_state_dict(loaded_obj[scheduler_key]) |
|
|
| if exists(warmup_scheduler) and warmup_scheduler_key in loaded_obj: |
| warmup_scheduler.load_state_dict(loaded_obj[warmup_scheduler_key]) |
|
|
| if exists(optimizer): |
| try: |
| optimizer.load_state_dict(loaded_obj[optimizer_key]) |
| scaler.load_state_dict(loaded_obj[scaler_key]) |
| except: |
| self.print('could not load optimizer and scaler, possibly because you have turned on mixed precision training since the last run. resuming with new optimizer and scalers') |
|
|
| if self.use_ema: |
| assert 'ema' in loaded_obj |
| try: |
| self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict) |
| except RuntimeError: |
| print("Failed loading state dict. Trying partial load") |
| self.ema_unets.load_state_dict(restore_parts(self.ema_unets.state_dict(), |
| loaded_obj['ema'])) |
|
|
| self.print(f'checkpoint loaded from {path}') |
| return loaded_obj |
|
|
| |
|
|
| @property |
| def unets(self): |
| return nn.ModuleList([ema.ema_model for ema in self.ema_unets]) |
|
|
| def get_ema_unet(self, unet_number = None): |
| if not self.use_ema: |
| return |
|
|
| unet_number = self.validate_unet_number(unet_number) |
| index = unet_number - 1 |
|
|
| if isinstance(self.unets, nn.ModuleList): |
| unets_list = [unet for unet in self.ema_unets] |
| delattr(self, 'ema_unets') |
| self.ema_unets = unets_list |
|
|
| if index != self.ema_unet_being_trained_index: |
| for unet_index, unet in enumerate(self.ema_unets): |
| unet.to(self.device if unet_index == index else 'cpu') |
|
|
| self.ema_unet_being_trained_index = index |
| return self.ema_unets[index] |
|
|
| def reset_ema_unets_all_one_device(self, device = None): |
| if not self.use_ema: |
| return |
|
|
| device = default(device, self.device) |
| self.ema_unets = nn.ModuleList([*self.ema_unets]) |
| self.ema_unets.to(device) |
|
|
| self.ema_unet_being_trained_index = -1 |
|
|
| @torch.no_grad() |
| @contextmanager |
| def use_ema_unets(self): |
| if not self.use_ema: |
| output = yield |
| return output |
|
|
| self.reset_ema_unets_all_one_device() |
| self.imagen.reset_unets_all_one_device() |
|
|
| self.unets.eval() |
|
|
| trainable_unets = self.imagen.unets |
| self.imagen.unets = self.unets |
|
|
| output = yield |
|
|
| self.imagen.unets = trainable_unets |
|
|
| |
| for ema in self.ema_unets: |
| ema.restore_ema_model_device() |
|
|
| return output |
|
|
| def print_unet_devices(self): |
| self.print('unet devices:') |
| for i, unet in enumerate(self.imagen.unets): |
| device = next(unet.parameters()).device |
| self.print(f'\tunet {i}: {device}') |
|
|
| if not self.use_ema: |
| return |
|
|
| self.print('\nema unet devices:') |
| for i, ema_unet in enumerate(self.ema_unets): |
| device = next(ema_unet.parameters()).device |
| self.print(f'\tema unet {i}: {device}') |
|
|
| |
|
|
| def state_dict(self, *args, **kwargs): |
| self.reset_ema_unets_all_one_device() |
| return super().state_dict(*args, **kwargs) |
|
|
| def load_state_dict(self, *args, **kwargs): |
| self.reset_ema_unets_all_one_device() |
| return super().load_state_dict(*args, **kwargs) |
|
|
| |
|
|
| def encode_text(self, text, **kwargs): |
| return self.imagen.encode_text(text, **kwargs) |
|
|
| |
|
|
| def update(self, unet_number = None): |
| unet_number = self.validate_unet_number(unet_number) |
| self.validate_and_set_unet_being_trained(unet_number) |
| self.set_accelerator_scaler(unet_number) |
|
|
| index = unet_number - 1 |
| unet = self.unet_being_trained |
|
|
| optimizer = getattr(self, f'optim{index}') |
| scaler = getattr(self, f'scaler{index}') |
| scheduler = getattr(self, f'scheduler{index}') |
| warmup_scheduler = getattr(self, f'warmup{index}') |
|
|
| |
|
|
| if exists(self.max_grad_norm): |
| self.accelerator.clip_grad_norm_(unet.parameters(), self.max_grad_norm) |
|
|
| optimizer.step() |
| optimizer.zero_grad() |
|
|
| if self.use_ema: |
| ema_unet = self.get_ema_unet(unet_number) |
| ema_unet.update() |
|
|
| |
|
|
| maybe_warmup_context = nullcontext() if not exists(warmup_scheduler) else warmup_scheduler.dampening() |
|
|
| with maybe_warmup_context: |
| if exists(scheduler) and not self.accelerator.optimizer_step_was_skipped: |
| scheduler.step() |
|
|
| self.steps += F.one_hot(torch.tensor(unet_number - 1, device = self.steps.device), num_classes = len(self.steps)) |
|
|
| if not exists(self.checkpoint_path): |
| return |
|
|
| total_steps = int(self.steps.sum().item()) |
|
|
| if total_steps % self.checkpoint_every: |
| return |
|
|
| self.save_to_checkpoint_folder() |
|
|
| @torch.no_grad() |
| @cast_torch_tensor |
| @imagen_sample_in_chunks |
| def sample(self, *args, **kwargs): |
| context = nullcontext if kwargs.pop('use_non_ema', False) else self.use_ema_unets |
|
|
| self.print_untrained_unets() |
| |
| if not self.is_main: |
| kwargs['use_tqdm'] = False |
|
|
| with context(): |
| output = self.imagen.sample(*args, device = self.device, **kwargs) |
|
|
| return output |
|
|
| @partial(cast_torch_tensor, cast_fp16 = True) |
| def forward( |
| self, |
| *args, |
| unet_number = None, |
| max_batch_size = None, |
| **kwargs |
| ): |
| unet_number = self.validate_unet_number(unet_number) |
| self.validate_and_set_unet_being_trained(unet_number) |
| self.set_accelerator_scaler(unet_number) |
|
|
| assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, f'you can only train unet #{self.only_train_unet_number}' |
|
|
| total_loss = 0. |
| |
| |
| |
| if self.CKeys['Debug_TrainerPack']==1: |
| print("In Trainer:Forward, check inputs:") |
| print('args: ', len(args)) |
| print('args in:',args[0].shape) |
| print('kwargs: ', kwargs.keys()) |
| for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs): |
| |
| if self.CKeys['Debug_TrainerPack']==1: |
| print("after chunks,...") |
| print('chun_frac: ', chunk_size_frac) |
| print('chun_args: ', chunked_args) |
| print('chun_kwargs: ', chunked_kwargs) |
| |
| with self.accelerator.autocast(): |
| loss = self.model( |
| *chunked_args, |
| unet_number = unet_number, |
| **chunked_kwargs |
| ) |
| loss = loss * chunk_size_frac |
| |
| |
| if self.CKeys['Debug_TrainerPack']==1: |
| print('part chun loss: ', loss) |
|
|
| total_loss += loss |
|
|
| if self.training: |
| self.accelerator.backward(loss) |
|
|
| return total_loss |
| |
| class ImagenTrainer_Old(nn.Module): |
| locked = False |
|
|
| def __init__( |
| self, |
| |
| model = None, |
| |
| imagen_checkpoint_path = None, |
| use_ema = True, |
| lr = 1e-4, |
| eps = 1e-8, |
| beta1 = 0.9, |
| beta2 = 0.99, |
| max_grad_norm = None, |
| group_wd_params = True, |
| warmup_steps = None, |
| cosine_decay_max_steps = None, |
| only_train_unet_number = None, |
| fp16 = False, |
| precision = None, |
| split_batches = True, |
| dl_tuple_output_keywords_names = ('images', 'text_embeds', 'text_masks', 'cond_images'), |
| verbose = True, |
| split_valid_fraction = 0.025, |
| split_valid_from_train = False, |
| split_random_seed = 42, |
| checkpoint_path = None, |
| checkpoint_every = None, |
| checkpoint_fs = None, |
| fs_kwargs: dict = None, |
| max_checkpoints_keep = 20, |
| **kwargs |
| ): |
| super().__init__() |
| assert not ImagenTrainer.locked, 'ImagenTrainer can only be initialized once per process - for the sake of distributed training, you will now have to create a separate script to train each unet (or a script that accepts unet number as an argument)' |
| assert exists(model.imagen) ^ exists(imagen_checkpoint_path), 'either imagen instance is passed into the trainer, or a checkpoint path that contains the imagen config' |
|
|
| |
|
|
| self.fs = checkpoint_fs |
|
|
| if not exists(self.fs): |
| fs_kwargs = default(fs_kwargs, {}) |
| self.fs, _ = url_to_fs(default(checkpoint_path, './'), **fs_kwargs) |
| |
| |
| |
| |
| |
| ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs) |
|
|
| |
| self.imagen = model.imagen |
| |
| |
|
|
| self.model=model |
| self.is_elucidated = self.model.is_elucidated |
| |
|
|
| accelerate_kwargs, kwargs = groupby_prefix_and_trim('accelerate_', kwargs) |
|
|
| assert not (fp16 and exists(precision)), 'either set fp16 = True or forward the precision ("fp16", "bf16") to Accelerator' |
| accelerator_mixed_precision = default(precision, 'fp16' if fp16 else 'no') |
|
|
| self.accelerator = Accelerator(**{ |
| 'split_batches': split_batches, |
| 'mixed_precision': accelerator_mixed_precision, |
| 'kwargs_handlers': [DistributedDataParallelKwargs(find_unused_parameters = True)] |
| , **accelerate_kwargs}) |
|
|
| ImagenTrainer.locked = self.is_distributed |
|
|
| |
|
|
| self.cast_half_at_training = accelerator_mixed_precision == 'fp16' |
|
|
| |
|
|
| grad_scaler_enabled = fp16 |
| |
| self.num_unets = len(self.imagen.unets) |
|
|
| self.use_ema = use_ema and self.is_main |
| self.ema_unets = nn.ModuleList([]) |
|
|
| |
| |
|
|
| self.ema_unet_being_trained_index = -1 |
|
|
| |
|
|
| self.train_dl_iter = None |
| self.train_dl = None |
|
|
| self.valid_dl_iter = None |
| self.valid_dl = None |
|
|
| self.dl_tuple_output_keywords_names = dl_tuple_output_keywords_names |
|
|
| |
|
|
| self.split_valid_from_train = split_valid_from_train |
|
|
| assert 0 <= split_valid_fraction <= 1, 'split valid fraction must be between 0 and 1' |
| self.split_valid_fraction = split_valid_fraction |
| self.split_random_seed = split_random_seed |
|
|
| |
| |
|
|
| lr, eps, warmup_steps, cosine_decay_max_steps = map(partial(cast_tuple, length = self.num_unets), (lr, eps, warmup_steps, cosine_decay_max_steps)) |
|
|
| for ind, (unet, unet_lr, unet_eps, unet_warmup_steps, unet_cosine_decay_max_steps) in enumerate(zip(self.imagen.unets, lr, eps, warmup_steps, cosine_decay_max_steps)): |
| optimizer = Adam( |
| unet.parameters(), |
| lr = unet_lr, |
| eps = unet_eps, |
| betas = (beta1, beta2), |
| **kwargs |
| ) |
|
|
| if self.use_ema: |
| self.ema_unets.append(EMA(unet, **ema_kwargs)) |
|
|
| scaler = GradScaler(enabled = grad_scaler_enabled) |
|
|
| scheduler = warmup_scheduler = None |
|
|
| if exists(unet_cosine_decay_max_steps): |
| scheduler = CosineAnnealingLR(optimizer, T_max = unet_cosine_decay_max_steps) |
|
|
| if exists(unet_warmup_steps): |
| warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) |
|
|
| if not exists(scheduler): |
| scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0) |
|
|
| |
|
|
| setattr(self, f'optim{ind}', optimizer) |
| setattr(self, f'scaler{ind}', scaler) |
| setattr(self, f'scheduler{ind}', scheduler) |
| setattr(self, f'warmup{ind}', warmup_scheduler) |
|
|
| |
|
|
| self.max_grad_norm = max_grad_norm |
|
|
| |
|
|
| self.register_buffer('steps', torch.tensor([0] * self.num_unets)) |
|
|
| self.verbose = verbose |
|
|
| |
|
|
| self.imagen.to(self.device) |
| self.to(self.device) |
|
|
| |
|
|
| assert not (exists(checkpoint_path) ^ exists(checkpoint_every)) |
| self.checkpoint_path = checkpoint_path |
| self.checkpoint_every = checkpoint_every |
| self.max_checkpoints_keep = max_checkpoints_keep |
|
|
| self.can_checkpoint = self.is_local_main if isinstance(checkpoint_fs, LocalFileSystem) else self.is_main |
|
|
| if exists(checkpoint_path) and self.can_checkpoint: |
| bucket = url_to_bucket(checkpoint_path) |
|
|
| if not self.fs.exists(bucket): |
| self.fs.mkdir(bucket) |
|
|
| self.load_from_checkpoint_folder() |
|
|
| |
|
|
| self.only_train_unet_number = only_train_unet_number |
| self.validate_and_set_unet_being_trained(only_train_unet_number) |
|
|
| |
|
|
| @property |
| def device(self): |
| return self.accelerator.device |
|
|
| @property |
| def is_distributed(self): |
| return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1) |
|
|
| @property |
| def is_main(self): |
| return self.accelerator.is_main_process |
|
|
| @property |
| def is_local_main(self): |
| return self.accelerator.is_local_main_process |
|
|
| @property |
| def unwrapped_unet(self): |
| return self.accelerator.unwrap_model(self.unet_being_trained) |
|
|
| |
|
|
| def get_lr(self, unet_number): |
| self.validate_unet_number(unet_number) |
| unet_index = unet_number - 1 |
|
|
| optim = getattr(self, f'optim{unet_index}') |
|
|
| return optim.param_groups[0]['lr'] |
|
|
| |
|
|
| def validate_and_set_unet_being_trained(self, unet_number = None): |
| if exists(unet_number): |
| self.validate_unet_number(unet_number) |
|
|
| assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, 'you cannot only train on one unet at a time. you will need to save the trainer into a checkpoint, and resume training on a new unet' |
|
|
| self.only_train_unet_number = unet_number |
| self.imagen.only_train_unet_number = unet_number |
|
|
| if not exists(unet_number): |
| return |
|
|
| self.wrap_unet(unet_number) |
|
|
| def wrap_unet(self, unet_number): |
| if hasattr(self, 'one_unet_wrapped'): |
| return |
|
|
| unet = self.imagen.get_unet(unet_number) |
| self.unet_being_trained = self.accelerator.prepare(unet) |
| unet_index = unet_number - 1 |
|
|
| optimizer = getattr(self, f'optim{unet_index}') |
| scheduler = getattr(self, f'scheduler{unet_index}') |
|
|
| optimizer = self.accelerator.prepare(optimizer) |
|
|
| if exists(scheduler): |
| scheduler = self.accelerator.prepare(scheduler) |
|
|
| setattr(self, f'optim{unet_index}', optimizer) |
| setattr(self, f'scheduler{unet_index}', scheduler) |
|
|
| self.one_unet_wrapped = True |
|
|
| |
|
|
| def set_accelerator_scaler(self, unet_number): |
| unet_number = self.validate_unet_number(unet_number) |
| scaler = getattr(self, f'scaler{unet_number - 1}') |
|
|
| self.accelerator.scaler = scaler |
| for optimizer in self.accelerator._optimizers: |
| optimizer.scaler = scaler |
|
|
| |
|
|
| def print(self, msg): |
| if not self.is_main: |
| return |
|
|
| if not self.verbose: |
| return |
|
|
| return self.accelerator.print(msg) |
|
|
| |
|
|
| def validate_unet_number(self, unet_number = None): |
| if self.num_unets == 1: |
| unet_number = default(unet_number, 1) |
|
|
| assert 0 < unet_number <= self.num_unets, f'unet number should be in between 1 and {self.num_unets}' |
| return unet_number |
|
|
| |
|
|
| def num_steps_taken(self, unet_number = None): |
| if self.num_unets == 1: |
| unet_number = default(unet_number, 1) |
|
|
| return self.steps[unet_number - 1].item() |
|
|
| def print_untrained_unets(self): |
| print_final_error = False |
|
|
| for ind, (steps, unet) in enumerate(zip(self.steps.tolist(), self.imagen.unets)): |
| if steps > 0 or isinstance(unet, NullUnet): |
| continue |
|
|
| self.print(f'unet {ind + 1} has not been trained') |
| print_final_error = True |
|
|
| if print_final_error: |
| self.print('when sampling, you can pass stop_at_unet_number to stop early in the cascade, so it does not try to generate with untrained unets') |
|
|
| |
|
|
| def add_train_dataloader(self, dl = None): |
| if not exists(dl): |
| return |
|
|
| assert not exists(self.train_dl), 'training dataloader was already added' |
| self.train_dl = self.accelerator.prepare(dl) |
|
|
| def add_valid_dataloader(self, dl): |
| if not exists(dl): |
| return |
|
|
| assert not exists(self.valid_dl), 'validation dataloader was already added' |
| self.valid_dl = self.accelerator.prepare(dl) |
|
|
| def add_train_dataset(self, ds = None, *, batch_size, **dl_kwargs): |
| if not exists(ds): |
| return |
|
|
| assert not exists(self.train_dl), 'training dataloader was already added' |
|
|
| valid_ds = None |
| if self.split_valid_from_train: |
| train_size = int((1 - self.split_valid_fraction) * len(ds)) |
| valid_size = len(ds) - train_size |
|
|
| ds, valid_ds = random_split(ds, [train_size, valid_size], generator = torch.Generator().manual_seed(self.split_random_seed)) |
| self.print(f'training with dataset of {len(ds)} samples and validating with randomly splitted {len(valid_ds)} samples') |
|
|
| dl = DataLoader(ds, batch_size = batch_size, **dl_kwargs) |
| self.train_dl = self.accelerator.prepare(dl) |
|
|
| if not self.split_valid_from_train: |
| return |
|
|
| self.add_valid_dataset(valid_ds, batch_size = batch_size, **dl_kwargs) |
|
|
| def add_valid_dataset(self, ds, *, batch_size, **dl_kwargs): |
| if not exists(ds): |
| return |
|
|
| assert not exists(self.valid_dl), 'validation dataloader was already added' |
|
|
| dl = DataLoader(ds, batch_size = batch_size, **dl_kwargs) |
| self.valid_dl = self.accelerator.prepare(dl) |
|
|
| def create_train_iter(self): |
| assert exists(self.train_dl), 'training dataloader has not been registered with the trainer yet' |
|
|
| if exists(self.train_dl_iter): |
| return |
|
|
| self.train_dl_iter = cycle(self.train_dl) |
|
|
| def create_valid_iter(self): |
| assert exists(self.valid_dl), 'validation dataloader has not been registered with the trainer yet' |
|
|
| if exists(self.valid_dl_iter): |
| return |
|
|
| self.valid_dl_iter = cycle(self.valid_dl) |
|
|
| def train_step(self, unet_number = None, **kwargs): |
| self.create_train_iter() |
| loss = self.step_with_dl_iter(self.train_dl_iter, unet_number = unet_number, **kwargs) |
| self.update(unet_number = unet_number) |
| return loss |
|
|
| @torch.no_grad() |
| @eval_decorator |
| def valid_step(self, **kwargs): |
| self.create_valid_iter() |
|
|
| context = self.use_ema_unets if kwargs.pop('use_ema_unets', False) else nullcontext |
|
|
| with context(): |
| loss = self.step_with_dl_iter(self.valid_dl_iter, **kwargs) |
| return loss |
|
|
| def step_with_dl_iter(self, dl_iter, **kwargs): |
| dl_tuple_output = cast_tuple(next(dl_iter)) |
| model_input = dict(list(zip(self.dl_tuple_output_keywords_names, dl_tuple_output))) |
| loss = self.forward(**{**kwargs, **model_input}) |
| return loss |
|
|
| |
|
|
| @property |
| def all_checkpoints_sorted(self): |
| glob_pattern = os.path.join(self.checkpoint_path, '*.pt') |
| checkpoints = self.fs.glob(glob_pattern) |
| sorted_checkpoints = sorted(checkpoints, key = lambda x: int(str(x).split('.')[-2]), reverse = True) |
| return sorted_checkpoints |
|
|
| def load_from_checkpoint_folder(self, last_total_steps = -1): |
| if last_total_steps != -1: |
| filepath = os.path.join(self.checkpoint_path, f'checkpoint.{last_total_steps}.pt') |
| self.load(filepath) |
| return |
|
|
| sorted_checkpoints = self.all_checkpoints_sorted |
|
|
| if len(sorted_checkpoints) == 0: |
| self.print(f'no checkpoints found to load from at {self.checkpoint_path}') |
| return |
|
|
| last_checkpoint = sorted_checkpoints[0] |
| self.load(last_checkpoint) |
|
|
| def save_to_checkpoint_folder(self): |
| self.accelerator.wait_for_everyone() |
|
|
| if not self.can_checkpoint: |
| return |
|
|
| total_steps = int(self.steps.sum().item()) |
| filepath = os.path.join(self.checkpoint_path, f'checkpoint.{total_steps}.pt') |
|
|
| self.save(filepath) |
|
|
| if self.max_checkpoints_keep <= 0: |
| return |
|
|
| sorted_checkpoints = self.all_checkpoints_sorted |
| checkpoints_to_discard = sorted_checkpoints[self.max_checkpoints_keep:] |
|
|
| for checkpoint in checkpoints_to_discard: |
| self.fs.rm(checkpoint) |
|
|
| |
|
|
| def save( |
| self, |
| path, |
| overwrite = True, |
| without_optim_and_sched = False, |
| **kwargs |
| ): |
| self.accelerator.wait_for_everyone() |
|
|
| if not self.can_checkpoint: |
| return |
|
|
| fs = self.fs |
|
|
| assert not (fs.exists(path) and not overwrite) |
|
|
| self.reset_ema_unets_all_one_device() |
|
|
| save_obj = dict( |
| model = self.imagen.state_dict(), |
| version = __version__, |
| steps = self.steps.cpu(), |
| **kwargs |
| ) |
|
|
| save_optim_and_sched_iter = range(0, self.num_unets) if not without_optim_and_sched else tuple() |
|
|
| for ind in save_optim_and_sched_iter: |
| scaler_key = f'scaler{ind}' |
| optimizer_key = f'optim{ind}' |
| scheduler_key = f'scheduler{ind}' |
| warmup_scheduler_key = f'warmup{ind}' |
|
|
| scaler = getattr(self, scaler_key) |
| optimizer = getattr(self, optimizer_key) |
| scheduler = getattr(self, scheduler_key) |
| warmup_scheduler = getattr(self, warmup_scheduler_key) |
|
|
| if exists(scheduler): |
| save_obj = {**save_obj, scheduler_key: scheduler.state_dict()} |
|
|
| if exists(warmup_scheduler): |
| save_obj = {**save_obj, warmup_scheduler_key: warmup_scheduler.state_dict()} |
|
|
| save_obj = {**save_obj, scaler_key: scaler.state_dict(), optimizer_key: optimizer.state_dict()} |
|
|
| if self.use_ema: |
| save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()} |
|
|
| |
|
|
| if hasattr(self.imagen, '_config'): |
| self.print(f'this checkpoint is commandable from the CLI - "imagen --model {str(path)} \"<prompt>\""') |
|
|
| save_obj = { |
| **save_obj, |
| 'imagen_type': 'elucidated' if self.is_elucidated else 'original', |
| 'imagen_params': self.imagen._config |
| } |
|
|
| |
|
|
| with fs.open(path, 'wb') as f: |
| torch.save(save_obj, f) |
|
|
| self.print(f'checkpoint saved to {path}') |
|
|
| def load(self, path, only_model = False, strict = True, noop_if_not_exist = False): |
| fs = self.fs |
|
|
| if noop_if_not_exist and not fs.exists(path): |
| self.print(f'trainer checkpoint not found at {str(path)}') |
| return |
|
|
| assert fs.exists(path), f'{path} does not exist' |
|
|
| self.reset_ema_unets_all_one_device() |
|
|
| |
|
|
| with fs.open(path) as f: |
| loaded_obj = torch.load(f, map_location='cpu') |
|
|
| if version.parse(__version__) != version.parse(loaded_obj['version']): |
| self.print(f'loading saved imagen at version {loaded_obj["version"]}, but current package version is {__version__}') |
|
|
| try: |
| self.imagen.load_state_dict(loaded_obj['model'], strict = strict) |
| except RuntimeError: |
| print("Failed loading state dict. Trying partial load") |
| self.imagen.load_state_dict(restore_parts(self.imagen.state_dict(), |
| loaded_obj['model'])) |
|
|
| if only_model: |
| return loaded_obj |
|
|
| self.steps.copy_(loaded_obj['steps']) |
|
|
| for ind in range(0, self.num_unets): |
| scaler_key = f'scaler{ind}' |
| optimizer_key = f'optim{ind}' |
| scheduler_key = f'scheduler{ind}' |
| warmup_scheduler_key = f'warmup{ind}' |
|
|
| scaler = getattr(self, scaler_key) |
| optimizer = getattr(self, optimizer_key) |
| scheduler = getattr(self, scheduler_key) |
| warmup_scheduler = getattr(self, warmup_scheduler_key) |
|
|
| if exists(scheduler) and scheduler_key in loaded_obj: |
| scheduler.load_state_dict(loaded_obj[scheduler_key]) |
|
|
| if exists(warmup_scheduler) and warmup_scheduler_key in loaded_obj: |
| warmup_scheduler.load_state_dict(loaded_obj[warmup_scheduler_key]) |
|
|
| if exists(optimizer): |
| try: |
| optimizer.load_state_dict(loaded_obj[optimizer_key]) |
| scaler.load_state_dict(loaded_obj[scaler_key]) |
| except: |
| self.print('could not load optimizer and scaler, possibly because you have turned on mixed precision training since the last run. resuming with new optimizer and scalers') |
|
|
| if self.use_ema: |
| assert 'ema' in loaded_obj |
| try: |
| self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict) |
| except RuntimeError: |
| print("Failed loading state dict. Trying partial load") |
| self.ema_unets.load_state_dict(restore_parts(self.ema_unets.state_dict(), |
| loaded_obj['ema'])) |
|
|
| self.print(f'checkpoint loaded from {path}') |
| return loaded_obj |
|
|
| |
|
|
| @property |
| def unets(self): |
| return nn.ModuleList([ema.ema_model for ema in self.ema_unets]) |
|
|
| def get_ema_unet(self, unet_number = None): |
| if not self.use_ema: |
| return |
|
|
| unet_number = self.validate_unet_number(unet_number) |
| index = unet_number - 1 |
|
|
| if isinstance(self.unets, nn.ModuleList): |
| unets_list = [unet for unet in self.ema_unets] |
| delattr(self, 'ema_unets') |
| self.ema_unets = unets_list |
|
|
| if index != self.ema_unet_being_trained_index: |
| for unet_index, unet in enumerate(self.ema_unets): |
| unet.to(self.device if unet_index == index else 'cpu') |
|
|
| self.ema_unet_being_trained_index = index |
| return self.ema_unets[index] |
|
|
| def reset_ema_unets_all_one_device(self, device = None): |
| if not self.use_ema: |
| return |
|
|
| device = default(device, self.device) |
| self.ema_unets = nn.ModuleList([*self.ema_unets]) |
| self.ema_unets.to(device) |
|
|
| self.ema_unet_being_trained_index = -1 |
|
|
| @torch.no_grad() |
| @contextmanager |
| def use_ema_unets(self): |
| if not self.use_ema: |
| output = yield |
| return output |
|
|
| self.reset_ema_unets_all_one_device() |
| self.imagen.reset_unets_all_one_device() |
|
|
| self.unets.eval() |
|
|
| trainable_unets = self.imagen.unets |
| self.imagen.unets = self.unets |
|
|
| output = yield |
|
|
| self.imagen.unets = trainable_unets |
|
|
| |
| for ema in self.ema_unets: |
| ema.restore_ema_model_device() |
|
|
| return output |
|
|
| def print_unet_devices(self): |
| self.print('unet devices:') |
| for i, unet in enumerate(self.imagen.unets): |
| device = next(unet.parameters()).device |
| self.print(f'\tunet {i}: {device}') |
|
|
| if not self.use_ema: |
| return |
|
|
| self.print('\nema unet devices:') |
| for i, ema_unet in enumerate(self.ema_unets): |
| device = next(ema_unet.parameters()).device |
| self.print(f'\tema unet {i}: {device}') |
|
|
| |
|
|
| def state_dict(self, *args, **kwargs): |
| self.reset_ema_unets_all_one_device() |
| return super().state_dict(*args, **kwargs) |
|
|
| def load_state_dict(self, *args, **kwargs): |
| self.reset_ema_unets_all_one_device() |
| return super().load_state_dict(*args, **kwargs) |
|
|
| |
|
|
| def encode_text(self, text, **kwargs): |
| return self.imagen.encode_text(text, **kwargs) |
|
|
| |
|
|
| def update(self, unet_number = None): |
| unet_number = self.validate_unet_number(unet_number) |
| self.validate_and_set_unet_being_trained(unet_number) |
| self.set_accelerator_scaler(unet_number) |
|
|
| index = unet_number - 1 |
| unet = self.unet_being_trained |
|
|
| optimizer = getattr(self, f'optim{index}') |
| scaler = getattr(self, f'scaler{index}') |
| scheduler = getattr(self, f'scheduler{index}') |
| warmup_scheduler = getattr(self, f'warmup{index}') |
|
|
| |
|
|
| if exists(self.max_grad_norm): |
| self.accelerator.clip_grad_norm_(unet.parameters(), self.max_grad_norm) |
|
|
| optimizer.step() |
| optimizer.zero_grad() |
|
|
| if self.use_ema: |
| ema_unet = self.get_ema_unet(unet_number) |
| ema_unet.update() |
|
|
| |
|
|
| maybe_warmup_context = nullcontext() if not exists(warmup_scheduler) else warmup_scheduler.dampening() |
|
|
| with maybe_warmup_context: |
| if exists(scheduler) and not self.accelerator.optimizer_step_was_skipped: |
| scheduler.step() |
|
|
| self.steps += F.one_hot(torch.tensor(unet_number - 1, device = self.steps.device), num_classes = len(self.steps)) |
|
|
| if not exists(self.checkpoint_path): |
| return |
|
|
| total_steps = int(self.steps.sum().item()) |
|
|
| if total_steps % self.checkpoint_every: |
| return |
|
|
| self.save_to_checkpoint_folder() |
|
|
| @torch.no_grad() |
| @cast_torch_tensor |
| @imagen_sample_in_chunks |
| def sample(self, *args, **kwargs): |
| context = nullcontext if kwargs.pop('use_non_ema', False) else self.use_ema_unets |
|
|
| self.print_untrained_unets() |
| |
| if not self.is_main: |
| kwargs['use_tqdm'] = False |
|
|
| with context(): |
| output = self.imagen.sample(*args, device = self.device, **kwargs) |
|
|
| return output |
|
|
| @partial(cast_torch_tensor, cast_fp16 = True) |
| def forward( |
| self, |
| *args, |
| unet_number = None, |
| max_batch_size = None, |
| **kwargs |
| ): |
| unet_number = self.validate_unet_number(unet_number) |
| self.validate_and_set_unet_being_trained(unet_number) |
| self.set_accelerator_scaler(unet_number) |
|
|
| assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, f'you can only train unet #{self.only_train_unet_number}' |
|
|
| total_loss = 0. |
| |
| |
| |
| print('args: ', len(args)) |
| print('args in:',args[0].shape) |
| print('kwargs: ', kwargs.keys()) |
| for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs): |
| |
| print('chun_frac: ', chunk_size_frac) |
| print('chun_args: ', chunked_args) |
| print('chun_kwargs: ', chunked_kwargs) |
| |
| with self.accelerator.autocast(): |
| loss = self.model(*chunked_args, unet_number = unet_number, **chunked_kwargs) |
| loss = loss * chunk_size_frac |
| |
| print('loss: ', loss) |
|
|
| total_loss += loss |
|
|
| if self.training: |
| self.accelerator.backward(loss) |
|
|
| print('I am here') |
| return total_loss |
|
|
| |
| def write_fasta (sequence, filename_out): |
| |
| with open (filename_out, mode ='w') as f: |
| f.write (f'>{filename_out}\n') |
| f.write (f'{sequence}') |
|
|
|
|
| |
| def sample_sequence_FromModelB ( |
| model, |
| X=None, |
| flag=0, |
| cond_scales=1., |
| foldproteins=False, |
| X_string=None, |
| x_data=None, |
| skip_steps=0, |
| inpaint_images = None, |
| inpaint_masks = None, |
| inpaint_resample_times = None, |
| init_images = None, |
| num_cycle=16, |
| |
| ynormfac=1, |
| train_unet_number=1, |
| tokenizer_X=None, |
| Xnormfac=1., |
| max_length=1., |
| prefix=None, |
| tokenizer_y=None, |
| ): |
| steps=0 |
| e=flag |
|
|
|
|
| |
|
|
| |
| if X!=None: |
| print (f"Producing {len(X)} samples...from text conditioning X...") |
| lenn_val=len(X) |
| if X_string!=None: |
| lenn_val=len(X_string) |
| print (f"Producing {len(X_string)} samples...from text conditioning X_String (from string)...") |
| |
| if x_data!=None: |
| print (f"Producing {len(x_data)} samples...from image conditingig x_data ...") |
| lenn_val=len(x_data) |
| print (x_data) |
| |
| print ('Device: ', device) |
|
|
|
|
| for iisample in range (lenn_val): |
| X_cond=None |
| if X_string==None and X != None: |
| X_cond=torch.Tensor (X[iisample]).to(device).unsqueeze (0) |
| if X_string !=None: |
| X = tokenizer_X.texts_to_sequences(X_string[iisample]) |
| X= sequence.pad_sequences(X, maxlen=max_length, padding='post', truncating='post') |
| X=np.array(X) |
| X_cond=torch.from_numpy(X).float()/Xnormfac |
| print ('Tokenized and processed: ', X_cond) |
| |
| print ("X_cond=", X_cond) |
| |
| result=model.sample ( |
| x=X_cond, |
| stop_at_unet_number=train_unet_number , |
| cond_scale=cond_scales , |
| x_data=x_data, skip_steps=skip_steps, |
| inpaint_images = inpaint_images, |
| inpaint_masks = inpaint_masks, |
| inpaint_resample_times = inpaint_resample_times, |
| init_images = init_images,device=device, |
| |
| tokenizer_X=tokenizer_X, |
| Xnormfac=Xnormfac, |
| max_length=max_length, |
| ) |
| result=torch.round(result*ynormfac) |
| |
| plt.plot (result[0,0,:].cpu().detach().numpy(),label= f'Predicted') |
| |
| plt.legend() |
|
|
| outname = prefix+ f"sampled_from_X_{flag}_condscale-{str (cond_scales)}_{e}_{steps}.jpg" |
| |
| plt.savefig(outname, dpi=200) |
| plt.show () |
|
|
| to_rev=result[:,0,:] |
| to_rev=to_rev.long().cpu().detach().numpy() |
| print (to_rev.shape) |
| y_data_reversed=tokenizer_y.sequences_to_texts (to_rev) |
|
|
| for iii in range (len(y_data_reversed)): |
| y_data_reversed[iii]=y_data_reversed[iii].upper().strip().replace(" ", "") |
| |
| |
| if X_cond != None: |
| X_cond=torch.round(X_cond*Xnormfac) |
|
|
| to_rev=X_cond[:,:] |
| to_rev=to_rev.long().cpu().detach().numpy() |
| print (to_rev.shape) |
| X_data_reversed=tokenizer_X.sequences_to_texts (to_rev) |
|
|
| for iii in range (len(y_data_reversed)): |
| X_data_reversed[iii]=X_data_reversed[iii].upper().strip().replace(" ", "") |
| if x_data !=None: |
| X_data_reversed=x_data |
| |
|
|
| print (f"For {X} or {X_data_reversed[iisample]}, predicted sequence: ", y_data_reversed[iisample]) |
| if foldproteins: |
| |
| if X_cond != None: |
| xbc=X_cond[iisample,:].cpu().detach().numpy() |
| out_nam=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.1f" % xbc})+f'_{flag}_{steps}' |
| if x_data !=None: |
| |
| out_nam=x_data[iisample] |
| |
| |
| tempname='temp' |
| pdb_file=foldandsavePDB ( |
| sequence=y_data_reversed[0], |
| filename_out=tempname, |
| num_cycle=num_cycle, |
| flag=flag, |
| |
| prefix=prefix |
| ) |
|
|
| out_nam_fasta=f'{prefix}{out_nam}_{flag}_{steps}.fasta' |
|
|
| write_fasta (y_data_reversed[0], out_nam_fasta) |
| |
| |
| out_nam=f'{prefix}{X_data_reversed[iisample]}_{flag}_{steps}.pdb' |
| |
| |
| shutil.copy (pdb_file, out_nam) |
| |
| |
| |
| |
| pdb_file=out_nam |
| |
| |
| |
| |
| |
| print (f"Properly named PDB file produced: {pdb_file}") |
| |
| view=show_pdb(pdb_file=pdb_file, flag=flag, |
| show_sidechains=show_sidechains, show_mainchains=show_mainchains, color=color) |
| view.show() |
|
|
|
|
| steps=steps+1 |
| |
| return pdb_file |
| |
| def sample_loop_FromModelB (model, |
| train_loader, |
| cond_scales=[7.5], |
| num_samples=2, |
| timesteps=100, |
| flag=0,foldproteins=False, |
| use_text_embedd=True,skip_steps=0, |
| |
| train_unet_number=1, |
| ynormfac=1, |
| prefix=None, |
| tokenizer_y=None, |
| Xnormfac=1, |
| tokenizer_X=None, |
| |
| ): |
| steps=0 |
| e=flag |
| for item in train_loader: |
|
|
| X_train_batch= item[0].to(device) |
| y_train_batch=item[1].to(device) |
|
|
| GT=y_train_batch.cpu().detach() |
| |
| GT= GT.unsqueeze(1) |
| num_samples = min (num_samples,y_train_batch.shape[0] ) |
| print (f"Producing {num_samples} samples...") |
| |
| print ('X_train_batch shape: ', X_train_batch.shape) |
|
|
| for iisample in range (len (cond_scales)): |
| |
| if use_text_embedd: |
| result=model.sample (x= X_train_batch,stop_at_unet_number=train_unet_number , |
| cond_scale=cond_scales[iisample], device=device, skip_steps=skip_steps) |
| else: |
| result=model.sample (x= None, x_data_tokenized= X_train_batch, |
| stop_at_unet_number=train_unet_number , |
| cond_scale=cond_scales[iisample],device=device,skip_steps=skip_steps) |
| |
| result=torch.round(result*ynormfac) |
| GT=torch.round (GT*ynormfac) |
|
|
| for samples in range (num_samples): |
| print ("sample ", samples, "out of ", num_samples) |
| |
| plt.plot (result[samples,0,:].cpu().detach().numpy(),label= f'Predicted') |
| plt.plot (GT[samples,0,:],label= f'GT {0}') |
| plt.legend() |
|
|
| outname = prefix+ f"sample-{samples}_condscale-{str (cond_scales[iisample])}_{e}_{steps}.jpg" |
| |
| plt.savefig(outname, dpi=200) |
| plt.show () |
| |
| |
| to_rev=result[:,0,:] |
| to_rev=to_rev.long().cpu().detach().numpy() |
| |
| y_data_reversed=tokenizer_y.sequences_to_texts (to_rev) |
|
|
| for iii in range (len(y_data_reversed)): |
| y_data_reversed[iii]=y_data_reversed[iii].upper().strip().replace(" ", "") |
| |
| |
| to_rev=GT[:,0,:] |
| to_rev=to_rev.long().cpu().detach().numpy() |
| |
| GT_y_data_reversed=tokenizer_y.sequences_to_texts (to_rev) |
|
|
| for iii in range (len(y_data_reversed)): |
| GT_y_data_reversed[iii]=GT_y_data_reversed[iii].upper().strip().replace(" ", "") |
| |
| |
| to_rev=torch.round (X_train_batch[:,:]*Xnormfac) |
| to_rev=to_rev.long().cpu().detach().numpy() |
| |
| X_data_reversed=tokenizer_X.sequences_to_texts (to_rev) |
|
|
| for iii in range (len(y_data_reversed)): |
| X_data_reversed[iii]=X_data_reversed[iii].upper().strip().replace(" ", "") |
|
|
| print (f"For {X_train_batch[samples,:].cpu().detach().numpy()} or {X_data_reversed[samples]}, predicted sequence: ", y_data_reversed[samples]) |
| print (f"Ground truth: {GT_y_data_reversed[samples]}") |
| |
| if foldproteins: |
| xbc=X_train_batch[samples,:].cpu().detach().numpy() |
| out_nam=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.1f" % xbc}) |
| tempname='temp' |
| pdb_file=foldandsavePDB ( |
| sequence=y_data_reversed[samples], |
| filename_out=tempname, |
| num_cycle=16, flag=flag, |
| |
| prefix=prefix |
| ) |
| |
| |
| out_nam=f'{prefix}{X_data_reversed[samples]}.pdb' |
| print (f'Original PDB: {pdb_file} OUT: {out_nam}') |
| shutil.copy (pdb_file, out_nam) |
| pdb_file=out_nam |
| print (f"Properly named PDB file produced: {pdb_file}") |
| |
| view=show_pdb(pdb_file=pdb_file, flag=flag, show_sidechains=show_sidechains, show_mainchains=show_mainchains, color=color) |
| view.show() |
|
|
| steps=steps+1 |
| if steps>num_samples: |
| break |
| |
| def sample_sequence_omegafold_pLM_ModelB ( |
| model, |
| X=None, |
| flag=0, |
| cond_scales=1., |
| foldproteins=False, |
| X_string=None, |
| x_data=None, |
| skip_steps=0, |
| inpaint_images = None, |
| inpaint_masks = None, |
| inpaint_resample_times = None, |
| init_images = None, |
| num_cycle=16, |
| |
| ynormfac=1, |
| train_unet_number=1, |
| tokenizer_X=None, |
| Xnormfac=1., |
| max_length=1., |
| prefix=None, |
| tokenizer_y=None, |
| |
| CKeys=None, |
| sample_dir=None, |
| steps=None, |
| e=None, |
| IF_showfig=True, |
| |
| pLM_Model=None, |
| pLM_Model_Name=None, |
| image_channels=None, |
| pLM_alphabet=None, |
| ): |
| |
| |
| |
|
|
| |
| if X!=None: |
| print (f"Producing {len(X)} samples...from text conditioning X...") |
| lenn_val=len(X) |
| if X_string!=None: |
| lenn_val=len(X_string) |
| print (f"Producing {len(X_string)} samples...from text conditioning X_String (from string)...") |
| |
| if x_data!=None: |
| print (f"Producing {len(x_data)} samples...from image conditingig x_data ...") |
| lenn_val=len(x_data) |
| print (x_data) |
| |
| print ('Device: ', device) |
| |
| pdb_file_list=[] |
| fasta_file_list=[] |
|
|
| |
| print('tot ', lenn_val) |
| for iisample in range (lenn_val): |
| print("Working on ", iisample) |
| X_cond=None |
| if X_string==None and X != None: |
| X_cond=torch.Tensor (X[iisample]).to(device).unsqueeze (0) |
| if X_string !=None: |
| XX = tokenizer_X.texts_to_sequences(X_string[iisample]) |
| XX= sequence.pad_sequences(XX, maxlen=max_length, padding='post', truncating='post') |
| XX=np.array(XX) |
| X_cond=torch.from_numpy(XX).float()/Xnormfac |
| print ('Tokenized and processed: ', X_cond) |
| |
| print ("X_cond=", X_cond) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| if tokenizer_X!=None: |
| |
| result_embedding=model.sample ( |
| x=X_cond, |
| stop_at_unet_number=train_unet_number , |
| cond_scale=cond_scales , |
| x_data=x_data[iisample], |
| |
| x_data_tokenized=None, |
| skip_steps=skip_steps, |
| inpaint_images = inpaint_images, |
| inpaint_masks = inpaint_masks, |
| inpaint_resample_times = inpaint_resample_times, |
| init_images = init_images,device=device, |
| |
| tokenizer_X=tokenizer_X, |
| Xnormfac=Xnormfac, |
| max_length=max_length, |
| ) |
| else: |
| |
| |
| |
| x_data_tokenized=torch.from_numpy(x_data[iisample]/Xnormfac) |
| x_data_tokenized=x_data_tokenized.to(torch.float) |
| |
| x_data_tokenized=x_data_tokenized.unsqueeze(0) |
| |
| |
| |
| if CKeys['Debug_TrainerPack']==3: |
| print("x_data_tokenized dim: ", x_data_tokenized.shape) |
| print("x_data_tokenized dtype: ", x_data_tokenized.dtype) |
| print("test x_data_tokenized!=None: ", x_data_tokenized!=None) |
| |
| result_embedding=model.sample ( |
| x=X_cond, |
| stop_at_unet_number=train_unet_number , |
| cond_scale=cond_scales , |
| x_data=None, |
| |
| x_data_tokenized=x_data_tokenized, |
| |
| skip_steps=skip_steps, |
| inpaint_images = inpaint_images, |
| inpaint_masks = inpaint_masks, |
| inpaint_resample_times = inpaint_resample_times, |
| init_images = init_images,device=device, |
| |
| tokenizer_X=tokenizer_X, |
| Xnormfac=Xnormfac, |
| max_length=max_length, |
| ) |
| |
| |
| |
| |
| |
| |
| |
| |
| if image_channels==33: |
| result_tokens,result_logits = convert_into_tokens_using_prob( |
| result_embedding, |
| pLM_Model_Name, |
| ) |
| else: |
| result_tokens,result_logits = convert_into_tokens( |
| pLM_Model, |
| result_embedding, |
| pLM_Model_Name, |
| ) |
| |
| result=result_tokens.unsqueeze(1) |
| |
| |
| print('result dim: ', result.shape) |
| |
| |
| fig=plt.figure() |
| plt.plot ( |
| result[0,0,:].cpu().detach().numpy(), |
| label= f'Predicted' |
| ) |
| |
| plt.legend() |
| outname = sample_dir+ f"sampled_from_X_{iisample}_condscale-{str (cond_scales)}_{e}_{steps}.jpg" |
| |
| if IF_showfig==1: |
| plt.show () |
| else: |
| plt.savefig(outname, dpi=200) |
| plt.close() |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| if tokenizer_X!=None: |
| |
| result_mask = read_mask_from_input( |
| tokenized_data=None, |
| mask_value=None, |
| seq_data=x_data[iisample], |
| max_seq_length=max_length, |
| ) |
| else: |
| |
| result_mask = read_mask_from_input( |
| tokenized_data=x_data_tokenized, |
| mask_value=0, |
| seq_data=None, |
| max_seq_length=None, |
| ) |
| |
| to_rev=result[:,0,:] |
| if CKeys['Debug_TrainerPack']==3: |
| print("on foldable result: ", to_rev[0]) |
| print("on result_logits: ", result_logits[0]) |
| print("on mask: ", result_mask[0]) |
| a = decode_one_ems_token_rec_for_folding_with_mask( |
| to_rev[0], |
| result_logits[0], |
| pLM_alphabet, |
| pLM_Model, |
| result_mask[0], |
| ) |
| print('One resu: ', a) |
|
|
| y_data_reversed=decode_many_ems_token_rec_for_folding_with_mask( |
| to_rev, |
| result_logits, |
| pLM_alphabet, |
| pLM_Model, |
| result_mask, |
| ) |
| if CKeys['Debug_TrainerPack']==3: |
| print("on y_data_reversed[0]: ", y_data_reversed[0]) |
| |
| |
| |
| |
| if X_cond != None: |
| X_cond=torch.round(X_cond*Xnormfac) |
|
|
| to_rev=X_cond[:,:] |
| to_rev=to_rev.long().cpu().detach().numpy() |
| print (to_rev.shape) |
| X_data_reversed=tokenizer_X.sequences_to_texts (to_rev) |
|
|
| for iii in range (len(y_data_reversed)): |
| X_data_reversed[iii]=X_data_reversed[iii].upper().strip().replace(" ", "") |
| if x_data !=None: |
| X_data_reversed=x_data |
| |
|
|
| |
| print (f"For {X} or {X_data_reversed[iisample]}, predicted sequence: ", y_data_reversed) |
| |
| |
| print("================================================") |
| print("foldproteins: ", foldproteins) |
| |
| if not foldproteins: |
| pdb_file=None |
| |
| else: |
| |
| if X_cond != None: |
| xbc=X_cond[iisample,:].cpu().detach().numpy() |
| out_nam=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.1f" % xbc})+f'_{flag}_{steps}' |
| if x_data !=None: |
| |
| |
| |
| out_nam=x_data[iisample] |
| |
| |
| out_nam=iisample |
| |
| |
| tempname='temp' |
| pdb_file, fasta_file=foldandsavePDB_pdb_fasta ( |
| sequence=y_data_reversed[0], |
| filename_out=tempname, |
| num_cycle=num_cycle, |
| flag=flag, |
| |
| |
| prefix=sample_dir, |
| ) |
|
|
| |
| |
| |
| |
| |
| out_nam_fasta=f'{sample_dir}DeNovoSampling_{iisample}_epo_{e}_step_{steps}.fasta' |
|
|
| write_fasta (y_data_reversed[0], out_nam_fasta) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| out_nam=f'{sample_dir}DeNovoSampling_{iisample}_epo_{e}_step_{steps}.pdb' |
| out_nam_fasta=f'{sample_dir}DeNovoSampling_{iisample}_epo_{e}_step_{steps}.fasta' |
| |
| |
| |
| shutil.copy (pdb_file, out_nam) |
| shutil.copy (fasta_file, out_nam_fasta) |
| |
| |
| |
| |
| |
| os.remove (pdb_file) |
| os.remove (fasta_file) |
| |
| pdb_file=out_nam |
| fasta_file=out_nam_fasta |
| pdb_file_list.append(pdb_file) |
| fasta_file_list.append(fasta_file) |
| |
| |
| print (f"Properly named PDB file produced: {pdb_file}") |
| if IF_showfig==1: |
| |
| view=show_pdb( |
| pdb_file=pdb_file, |
| flag=flag, |
| show_sidechains=show_sidechains, |
| show_mainchains=show_mainchains, |
| color=color |
| ) |
| view.show() |
|
|
|
|
| |
| |
| return pdb_file_list, fasta_file_list |
|
|
| |
| def sample_sequence_omegafold_ModelB ( |
| model, |
| X=None, |
| flag=0, |
| cond_scales=1., |
| foldproteins=False, |
| X_string=None, |
| x_data=None, |
| skip_steps=0, |
| inpaint_images = None, |
| inpaint_masks = None, |
| inpaint_resample_times = None, |
| init_images = None, |
| num_cycle=16, |
| |
| ynormfac=1, |
| train_unet_number=1, |
| tokenizer_X=None, |
| Xnormfac=1., |
| max_length=1., |
| prefix=None, |
| tokenizer_y=None, |
| |
| CKeys=None, |
| sample_dir=None, |
| steps=None, |
| e=None, |
| IF_showfig=True, |
| ): |
| |
| |
| |
|
|
| |
| if X!=None: |
| print (f"Producing {len(X)} samples...from text conditioning X...") |
| lenn_val=len(X) |
| if X_string!=None: |
| lenn_val=len(X_string) |
| print (f"Producing {len(X_string)} samples...from text conditioning X_String (from string)...") |
| |
| if x_data!=None: |
| print (f"Producing {len(x_data)} samples...from image conditingig x_data ...") |
| lenn_val=len(x_data) |
| print (x_data) |
| |
| print ('Device: ', device) |
|
|
| |
| print('tot ', lenn_val) |
| for iisample in range (lenn_val): |
| print("Working on ", iisample) |
| X_cond=None |
| if X_string==None and X != None: |
| X_cond=torch.Tensor (X[iisample]).to(device).unsqueeze (0) |
| if X_string !=None: |
| XX = tokenizer_X.texts_to_sequences(X_string[iisample]) |
| XX= sequence.pad_sequences(XX, maxlen=max_length, padding='post', truncating='post') |
| XX=np.array(XX) |
| X_cond=torch.from_numpy(XX).float()/Xnormfac |
| print ('Tokenized and processed: ', X_cond) |
| |
| print ("X_cond=", X_cond) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| if tokenizer_X!=None: |
| result=model.sample ( |
| x=X_cond, |
| stop_at_unet_number=train_unet_number , |
| cond_scale=cond_scales , |
| x_data=x_data[iisample], |
| |
| x_data_tokenized=None, |
| skip_steps=skip_steps, |
| inpaint_images = inpaint_images, |
| inpaint_masks = inpaint_masks, |
| inpaint_resample_times = inpaint_resample_times, |
| init_images = init_images,device=device, |
| |
| tokenizer_X=tokenizer_X, |
| Xnormfac=Xnormfac, |
| max_length=max_length, |
| ) |
| else: |
| x_data_tokenized=torch.from_numpy(x_data[iisample]/Xnormfac) |
| x_data_tokenized=x_data_tokenized.to(torch.float) |
| |
| if CKeys['Debug_TrainerPack']==1: |
| print("x_data_tokenized dim: ", x_data_tokenized.shape) |
| print("x_data_tokenized dtype: ", x_data_tokenized.dtype) |
| print("test: ", x_data_tokenized!=None) |
| result=model.sample ( |
| x=X_cond, |
| stop_at_unet_number=train_unet_number , |
| cond_scale=cond_scales , |
| x_data=None, |
| |
| x_data_tokenized=x_data_tokenized, |
| |
| skip_steps=skip_steps, |
| inpaint_images = inpaint_images, |
| inpaint_masks = inpaint_masks, |
| inpaint_resample_times = inpaint_resample_times, |
| init_images = init_images,device=device, |
| |
| tokenizer_X=tokenizer_X, |
| Xnormfac=Xnormfac, |
| max_length=max_length, |
| ) |
| |
| |
| result=torch.round(result*ynormfac) |
| |
| print('result dim: ', result.shape) |
| |
| fig=plt.figure() |
| plt.plot ( |
| result[0,0,:].cpu().detach().numpy(), |
| label= f'Predicted' |
| ) |
| |
| plt.legend() |
| outname = sample_dir+ f"sampled_from_X_{iisample}_condscale-{str (cond_scales)}_{e}_{steps}.jpg" |
| |
| if IF_showfig==1: |
| plt.show () |
| else: |
| plt.savefig(outname, dpi=200) |
| plt.close() |
| |
|
|
| to_rev=result[:,0,:] |
| to_rev=to_rev.long().cpu().detach().numpy() |
| print (to_rev.shape) |
| y_data_reversed=tokenizer_y.sequences_to_texts (to_rev) |
|
|
| for iii in range (len(y_data_reversed)): |
| y_data_reversed[iii]=y_data_reversed[iii].upper().strip().replace(" ", "") |
| |
| |
| if X_cond != None: |
| X_cond=torch.round(X_cond*Xnormfac) |
|
|
| to_rev=X_cond[:,:] |
| to_rev=to_rev.long().cpu().detach().numpy() |
| print (to_rev.shape) |
| X_data_reversed=tokenizer_X.sequences_to_texts (to_rev) |
|
|
| for iii in range (len(y_data_reversed)): |
| X_data_reversed[iii]=X_data_reversed[iii].upper().strip().replace(" ", "") |
| if x_data !=None: |
| X_data_reversed=x_data |
| |
|
|
| |
| print (f"For {X} or {X_data_reversed[iisample]}, predicted sequence: ", y_data_reversed) |
| |
| |
| print("================================================") |
| print("foldproteins: ", foldproteins) |
| |
| if not foldproteins: |
| pdb_file=None |
| |
| else: |
| |
| if X_cond != None: |
| xbc=X_cond[iisample,:].cpu().detach().numpy() |
| out_nam=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.1f" % xbc})+f'_{flag}_{steps}' |
| if x_data !=None: |
| |
| |
| |
| out_nam=x_data[iisample] |
| |
| |
| out_nam=iisample |
| |
| |
| tempname='temp' |
| pdb_file=foldandsavePDB ( |
| sequence=y_data_reversed[0], |
| filename_out=tempname, |
| num_cycle=num_cycle, |
| flag=flag, |
| |
| |
| prefix=sample_dir, |
| ) |
|
|
| |
| |
| |
| |
| |
| out_nam_fasta=f'{sample_dir}DeNovoSampling_{iisample}_epo_{e}_step_{steps}.fasta' |
|
|
| write_fasta (y_data_reversed[0], out_nam_fasta) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| out_nam=f'{sample_dir}DeNovoSampling_{iisample}_epo_{e}_step_{steps}.pdb' |
| |
| |
| |
| shutil.copy (pdb_file, out_nam) |
| |
| |
| |
| |
| pdb_file=out_nam |
| |
| |
| |
| print (f"Properly named PDB file produced: {pdb_file}") |
| if IF_showfig==1: |
| |
| view=show_pdb( |
| pdb_file=pdb_file, |
| flag=flag, |
| show_sidechains=show_sidechains, |
| show_mainchains=show_mainchains, |
| color=color |
| ) |
| view.show() |
|
|
|
|
| |
| |
| return pdb_file |
|
|
| |
| |
| def sample_sequence_omegafold_pLM_ModelB_For_ForcPath ( |
| model, |
| X=None, |
| flag=0, |
| cond_scales=[1.], |
| foldproteins=False, |
| X_string=None, |
| x_data=None, |
| skip_steps=0, |
| inpaint_images = None, |
| inpaint_masks = None, |
| inpaint_resample_times = None, |
| init_images = None, |
| num_cycle=16, |
| |
| ynormfac=1, |
| train_unet_number=1, |
| tokenizer_X=None, |
| Xnormfac=1., |
| max_length=1., |
| prefix=None, |
| tokenizer_y=None, |
| |
| CKeys=None, |
| sample_dir=None, |
| steps=None, |
| e=None, |
| IF_showfig=True, |
| |
| pLM_Model=None, |
| pLM_Model_Name=None, |
| image_channels=None, |
| pLM_alphabet=None, |
| ): |
| |
| |
| |
| X_cond=None |
| if X_string==None and X != None: |
| print (f"Producing {len(X)} samples...from text conditioning X...") |
| lenn_val=len(X) |
| |
| X_cond=torch.Tensor(X).to(device) |
| |
| |
| if X_string !=None: |
| print (f"Producing {len(X_string)} samples...from text conditioning X_String (from string)...") |
| lenn_val=len(X_string) |
| |
| XX = tokenizer_X.texts_to_sequences(X_string[iisample]) |
| |
| XX = tokenizer_X.texts_to_sequences(X_string) |
| XX= sequence.pad_sequences(XX, maxlen=max_length, padding='post', truncating='post') |
| XX=np.array(XX) |
| X_cond=torch.from_numpy(XX).float()/Xnormfac |
| print ('Tokenized and processed: ', X_cond) |
| |
| if x_data!=None: |
| print (f"Producing {len(x_data)} samples...from image conditingig x_data ...") |
| lenn_val=len(x_data) |
| if tokenizer_X==None: |
| |
| |
| x_data_tokenized=[] |
| for ii in range(lenn_val): |
| x_data_one_line=pad_a_np_arr(x_data[ii], 0.0, max_length) |
| x_data_tokenized.append(x_data_one_line) |
| x_data_tokenized=np.array(x_data_tokenized) |
| x_data_tokenized=torch.from_numpy(x_data_tokenized/Xnormfac) |
| else: |
| |
| pass |
| |
| |
| |
| |
| if tokenizer_X!=None: |
| |
| result_mask = read_mask_from_input( |
| tokenized_data=None, |
| mask_value=None, |
| seq_data=x_data, |
| max_seq_length=max_length, |
| ) |
| else: |
| |
| result_mask = read_mask_from_input( |
| tokenized_data=x_data_tokenized, |
| mask_value=0, |
| seq_data=None, |
| max_seq_length=None, |
| ) |
| |
| |
| print ("Input contents:") |
| print ("cond_img condition: x_data=\n", x_data) |
| print ("Text condition: X_cond=\n", X_cond) |
| |
| |
| pdb_file_list=[] |
| fasta_file_list=[] |
| |
| |
| for idx_cond, this_cond_scale in enumerate(cond_scales): |
| print(f"Working on cond_scale {str(this_cond_scale)}") |
| |
| |
| |
| if tokenizer_X!=None: |
| |
| result_embedding=model.sample ( |
| x=X_cond, |
| stop_at_unet_number=train_unet_number , |
| cond_scale=this_cond_scale, |
| x_data=x_data, |
| |
| x_data_tokenized=None, |
| skip_steps=skip_steps, |
| inpaint_images = inpaint_images, |
| inpaint_masks = inpaint_masks, |
| inpaint_resample_times = inpaint_resample_times, |
| init_images = init_images,device=device, |
| |
| tokenizer_X=tokenizer_X, |
| Xnormfac=Xnormfac, |
| max_length=max_length, |
| ) |
| else: |
| |
| |
| x_data_tokenized=x_data_tokenized.to(torch.float) |
| |
| |
| |
| if CKeys['Debug_TrainerPack']==3: |
| print("x_data_tokenized dim: ", x_data_tokenized.shape) |
| print("x_data_tokenized dtype: ", x_data_tokenized.dtype) |
| print("test x_data_tokenized!=None: ", x_data_tokenized!=None) |
| |
| result_embedding=model.sample ( |
| x=X_cond, |
| stop_at_unet_number=train_unet_number , |
| cond_scale=this_cond_scale, |
| x_data=None, |
| |
| x_data_tokenized=x_data_tokenized, |
| |
| skip_steps=skip_steps, |
| inpaint_images = inpaint_images, |
| inpaint_masks = inpaint_masks, |
| inpaint_resample_times = inpaint_resample_times, |
| init_images = init_images,device=device, |
| |
| tokenizer_X=tokenizer_X, |
| Xnormfac=Xnormfac, |
| max_length=max_length, |
| ) |
| |
| |
| |
| if image_channels==33: |
| |
| result_tokens,result_logits = convert_into_tokens_using_prob( |
| result_embedding, |
| pLM_Model_Name, |
| ) |
| else: |
| |
| |
| |
| result_tokens,result_logits = convert_into_tokens( |
| pLM_Model, |
| result_embedding, |
| pLM_Model_Name, |
| ) |
| |
| result=result_tokens.unsqueeze(1) |
| |
| |
| print('result dim: ', result.shape) |
| |
| |
| fig=plt.figure() |
| for ii in range(lenn_val): |
| plt.plot ( |
| result[ii,0,:].cpu().detach().numpy(), |
| label= f'Predicted for Input#{str(ii)}' |
| ) |
| |
| plt.legend() |
| outname = sample_dir+ f"DenovoInputXs_CondScale_No{str(idx_cond)}_Val_{str(this_cond_scale)}_{e}_{steps}.jpg" |
| |
| if IF_showfig==1: |
| plt.show () |
| else: |
| plt.savefig(outname, dpi=200) |
| plt.close() |
| |
| |
| to_rev=result[:,0,:] |
| if CKeys['Debug_TrainerPack']==3: |
| print("on foldable result: ", to_rev[0]) |
| print("on result_logits: ", result_logits[0]) |
| print("on mask: ", result_mask[0]) |
| a = decode_one_ems_token_rec_for_folding_with_mask( |
| to_rev[0], |
| result_logits[0], |
| pLM_alphabet, |
| pLM_Model, |
| result_mask[0], |
| ) |
| print('One resu: ', a) |
|
|
| y_data_reversed=decode_many_ems_token_rec_for_folding_with_mask( |
| to_rev, |
| result_logits, |
| pLM_alphabet, |
| pLM_Model, |
| result_mask, |
| ) |
| if CKeys['Debug_TrainerPack']==3: |
| print("on y_data_reversed[0]: ", y_data_reversed[0]) |
| |
| |
| if X_cond != None: |
| X_cond=torch.round(X_cond*Xnormfac) |
|
|
| to_rev=X_cond[:,:] |
| to_rev=to_rev.long().cpu().detach().numpy() |
| print (to_rev.shape) |
| X_data_reversed=tokenizer_X.sequences_to_texts (to_rev) |
|
|
| for iii in range (len(y_data_reversed)): |
| X_data_reversed[iii]=X_data_reversed[iii].upper().strip().replace(" ", "") |
| if x_data !=None: |
| |
| |
| X_data_reversed=x_data |
| |
| |
| for iisample in range(lenn_val): |
| print (f"For {X} or {X_data_reversed[iisample]}, predicted sequence: ", y_data_reversed[iisample]) |
| |
| out_nam_fasta=f'{sample_dir}DN_{iisample}_CondS_No_{idx_cond}_Val_{this_cond_scale}_epo_{e}_step_{steps}.fasta' |
| write_fasta (y_data_reversed[iisample], out_nam_fasta) |
| fasta_file_list.append(out_nam_fasta) |
| |
| |
| print("================================================") |
| print("foldproteins: ", foldproteins) |
| |
| if not foldproteins: |
| pdb_file=None |
|
|
| else: |
|
|
| if X_cond != None: |
| |
| xbc=X_cond[iisample,:].cpu().detach().numpy() |
| out_nam=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.4f" % xbc})+f'_{flag}_{steps}' |
| if x_data !=None: |
| pass |
| |
| |
| |
| |
| |
| |
| |
|
|
| tempname='temp' |
| pdb_file, fasta_file=foldandsavePDB_pdb_fasta ( |
| sequence=y_data_reversed[iisample], |
| filename_out=tempname, |
| num_cycle=num_cycle, |
| flag=flag, |
| |
| |
| prefix=sample_dir, |
| ) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| out_nam=f'{sample_dir}DN_{iisample}_CondS_No_{idx_cond}_Val_{this_cond_scale}_epo_{e}_step_{steps}.pdb' |
| |
|
|
| |
| |
| shutil.copy (pdb_file, out_nam) |
| |
| |
| |
| |
| |
| |
| os.remove (pdb_file) |
| os.remove (fasta_file) |
|
|
| pdb_file=out_nam |
| |
| pdb_file_list.append(pdb_file) |
| |
| |
| out_nam_inX=f'{sample_dir}DN_{iisample}_CondS_No_{idx_cond}_Val_{this_cond_scale}_epo_{e}_step_{steps}_input.txt' |
| if torch.is_tensor(X_data_reversed[iisample]): |
| |
| xbc=X_data_reversed[iisample].cpu().detach().numpy() |
| else: |
| xbc=X_data_reversed[iisample] |
| if tokenizer_X==None: |
| |
| out_inX=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.4f" % xbc}) |
| else: |
| |
| out_inX=xbc |
| with open(out_nam_inX, "w") as inX_file: |
| inX_file.write(out_inX) |
|
|
|
|
| print (f"Properly named PDB file produced: {pdb_file}") |
| if IF_showfig==1: |
| |
| view=show_pdb( |
| pdb_file=pdb_file, |
| flag=flag, |
| show_sidechains=show_sidechains, |
| show_mainchains=show_mainchains, |
| color=color |
| ) |
| view.show() |
| |
| |
| return pdb_file_list, fasta_file_list |
|
|
| |
| def sample_sequence_pLM_ModelB_For_ForcPath_Predictor ( |
| model, |
| X=None, |
| flag=0, |
| cond_scales=[1.], |
| foldproteins=False, |
| X_string=None, |
| x_data=None, |
| skip_steps=0, |
| inpaint_images = None, |
| inpaint_masks = None, |
| inpaint_resample_times = None, |
| init_images = None, |
| num_cycle=16, |
| |
| ynormfac=1, |
| train_unet_number=1, |
| tokenizer_X=None, |
| Xnormfac=1., |
| max_length=1., |
| prefix=None, |
| tokenizer_y=None, |
| |
| CKeys=None, |
| sample_dir=None, |
| steps=None, |
| e=None, |
| IF_showfig=True, |
| |
| pLM_Model=None, |
| pLM_Model_Name=None, |
| image_channels=None, |
| pLM_alphabet=None, |
| |
| esm_layer=None, |
| ): |
| |
| |
| |
| |
| |
| X_cond=None |
| if X_string==None and X != None: |
| print (f"Producing {len(X)} samples...from text conditioning X...") |
| lenn_val=len(X) |
| |
| X_cond=torch.Tensor(X).to(device) |
| |
| |
| if X_string !=None: |
| print (f"Producing {len(X_string)} samples...from text conditioning X_String (from string)...") |
| lenn_val=len(X_string) |
| |
| XX = tokenizer_X.texts_to_sequences(X_string[iisample]) |
| |
| XX = tokenizer_X.texts_to_sequences(X_string) |
| XX= sequence.pad_sequences(XX, maxlen=max_length, padding='post', truncating='post') |
| XX=np.array(XX) |
| X_cond=torch.from_numpy(XX).float()/Xnormfac |
| print ('Tokenized and processed: ', X_cond) |
| |
| |
| if x_data!=None: |
| |
| print (f"Producing {len(x_data)} samples...from image conditingig x_data ...") |
| lenn_val=len(x_data) |
| seq_len_list=[] |
| for this_AA in x_data: |
| seq_len_list.append(len(this_AA)) |
| |
| |
| print ("Input contents:") |
| print ("cond_img condition: x_data=\n", x_data) |
| print ("Text condition: X_cond=\n", X_cond) |
| |
| |
| |
| resu_prediction={} |
| |
| for idx_cond, this_cond_scale in enumerate(cond_scales): |
| print(f"Working on cond_scale {str(this_cond_scale)}") |
| |
| |
| |
| |
| result_embedding=model.sample ( |
| x=X_cond, |
| stop_at_unet_number=train_unet_number , |
| cond_scale=this_cond_scale, |
| x_data=x_data, |
| |
| x_data_tokenized=None, |
| skip_steps=skip_steps, |
| inpaint_images = inpaint_images, |
| inpaint_masks = inpaint_masks, |
| inpaint_resample_times = inpaint_resample_times, |
| init_images = init_images,device=device, |
| |
| tokenizer_X=tokenizer_X, |
| Xnormfac=Xnormfac, |
| max_length=max_length, |
| |
| pLM_Model=pLM_Model, |
| pLM_alphabet=pLM_alphabet, |
| esm_layer=esm_layer, |
| pLM_Model_Name=pLM_Model_Name, |
| |
| ) |
| |
| |
| |
| |
| |
| result_embedding=torch.mean(result_embedding, 1) |
| |
| |
| |
| result_mask = read_mask_from_input( |
| tokenized_data=None, |
| mask_value=0.0, |
| seq_data=x_data, |
| max_seq_length=max_length, |
| ) |
| |
| |
| result = result_embedding.cpu()*result_mask |
| |
| y_data_reversed = result*ynormfac |
| |
| prediction_list = [] |
| for ii in range(len(x_data)): |
| prediction_list.append( |
| y_data_reversed[ii, :seq_len_list[ii]+1] |
| ) |
| if CKeys['Debug_TrainerPack']==3: |
| print("check prediction dim:") |
| print(f"model output: ", y_data_reversed[0]) |
| print(f"prediction output: ", prediction_list[0]) |
| |
| |
| resu_prediction[str(this_cond_scale)]=prediction_list |
| |
| return resu_prediction,seq_len_list |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| def sample_loop_omegafold_pLM_ModelB ( |
| model, |
| train_loader, |
| cond_scales=[7.5], |
| num_samples=2, |
| timesteps=100, |
| flag=0, |
| foldproteins=False, |
| use_text_embedd=True, |
| skip_steps=0, |
| |
| train_unet_number=1, |
| ynormfac=1, |
| prefix=None, |
| tokenizer_y=None, |
| Xnormfac=1, |
| tokenizer_X=None, |
| |
| CKeys=None, |
| sample_dir=None, |
| steps=None, |
| e=None, |
| IF_showfig=True, |
| |
| pLM_Model=None, |
| pLM_Model_Name=None, |
| image_channels=None, |
| pLM_alphabet=None, |
| ): |
| |
| |
| |
| |
| |
| |
| for idx, item in enumerate(train_loader): |
|
|
| X_train_batch= item[0].to(device) |
| y_train_batch=item[1].to(device) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| GT=y_train_batch.cpu().detach() |
|
|
| GT= GT.unsqueeze(1) |
| if num_samples>y_train_batch.shape[0]: |
| print("Warning: sampling # > len(mini_batch)") |
|
|
| num_samples = min (num_samples,y_train_batch.shape[0] ) |
| print (f"Producing {num_samples} samples...") |
| X_train_batch_picked = X_train_batch[:num_samples,:] |
| print ('After pLM, (TEST) X_batch shape: ', X_train_batch_picked.shape) |
| |
| for iisample in range (len (cond_scales)): |
|
|
| if use_text_embedd: |
| result_embedding=model.sample ( |
| |
| x= X_train_batch_picked, |
| stop_at_unet_number=train_unet_number , |
| cond_scale=cond_scales[iisample], |
| device=device, |
| skip_steps=skip_steps |
| ) |
| else: |
| result_embedding=model.sample ( |
| x= None, |
| |
| x_data_tokenized= X_train_batch_picked, |
| stop_at_unet_number=train_unet_number , |
| cond_scale=cond_scales[iisample], |
| device=device, |
| skip_steps=skip_steps |
| ) |
| |
| if image_channels==33: |
| result_tokens,result_logits = convert_into_tokens_using_prob( |
| result_embedding, |
| pLM_Model_Name, |
| ) |
| else: |
| |
| |
| |
| result_tokens,result_logits = convert_into_tokens( |
| pLM_Model, |
| result_embedding, |
| pLM_Model_Name, |
| ) |
|
|
| |
| |
| |
| |
| result=result_tokens.unsqueeze(1) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| result_mask = read_mask_from_input( |
| tokenized_data=X_train_batch[:num_samples], |
| mask_value=0, |
| seq_data=None, |
| max_seq_length=None, |
| ) |
| to_rev=result[:,0,:] |
| if CKeys['Debug_TrainerPack']==3: |
| print("on foldable result: ", to_rev[0]) |
| print("on result_logits: ", result_logits[0]) |
| print("on mask: ", result_mask[0]) |
| a = decode_one_ems_token_rec_for_folding_with_mask( |
| to_rev[0], |
| result_logits[0], |
| pLM_alphabet, |
| pLM_Model, |
| result_mask[0], |
| ) |
| print('One resu: ', a) |
|
|
| y_data_reversed=decode_many_ems_token_rec_for_folding_with_mask( |
| to_rev, |
| result_logits, |
| pLM_alphabet, |
| pLM_Model, |
| result_mask, |
| ) |
| if CKeys['Debug_TrainerPack']==3: |
| print("on y_data_reversed[0]: ", y_data_reversed[0]) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| to_rev=GT[:,0,:] |
| GT_y_data_reversed=decode_many_ems_token_rec( |
| to_rev, |
| pLM_alphabet, |
| ) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| to_rev=X_train_batch[:,:]*Xnormfac |
| to_rev=to_rev.cpu().detach().numpy() |
| |
| |
| |
| |
| |
| n_batch=to_rev.shape[0] |
| n_embed=to_rev.shape[1] |
| to_rev_1 = np.zeros(to_rev.shape) |
| to_rev_1[:,0:n_embed-1]=to_rev[:,1:n_embed] |
|
|
| |
| if tokenizer_X!=None: |
| |
| to_rev_1 = np.round(to_rev_1) |
| X_data_reversed=tokenizer_X.sequences_to_texts (to_rev_1) |
| for iii in range (len(y_data_reversed)): |
| X_data_reversed[iii]=X_data_reversed[iii].upper().strip().replace(" ", "") |
| else: |
| X_data_reversed=to_rev_1.copy() |
| |
| if CKeys['Debug_TrainerPack']==1: |
| print("X_data_reversed: ", X_data_reversed) |
| |
|
|
| for samples in range (num_samples): |
| print ("sample ", samples+1, "out of ", num_samples) |
|
|
| fig=plt.figure() |
| plt.plot ( |
| result[samples,0,:].cpu().detach().numpy(), |
| label= f'Predicted' |
| ) |
| plt.plot ( |
| GT[samples,0,:], |
| label= f'GT {0}' |
| ) |
| plt.legend() |
| outname = sample_dir+ f"Batch_{idx}_sample_{samples}_condscale-{str (cond_scales[iisample])}_{e}_{steps}.jpg" |
| if IF_showfig==1: |
| plt.show() |
| else: |
| plt.savefig(outname, dpi=200) |
| plt.close () |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| print (f"For {X_train_batch[samples,:].cpu().detach().numpy()} \nor\n {X_data_reversed[samples]}, ") |
| print (f"predicted sequence: {y_data_reversed[samples]}") |
| print (f"Ground truth: {GT_y_data_reversed[samples]}") |
| error=string_diff (y_data_reversed[samples], GT_y_data_reversed[samples])/len (GT_y_data_reversed[samples]) |
| print(f"Recovery ratio(Ref): {1.-error}") |
| |
| |
| |
| |
| |
| |
| |
| xbc = X_data_reversed[samples] |
| if type(xbc)==str: |
| out_nam_content=xbc |
| else: |
| out_nam_content=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.4f" % xbc}) |
| |
| out_nam_inX=f'{sample_dir}SamplingLoop_B_{idx}_Sample_{samples}_condscale-{str (cond_scales[iisample])}_epo_{e}_step_{steps}.txt' |
| |
| |
| with open(out_nam_inX, "w") as inX_file: |
| |
| inX_file.write(out_nam_content) |
| |
| out_nam_OuY_PR=f'{sample_dir}SamplingLoop_B_{idx}_Sample_{samples}_condscale-{str (cond_scales[iisample])}_epo_{e}_step_{steps}_predict.fasta' |
| with open(out_nam_OuY_PR, "w") as ouY_fasta: |
| ouY_fasta.write(f">Predicted\n") |
| ouY_fasta.write(y_data_reversed[samples]) |
| |
| out_nam_OuY_GT=f'{sample_dir}SamplingLoop_B_{idx}_Sample_{samples}_condscale-{str (cond_scales[iisample])}_epo_{e}_step_{steps}_recGT.fasta' |
| with open(out_nam_OuY_GT, "w") as ouY_fasta: |
| ouY_fasta.write(f">reconstructed GT, recoverabliblity: {1.-error}\n") |
| ouY_fasta.write(GT_y_data_reversed[samples]) |
| |
| |
|
|
| if foldproteins: |
| |
| tempname='temp' |
| pdb_file,fasta_file=foldandsavePDB_pdb_fasta ( |
| sequence=y_data_reversed[samples], |
| filename_out=tempname, |
| num_cycle=16, flag=flag, |
| |
| prefix=prefix |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| out_nam=f'{sample_dir}SamplingLoop_B_{idx}_Sample_{samples}_condscale-{str (cond_scales[iisample])}_epo_{e}_step_{steps}.pdb' |
| out_nam_seq=f'{sample_dir}SamplingLoop_B_{idx}_Sample_{samples}_condscale-{str (cond_scales[iisample])}_epo_{e}_step_{steps}.fasta' |
| |
| |
| if CKeys['Debug_TrainerPack']==1: |
| print("pdb_file: ", pdb_file) |
| print("out_nam: ", out_nam) |
| |
| print (f'Original PDB: {pdb_file} OUT: {out_nam}') |
| shutil.copy (pdb_file, out_nam) |
| shutil.copy (fasta_file, out_nam_seq) |
| |
| |
| |
| os.remove (pdb_file) |
| os.remove (fasta_file) |
| |
| |
| pdb_file=out_nam |
| print (f"Properly named PDB file produced: {pdb_file}") |
| print (f"input X for sampling stored: {pdb_file}") |
| |
| if IF_showfig==1: |
| view=show_pdb( |
| pdb_file=pdb_file, |
| flag=flag, |
| show_sidechains=show_sidechains, |
| show_mainchains=show_mainchains, |
| color=color |
| ) |
| view.show() |
| |
| |
| def sample_loop_omegafold_pLM_ModelB_Predictor ( |
| model, |
| train_loader, |
| cond_scales=[7.5], |
| num_samples=2, |
| timesteps=100, |
| flag=0, |
| foldproteins=False, |
| use_text_embedd=True, |
| skip_steps=0, |
| |
| train_unet_number=1, |
| ynormfac=1, |
| prefix=None, |
| tokenizer_y=None, |
| Xnormfac=1, |
| tokenizer_X=None, |
| |
| CKeys=None, |
| sample_dir=None, |
| steps=None, |
| e=None, |
| IF_showfig=True, |
| |
| pLM_Model=None, |
| pLM_Model_Name=None, |
| image_channels=None, |
| pLM_alphabet=None, |
| |
| esm_layer=None, |
| ): |
| |
| |
| |
| |
| |
| |
| val_epoch_MSE_list=[] |
| resu_pred = {} |
| resu_grou = {} |
| |
| for iisample in range (len (cond_scales)): |
| |
| |
| val_epoch_MSE=0. |
| num_rec=0 |
| this_prediction = [] |
| this_groundtruth = [] |
| |
| for idx, item in enumerate(train_loader): |
|
|
| X_train_batch= item[0].to(device) |
| y_train_batch=item[1].to(device) |
|
|
| |
| |
| if num_samples>y_train_batch.shape[0]: |
| print("Warning: sampling # > len(mini_batch)") |
| num_samples = min (num_samples,y_train_batch.shape[0]) |
| print (f"Producing {num_samples} samples...") |
| X_train_batch_picked = X_train_batch[:num_samples,:] |
| GT=y_train_batch.cpu().detach() |
| GT_picked = GT[:num_samples,:] |
| |
| |
| |
| |
| |
| |
| if use_text_embedd: |
| result_embedding=model.sample ( |
| |
| x= X_train_batch_picked, |
| stop_at_unet_number=train_unet_number , |
| cond_scale=cond_scales[iisample], |
| device=device, |
| skip_steps=skip_steps, |
| |
| pLM_Model_Name=pLM_Model_Name, |
| pLM_Model=pLM_Model, |
| pLM_alphabet=pLM_alphabet, |
| esm_layer=esm_layer, |
| ) |
| else: |
| result_embedding=model.sample ( |
| x= None, |
| |
| x_data_tokenized= X_train_batch_picked, |
| stop_at_unet_number=train_unet_number , |
| cond_scale=cond_scales[iisample], |
| device=device, |
| skip_steps=skip_steps, |
| |
| pLM_Model_Name=pLM_Model_Name, |
| pLM_Model=pLM_Model, |
| pLM_alphabet=pLM_alphabet, |
| esm_layer=esm_layer, |
| ) |
| |
| |
| |
| |
| |
| result_embedding=torch.mean(result_embedding, 1) |
| |
| |
| |
| result_mask = read_mask_from_input( |
| tokenized_data=X_train_batch[:num_samples], |
| mask_value=0.0, |
| seq_data=None, |
| max_seq_length=None, |
| ) |
| |
| result = result_embedding*result_mask |
| result = result.cpu() |
| |
| |
| |
| |
| with torch.no_grad(): |
| val_loss_MSE = criterion_MSE_sum( |
| result, |
| GT_picked, |
| ) |
| val_epoch_MSE += val_loss_MSE.item()/GT_picked.shape[1] |
| num_rec += len(GT_picked) |
| |
| |
| y_data_reversed = result*ynormfac |
| |
| GT_y_data_reversed = GT_picked*ynormfac |
| |
| for ibat in range (GT_picked.shape[0]): |
| this_prediction.append (np.array( y_data_reversed[ibat,:].cpu() )) |
| this_groundtruth.append (np.array( GT_y_data_reversed[ibat,:].cpu() )) |
| |
| |
| |
| |
| |
| |
| |
| TestSet_MSE = val_epoch_MSE/num_rec |
| resu_pred[str(cond_scales[iisample])] = this_prediction |
| resu_grou[str(cond_scales[iisample])] = this_groundtruth |
| |
| |
| val_epoch_MSE_list.append(TestSet_MSE) |
| |
| return val_epoch_MSE_list, resu_pred, resu_grou |
| |
|
|
| |
| |
| def sample_loop_omegafold_ModelB ( |
| model, |
| train_loader, |
| cond_scales=[7.5], |
| num_samples=2, |
| timesteps=100, |
| flag=0, |
| foldproteins=False, |
| use_text_embedd=True, |
| skip_steps=0, |
| |
| train_unet_number=1, |
| ynormfac=1, |
| prefix=None, |
| tokenizer_y=None, |
| Xnormfac=1, |
| tokenizer_X=None, |
| |
| CKeys=None, |
| sample_dir=None, |
| steps=None, |
| e=None, |
| IF_showfig=True, |
| ): |
| |
| |
| |
| |
| |
| |
| for idx, item in enumerate(train_loader): |
|
|
| X_train_batch= item[0].to(device) |
| y_train_batch=item[1].to(device) |
|
|
| GT=y_train_batch.cpu().detach() |
|
|
| GT= GT.unsqueeze(1) |
| if num_samples>y_train_batch.shape[0]: |
| print("Warning: sampling # > len(mini_batch)") |
|
|
| num_samples = min (num_samples,y_train_batch.shape[0] ) |
| print (f"Producing {num_samples} samples...") |
| X_train_batch_picked = X_train_batch[:num_samples,:] |
| print ('(TEST) X_batch shape: ', X_train_batch_picked.shape) |
|
|
| for iisample in range (len (cond_scales)): |
|
|
| if use_text_embedd: |
| result=model.sample ( |
| |
| x= X_train_batch_picked, |
| stop_at_unet_number=train_unet_number , |
| cond_scale=cond_scales[iisample], |
| device=device, |
| skip_steps=skip_steps |
| ) |
| else: |
| result=model.sample ( |
| x= None, |
| |
| x_data_tokenized= X_train_batch_picked, |
| stop_at_unet_number=train_unet_number , |
| cond_scale=cond_scales[iisample], |
| device=device, |
| skip_steps=skip_steps |
| ) |
| |
| result=torch.round(result*ynormfac) |
| GT=torch.round (GT*ynormfac) |
|
|
| for samples in range (num_samples): |
| print ("sample ", samples+1, "out of ", num_samples) |
|
|
| fig=plt.figure() |
| plt.plot ( |
| result[samples,0,:].cpu().detach().numpy(), |
| label= f'Predicted' |
| ) |
| plt.plot ( |
| GT[samples,0,:], |
| label= f'GT {0}' |
| ) |
| plt.legend() |
| outname = sample_dir+ f"Batch_{idx}_sample_{samples}_condscale-{str (cond_scales[iisample])}_{e}_{steps}.jpg" |
| if IF_showfig==1: |
| plt.show() |
| else: |
| plt.savefig(outname, dpi=200) |
| plt.close () |
|
|
| |
| to_rev=result[:,0,:] |
| to_rev=to_rev.long().cpu().detach().numpy() |
|
|
| y_data_reversed=tokenizer_y.sequences_to_texts (to_rev) |
|
|
| for iii in range (len(y_data_reversed)): |
| y_data_reversed[iii]=y_data_reversed[iii].upper().strip().replace(" ", "") |
|
|
| |
| to_rev=GT[:,0,:] |
| to_rev=to_rev.long().cpu().detach().numpy() |
|
|
| GT_y_data_reversed=tokenizer_y.sequences_to_texts (to_rev) |
|
|
| for iii in range (len(y_data_reversed)): |
| GT_y_data_reversed[iii]=GT_y_data_reversed[iii].upper().strip().replace(" ", "") |
|
|
| |
| to_rev=torch.round (X_train_batch[:,:]*Xnormfac) |
| to_rev=to_rev.long().cpu().detach().numpy() |
| |
| |
| if tokenizer_X!=None: |
| X_data_reversed=tokenizer_X.sequences_to_texts (to_rev) |
| for iii in range (len(y_data_reversed)): |
| X_data_reversed[iii]=X_data_reversed[iii].upper().strip().replace(" ", "") |
| else: |
| X_data_reversed=to_rev.copy() |
| |
| if CKeys['Debug_TrainerPack']==1: |
| print("X_data_reversed: ", X_data_reversed) |
| |
|
|
| print (f"For {X_train_batch[samples,:].cpu().detach().numpy()} or {X_data_reversed[samples]}, \npredicted sequence: ", y_data_reversed[samples]) |
| print (f"Ground truth: {GT_y_data_reversed[samples]}") |
|
|
| if foldproteins: |
| xbc=X_train_batch[samples,:].cpu().detach().numpy() |
| out_nam=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.1f" % xbc}) |
| tempname='temp' |
| pdb_file=foldandsavePDB ( |
| sequence=y_data_reversed[samples], |
| filename_out=tempname, |
| num_cycle=16, flag=flag, |
| |
| prefix=prefix |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| out_nam=f'{sample_dir}SamplingLoop_B_{idx}_Sample_{samples}_condscale-{str (cond_scales[iisample])}_epo_{e}_step_{steps}.pdb' |
| out_nam_inX=f'{sample_dir}SamplingLoop_B_{idx}_Sample_{samples}_condscale-{str (cond_scales[iisample])}_epo_{e}_step_{steps}.txt' |
| |
| if CKeys['Debug_TrainerPack']==1: |
| print("pdb_file: ", pdb_file) |
| print("out_nam: ", out_nam) |
| |
| print (f'Original PDB: {pdb_file} OUT: {out_nam}') |
| shutil.copy (pdb_file, out_nam) |
| |
| with open(out_nam_inX, "w") as inX_file: |
| inX_file.write(f'{X_data_reversed[samples]}\n') |
| |
| pdb_file=out_nam |
| print (f"Properly named PDB file produced: {pdb_file}") |
| print (f"input X for sampling stored: {pdb_file}") |
| |
| if IF_showfig==1: |
| view=show_pdb( |
| pdb_file=pdb_file, |
| flag=flag, |
| show_sidechains=show_sidechains, |
| show_mainchains=show_mainchains, |
| color=color |
| ) |
| view.show() |
|
|
| |
| |
| |
| |
|
|
| |
| |
| def sample_loop_FromModelB (model, |
| train_loader, |
| cond_scales=[7.5], |
| num_samples=2, |
| timesteps=100, |
| flag=0,foldproteins=False, |
| use_text_embedd=True,skip_steps=0, |
| |
| train_unet_number=1, |
| ynormfac=1, |
| prefix=None, |
| tokenizer_y=None, |
| Xnormfac=1, |
| tokenizer_X=None, |
| |
| ): |
| steps=0 |
| e=flag |
| for item in train_loader: |
|
|
| X_train_batch= item[0].to(device) |
| y_train_batch=item[1].to(device) |
|
|
| GT=y_train_batch.cpu().detach() |
| |
| GT= GT.unsqueeze(1) |
| num_samples = min (num_samples,y_train_batch.shape[0] ) |
| print (f"Producing {num_samples} samples...") |
| |
| print ('X_train_batch shape: ', X_train_batch.shape) |
|
|
| for iisample in range (len (cond_scales)): |
| |
| if use_text_embedd: |
| result=model.sample (x= X_train_batch,stop_at_unet_number=train_unet_number , |
| cond_scale=cond_scales[iisample], device=device, skip_steps=skip_steps) |
| else: |
| result=model.sample (x= None, x_data_tokenized= X_train_batch, |
| stop_at_unet_number=train_unet_number , |
| cond_scale=cond_scales[iisample],device=device,skip_steps=skip_steps) |
| |
| result=torch.round(result*ynormfac) |
| GT=torch.round (GT*ynormfac) |
|
|
| for samples in range (num_samples): |
| print ("sample ", samples, "out of ", num_samples) |
| |
| plt.plot (result[samples,0,:].cpu().detach().numpy(),label= f'Predicted') |
| plt.plot (GT[samples,0,:],label= f'GT {0}') |
| plt.legend() |
|
|
| outname = prefix+ f"sample-{samples}_condscale-{str (cond_scales[iisample])}_{e}_{steps}.jpg" |
| |
| plt.savefig(outname, dpi=200) |
| plt.show () |
| |
| |
| to_rev=result[:,0,:] |
| to_rev=to_rev.long().cpu().detach().numpy() |
| |
| y_data_reversed=tokenizer_y.sequences_to_texts (to_rev) |
|
|
| for iii in range (len(y_data_reversed)): |
| y_data_reversed[iii]=y_data_reversed[iii].upper().strip().replace(" ", "") |
| |
| |
| to_rev=GT[:,0,:] |
| to_rev=to_rev.long().cpu().detach().numpy() |
| |
| GT_y_data_reversed=tokenizer_y.sequences_to_texts (to_rev) |
|
|
| for iii in range (len(y_data_reversed)): |
| GT_y_data_reversed[iii]=GT_y_data_reversed[iii].upper().strip().replace(" ", "") |
| |
| |
| to_rev=torch.round (X_train_batch[:,:]*Xnormfac) |
| to_rev=to_rev.long().cpu().detach().numpy() |
| |
| X_data_reversed=tokenizer_X.sequences_to_texts (to_rev) |
|
|
| for iii in range (len(y_data_reversed)): |
| X_data_reversed[iii]=X_data_reversed[iii].upper().strip().replace(" ", "") |
|
|
| print (f"For {X_train_batch[samples,:].cpu().detach().numpy()} or {X_data_reversed[samples]}, predicted sequence: ", y_data_reversed[samples]) |
| print (f"Ground truth: {GT_y_data_reversed[samples]}") |
| |
| if foldproteins: |
| xbc=X_train_batch[samples,:].cpu().detach().numpy() |
| out_nam=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.1f" % xbc}) |
| tempname='temp' |
| pdb_file=foldandsavePDB ( |
| sequence=y_data_reversed[samples], |
| filename_out=tempname, |
| num_cycle=16, flag=flag, |
| |
| prefix=prefix |
| ) |
| |
| |
| out_nam=f'{prefix}{X_data_reversed[samples]}.pdb' |
| print (f'Original PDB: {pdb_file} OUT: {out_nam}') |
| shutil.copy (pdb_file, out_nam) |
| pdb_file=out_nam |
| print (f"Properly named PDB file produced: {pdb_file}") |
| |
| view=show_pdb(pdb_file=pdb_file, flag=flag, show_sidechains=show_sidechains, show_mainchains=show_mainchains, color=color) |
| view.show() |
|
|
| steps=steps+1 |
| if steps>num_samples: |
| break |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| cal_norm_prob = nn.Softmax(dim=2) |
|
|
| def train_loop_Model_B ( |
| model, |
| train_loader, |
| test_loader, |
| |
| optimizer=None, |
| print_every=10, |
| epochs= 300, |
| start_ep=0, |
| start_step=0, |
| train_unet_number=1, |
| print_loss_every_steps=1000, |
| |
| trainer=None, |
| plot_unscaled=False, |
| max_batch_size=4, |
| save_model=False, |
| cond_scales=[7.5], |
| num_samples=2, |
| foldproteins=False, |
| cond_image=False, |
| |
| |
| device=None, |
| loss_list=[], |
| epoch_list=[], |
| train_hist_file=None, |
| train_hist_file_full=None, |
| prefix=None, |
| Xnormfac=1., |
| ynormfac=1., |
| tokenizer_X=None, |
| tokenizer_y=None, |
| test_condition_list=[], |
| max_length=1, |
| CKeys=None, |
| sample_steps=1, |
| sample_dir=None, |
| save_every_epoch=1, |
| save_point_info_file=None, |
| store_dir=None, |
| |
| pLM_Model_Name=None, |
| image_channels=None, |
| print_error=False, |
| ): |
| |
| if not exists (trainer): |
| if not exists (optimizer): |
| print ("ERROR: If trainer not used, need to provide optimizer.") |
| if exists (trainer): |
| print ("Trainer provided... will be used") |
| |
| |
| |
| |
| |
| steps=start_step |
| added_steps=0 |
| |
| loss_total=0 |
| |
| |
| if pLM_Model_Name=='None': |
| pLM_Model=None |
| |
| elif pLM_Model_Name=='esm2_t33_650M_UR50D': |
| |
| esm_layer=33 |
| pLM_Model, esm_alphabet = esm.pretrained.esm2_t33_650M_UR50D() |
| len_toks=len(esm_alphabet.all_toks) |
| pLM_Model.eval() |
| pLM_Model. to(device) |
| |
| elif pLM_Model_Name=='esm2_t36_3B_UR50D': |
| |
| esm_layer=36 |
| pLM_Model, esm_alphabet = esm.pretrained.esm2_t36_3B_UR50D() |
| len_toks=len(esm_alphabet.all_toks) |
| pLM_Model.eval() |
| pLM_Model. to(device) |
| |
| elif pLM_Model_Name=='esm2_t30_150M_UR50D': |
| |
| esm_layer=30 |
| pLM_Model, esm_alphabet = esm.pretrained.esm2_t30_150M_UR50D() |
| len_toks=len(esm_alphabet.all_toks) |
| pLM_Model.eval() |
| pLM_Model. to(device) |
| |
| elif pLM_Model_Name=='esm2_t12_35M_UR50D': |
| |
| esm_layer=12 |
| pLM_Model, esm_alphabet = esm.pretrained.esm2_t12_35M_UR50D() |
| len_toks=len(esm_alphabet.all_toks) |
| pLM_Model.eval() |
| pLM_Model. to(device) |
| |
| else: |
| print("pLM model is missing...") |
|
|
| |
| for e in range(1, epochs+1): |
| |
| |
|
|
| torch.cuda.empty_cache() |
| print ("######################################################################################") |
| start = time.time() |
| print ("NOW: Training epoch: ", e+start_ep) |
|
|
| |
| train_epoch_loss = 0 |
| model.train() |
|
|
| print ("Loop over ", len(train_loader), " batches (print . every ", print_every, " steps)") |
|
|
| for item in train_loader: |
| steps += 1 |
| added_steps += 1 |
|
|
| X_train_batch= item[0].to(device) |
| y_train_batch= item[1].to(device) |
| |
| if CKeys["Debug_TrainerPack"]==1: |
| print("Initial unload the dataloader items: ...") |
| print("X_train_batch.dim: ", X_train_batch.shape) |
| print("y_train_batch.dim: ", y_train_batch.shape) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| if pLM_Model_Name=='None': |
| |
| y_train_batch_in = y_train_batch.unsqueeze(1) |
| X_train_batch_in = X_train_batch.unsqueeze(1) |
| |
| else: |
| with torch.no_grad(): |
| results = pLM_Model( |
| y_train_batch, |
| repr_layers=[esm_layer], |
| return_contacts=False, |
| ) |
| y_train_batch_in = results["representations"][esm_layer] |
| |
| if image_channels==33: |
| with torch.no_grad(): |
| |
| y_train_batch_in = pLM_Model.lm_head( |
| y_train_batch_in |
| ) |
| |
| y_train_batch_in = cal_norm_prob(y_train_batch_in) |
| |
| |
| y_train_batch_in = rearrange( |
| y_train_batch_in, |
| 'b l c -> b c l' |
| ) |
| |
| |
| X_train_batch_in = X_train_batch.unsqueeze(1).repeat(1,image_channels,1) |
|
|
| |
| |
| if CKeys["Debug_TrainerPack"]==1: |
| print("After pLM model, the shape of X and y for training:") |
| print("X_train_batch_in.dim: ", X_train_batch_in.shape) |
| print("y_train_batch_in.dim: ", y_train_batch_in.shape) |
| |
| |
| |
| if exists (trainer): |
| |
| if cond_image==False: |
| loss = trainer( |
| y_train_batch.unsqueeze(1) , |
| x=X_train_batch, |
| unet_number=train_unet_number, |
| max_batch_size = max_batch_size, |
| ) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| if cond_image==True: |
| loss = trainer( |
| y_train_batch_in, |
| x=None, |
| cond_images=X_train_batch_in, |
| unet_number=train_unet_number, |
| max_batch_size = max_batch_size, |
| ) |
| |
| trainer.update(unet_number = train_unet_number) |
|
|
| else: |
| optimizer.zero_grad() |
| if cond_image==False: |
| loss=model ( |
| y_train_batch.unsqueeze(1) , |
| x=X_train_batch, |
| unet_number=train_unet_number |
| ) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| if cond_image==True: |
| loss=model ( |
| y_train_batch_in , |
| x=None, |
| cond_images=X_train_batch_in, |
| unet_number=train_unet_number |
| ) |
| |
| loss.backward( ) |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) |
| optimizer.step() |
|
|
| loss_total=loss_total+loss.item() |
| |
| train_epoch_loss=train_epoch_loss+loss.item() |
|
|
| if steps % print_every == 0: |
| |
| print(".", end="") |
|
|
| |
| if added_steps>0: |
| |
| if steps % print_loss_every_steps == 0: |
| |
| if CKeys['Debug_TrainerPack']==2: |
| print('I am here') |
| print("Here is steps: ", steps) |
| |
| norm_loss=loss_total/print_loss_every_steps |
| print (f"\nTOTAL LOSS at epoch={e+start_ep}, step={steps}: {norm_loss}") |
| |
| |
| add_line = str(e+start_ep)+','+str(steps)+','+str(norm_loss)+'\n' |
| with open(train_hist_file,'a') as f: |
| f.write(add_line) |
|
|
|
|
| loss_list.append (norm_loss) |
| loss_total=0 |
| |
| epoch_list.append(e+start_ep) |
| |
| |
| fig = plt.figure() |
| plt.plot (epoch_list, loss_list, label='Loss') |
| plt.legend() |
|
|
| |
| outname = sample_dir+ f"loss_{e+start_ep}_{steps}.jpg" |
| |
| |
| if CKeys['SlientRun']==1: |
| plt.savefig(outname, dpi=200) |
| else: |
| plt.show() |
| plt.close(fig) |
| |
| |
| if added_steps>0: |
| |
| if steps % sample_steps == 0: |
| |
| if CKeys['Debug_TrainerPack']==2: |
| print('I am here') |
| print("Here is steps: ", steps) |
| |
| if plot_unscaled: |
| |
| plt.plot ( |
| y_train_batch.unsqueeze(1)[0,0,:].cpu().detach().numpy(), |
| label= 'Unscaled GT' |
| ) |
| plt.legend() |
| plt.show() |
|
|
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| print ("<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< ") |
| print ("I. SAMPLING IN TEST SET: ") |
| print ("<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< ") |
| |
| |
| print (f"Producing {num_samples} samples...") |
|
|
| |
| if cond_image == True: |
| use_text_embedd=False |
| |
| |
| |
| cond_scales_extended=cond_scales |
| else: |
| use_text_embedd=True |
| cond_scales_extended=cond_scales |
|
|
| |
| sample_loop_omegafold_pLM_ModelB ( |
| model, |
| test_loader, |
| cond_scales=cond_scales_extended, |
| num_samples=num_samples, |
| timesteps=64, |
| flag=steps, |
| |
| foldproteins=foldproteins, |
| use_text_embedd= use_text_embedd, |
| |
| train_unet_number=train_unet_number, |
| ynormfac=ynormfac, |
| prefix=prefix, |
| tokenizer_y=tokenizer_y, |
| Xnormfac=Xnormfac, |
| tokenizer_X=tokenizer_X, |
| |
| |
| CKeys=CKeys, |
| sample_dir=sample_dir, |
| steps=steps, |
| e=e+start_ep, |
| IF_showfig= CKeys['SlientRun']!=1, |
| |
| pLM_Model=pLM_Model, |
| pLM_Model_Name=pLM_Model_Name, |
| image_channels=image_channels, |
| pLM_alphabet=esm_alphabet, |
| ) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| AH_code=2/Xnormfac |
| BS_code=3/Xnormfac |
| unstr_code= 1/Xnormfac |
|
|
| print ("<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< ") |
| print ("II. SAMPLING FOR DE NOVO:") |
| print ("<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< ") |
| |
| |
| DeNovoSam_pdbs, fasta_file_list=\ |
| sample_sequence_omegafold_pLM_ModelB ( |
| model, |
| x_data=test_condition_list, |
| flag=steps, |
| cond_scales=1., |
| foldproteins=foldproteins, |
| |
| ynormfac=ynormfac, |
| train_unet_number=train_unet_number, |
| tokenizer_X=tokenizer_X, |
| Xnormfac=Xnormfac, |
| max_length=max_length, |
| prefix=prefix, |
| tokenizer_y=tokenizer_y, |
| |
| CKeys=CKeys, |
| sample_dir=sample_dir, |
| steps=steps, |
| e=e+start_ep, |
| IF_showfig= CKeys['SlientRun']!=1, |
| |
| pLM_Model=pLM_Model, |
| pLM_Model_Name=pLM_Model_Name, |
| image_channels=image_channels, |
| pLM_alphabet=esm_alphabet, |
| ) |
| |
| if print_error and len(DeNovoSam_pdbs)>0: |
| print("Calculate SecStr and design error:") |
| |
| for ii in range(len(test_condition_list)): |
| seq=test_condition_list[ii][0] |
| DSSPresult,_,sequence_res=get_DSSP_result(DeNovoSam_pdbs[ii]) |
| print (f"INPUT: {seq}\nRESULT: {DSSPresult}\nAA sequence: {sequence_res}") |
| error=string_diff (DSSPresult, seq)/len (seq) |
| print ("Error: ", error) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| norm_loss_over_e = train_epoch_loss/len(train_loader) |
| print("\nnorm_loss over 1 epoch: ", norm_loss_over_e) |
| |
| |
| add_line = str(e+start_ep)+','+str(steps)+','+str(norm_loss_over_e)+'\n' |
| with open(train_hist_file_full,'a') as f: |
| f.write(add_line) |
| |
| |
| |
| if save_model and (e+start_ep) % save_every_epoch==0 and e>1: |
| |
| fname=f"{store_dir}trainer_save-model-epoch_{e+start_ep}.pt" |
| trainer.save(fname) |
| print (f"Model saved: ", fname) |
| |
| fname=f"{store_dir}statedict_save-model-epoch_{e+start_ep}.pt" |
| torch.save(model.state_dict(), fname) |
| print (f"Statedict model saved: ", fname) |
| |
| top_line='epoch,steps,norm_loss'+'\n' |
| add_line = str(e+start_ep)+','+str(steps)+','+str(norm_loss)+'\n' |
| with open(save_point_info_file, "w") as f: |
| f.write(top_line) |
| f.write(add_line) |
| |
|
|
|
|
| print (f"\n\n-------------------\nTime for epoch {e+start_ep}={(time.time()-start)/60}\n-------------------") |
| |
| |
| |
| |
| def train_loop_Model_B_Predictor ( |
| model, |
| train_loader, |
| test_loader, |
| |
| optimizer=None, |
| print_every=10, |
| epochs= 300, |
| start_ep=0, |
| start_step=0, |
| train_unet_number=1, |
| print_loss_every_steps=1000, |
| |
| trainer=None, |
| plot_unscaled=False, |
| max_batch_size=4, |
| save_model=False, |
| cond_scales=[1.], |
| num_samples=2, |
| foldproteins=False, |
| cond_image=False, |
| |
| |
| device=None, |
| loss_list=[], |
| epoch_list=[], |
| train_hist_file=None, |
| train_hist_file_full=None, |
| prefix=None, |
| Xnormfac=1., |
| ynormfac=1., |
| tokenizer_X=None, |
| tokenizer_y=None, |
| test_condition_list=[], |
| max_length=1, |
| CKeys=None, |
| sample_steps=1, |
| sample_dir=None, |
| save_every_epoch=1, |
| save_point_info_file=None, |
| store_dir=None, |
| |
| pLM_Model_Name=None, |
| image_channels=None, |
| print_error=False, |
| |
| train_hist_file_on_testset=None, |
| ): |
| |
| if not exists (trainer): |
| if not exists (optimizer): |
| print ("ERROR: If trainer not used, need to provide optimizer.") |
| if exists (trainer): |
| print ("Trainer provided... will be used") |
| |
| |
| |
| |
| |
| steps=start_step |
| added_steps=0 |
| |
| loss_total=0 |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| pLM_Model, esm_alphabet, \ |
| esm_layer, len_toks = load_in_pLM( |
| pLM_Model_Name, |
| device, |
| ) |
|
|
| |
| for e in range(1, epochs+1): |
| |
| |
|
|
| torch.cuda.empty_cache() |
| print ("######################################################################################") |
| start = time.time() |
| print ("NOW: Training epoch: ", e+start_ep) |
|
|
| |
| train_epoch_loss = 0 |
| model.train() |
|
|
| print ("Loop over ", len(train_loader), " batches (print . every ", print_every, " steps)") |
|
|
| for item in train_loader: |
| steps += 1 |
| added_steps += 1 |
|
|
| X_train_batch= item[0].to(device) |
| y_train_batch= item[1].to(device) |
| |
| if CKeys["Debug_TrainerPack"]==1: |
| print("Initial unload the dataloader items: ...") |
| print("X_train_batch.dim: ", X_train_batch.shape) |
| print("y_train_batch.dim: ", y_train_batch.shape) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| if pLM_Model_Name=='trivial': |
| |
| y_train_batch_in = y_train_batch.unsqueeze(1) |
| X_train_batch_in = X_train_batch.unsqueeze(1) |
| |
| else: |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| with torch.no_grad(): |
| results = pLM_Model( |
| X_train_batch, |
| repr_layers=[esm_layer], |
| return_contacts=False, |
| ) |
| X_train_batch_in = results["representations"][esm_layer] |
| X_train_batch_in = rearrange( |
| X_train_batch_in, |
| 'b l c -> b c l' |
| ) |
| y_train_batch_in = y_train_batch.unsqueeze(1).repeat(1,image_channels,1) |
| |
|
|
| |
| |
| if CKeys["Debug_TrainerPack"]==1: |
| print("After pLM model, the shape of X and y for training:") |
| print("X_train_batch_in.dim: ", X_train_batch_in.shape) |
| print("y_train_batch_in.dim: ", y_train_batch_in.shape) |
| |
| |
| |
| if exists (trainer): |
| |
| if cond_image==False: |
| loss = trainer( |
| y_train_batch.unsqueeze(1) , |
| x=X_train_batch, |
| unet_number=train_unet_number, |
| max_batch_size = max_batch_size, |
| ) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| if cond_image==True: |
| loss = trainer( |
| y_train_batch_in, |
| x=None, |
| cond_images=X_train_batch_in, |
| unet_number=train_unet_number, |
| max_batch_size = max_batch_size, |
| ) |
| |
| trainer.update(unet_number = train_unet_number) |
|
|
| else: |
| optimizer.zero_grad() |
| if cond_image==False: |
| loss=model ( |
| y_train_batch.unsqueeze(1) , |
| x=X_train_batch, |
| unet_number=train_unet_number |
| ) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| if cond_image==True: |
| loss=model ( |
| y_train_batch_in , |
| x=None, |
| cond_images=X_train_batch_in, |
| unet_number=train_unet_number |
| ) |
| |
| loss.backward( ) |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) |
| optimizer.step() |
|
|
| loss_total=loss_total+loss.item() |
| |
| train_epoch_loss=train_epoch_loss+loss.item() |
|
|
| if steps % print_every == 0: |
| |
| print(".", end="") |
|
|
| |
| if added_steps>0: |
| |
| if steps % print_loss_every_steps == 0: |
| |
| if CKeys['Debug_TrainerPack']==2: |
| print('I am here') |
| print("Here is steps: ", steps) |
| |
| norm_loss=loss_total/print_loss_every_steps |
| print (f"\nTOTAL LOSS at epoch={e+start_ep}, step={steps}: {norm_loss}") |
| |
| |
| add_line = str(e+start_ep)+','+str(steps)+','+str(norm_loss)+'\n' |
| with open(train_hist_file,'a') as f: |
| f.write(add_line) |
|
|
|
|
| loss_list.append (norm_loss) |
| loss_total=0 |
| |
| epoch_list.append(e+start_ep) |
| |
| |
| fig = plt.figure() |
| plt.plot (epoch_list, loss_list, label='Loss') |
| plt.legend() |
|
|
| |
| outname = sample_dir+ f"loss_{e+start_ep}_{steps}.jpg" |
| |
| |
| if CKeys['SlientRun']==1: |
| plt.savefig(outname, dpi=200) |
| else: |
| plt.show() |
| plt.close(fig) |
| |
| |
| if added_steps>0: |
| |
| if steps % sample_steps == 0: |
| |
| if CKeys['Debug_TrainerPack']==2: |
| print('I am here') |
| print("Here is steps: ", steps) |
| |
| if plot_unscaled: |
| |
| plt.plot ( |
| y_train_batch.unsqueeze(1)[0,0,:].cpu().detach().numpy(), |
| label= 'Unscaled GT' |
| ) |
| plt.legend() |
| plt.show() |
|
|
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| print ("<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< ") |
| print ("I. SAMPLING IN TEST SET: ") |
| print ("<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< ") |
| |
| |
| print (f"Producing {num_samples} samples...") |
|
|
| |
| if cond_image == True: |
| use_text_embedd=False |
| |
| |
| |
| cond_scales_extended=cond_scales |
| else: |
| use_text_embedd=True |
| cond_scales_extended=cond_scales |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| val_epoch_MSE_list, \ |
| resu_pred, resu_grou = \ |
| sample_loop_omegafold_pLM_ModelB_Predictor ( |
| model, |
| test_loader, |
| cond_scales=[1.], |
| num_samples=num_samples, |
| timesteps=64, |
| flag=steps, |
| |
| foldproteins=foldproteins, |
| use_text_embedd= use_text_embedd, |
| |
| train_unet_number=train_unet_number, |
| ynormfac=ynormfac, |
| prefix=prefix, |
| tokenizer_y=tokenizer_y, |
| Xnormfac=Xnormfac, |
| tokenizer_X=tokenizer_X, |
| |
| |
| CKeys=CKeys, |
| sample_dir=sample_dir, |
| steps=steps, |
| e=e+start_ep, |
| IF_showfig= CKeys['SlientRun']!=1, |
| |
| pLM_Model=pLM_Model, |
| pLM_Model_Name=pLM_Model_Name, |
| image_channels=image_channels, |
| pLM_alphabet=esm_alphabet, |
| |
| esm_layer=esm_layer, |
| ) |
| |
| print(f"Epo {str(e+start_ep)}, on TestSet, MSE: {val_epoch_MSE_list[0]}") |
| |
| add_line = str(e+start_ep)+','+str(steps)+','+str(val_epoch_MSE_list[0])+'\n' |
| with open(train_hist_file_on_testset,'a') as f: |
| f.write(add_line) |
|
|
|
|
| print ("<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< ") |
| print ("II. SAMPLING FOR DE NOVO: NOT USED in predictor mode") |
| print ("<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< ") |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| norm_loss_over_e = train_epoch_loss/len(train_loader) |
| print("\nnorm_loss over 1 epoch: ", norm_loss_over_e) |
| |
| |
| add_line = str(e+start_ep)+','+str(steps)+','+str(norm_loss_over_e)+'\n' |
| with open(train_hist_file_full,'a') as f: |
| f.write(add_line) |
| |
| |
| |
| if save_model and (e+start_ep) % save_every_epoch==0 and e>1: |
| |
| fname=f"{store_dir}trainer_save-model-epoch_{e+start_ep}.pt" |
| trainer.save(fname) |
| print (f"Model saved: ", fname) |
| |
| fname=f"{store_dir}statedict_save-model-epoch_{e+start_ep}.pt" |
| torch.save(model.state_dict(), fname) |
| print (f"Statedict model saved: ", fname) |
| |
| top_line='epoch,steps,norm_loss'+'\n' |
| add_line = str(e+start_ep)+','+str(steps)+','+str(norm_loss)+'\n' |
| with open(save_point_info_file, "w") as f: |
| f.write(top_line) |
| f.write(add_line) |
| |
|
|
|
|
| print (f"\n\n-------------------\nTime for epoch {e+start_ep}={(time.time()-start)/60}\n-------------------") |
| |
| |
| def train_loop_Old_FromModelB (model, |
| train_loader, |
| test_loader, |
| optimizer=None, |
| print_every=10, |
| epochs= 300, |
| start_ep=0, |
| start_step=0, |
| train_unet_number=1, |
| print_loss=1000, |
| trainer=None, |
| plot_unscaled=False, |
| max_batch_size=4, |
| save_model=False, |
| cond_scales=[7.5], |
| num_samples=2, |
| foldproteins=False, |
| cond_image=False, |
| |
| |
| device=None, |
| loss_list=[], |
| prefix=None, |
| ynormfac=1, |
| test_condition_list=[], |
| tokenizer_y=None, |
| Xnormfac=1, |
| tokenizer_X=None, |
| max_length=1, |
| |
| ): |
| |
| if not exists (trainer): |
| if not exists (optimizer): |
| print ("ERROR: If trainer not used, need to provide optimizer.") |
| if exists (trainer): |
| print ("Trainer provided... will be used") |
| |
| steps=start_step |
| |
| loss_total=0 |
| for e in range(1, epochs+1): |
| |
| start = time.time() |
|
|
| torch.cuda.empty_cache() |
| print ("######################################################################################") |
| start = time.time() |
| print ("NOW: Training epoch: ", e+start_ep) |
|
|
| train_epoch_loss = 0 |
| model.train() |
|
|
| print ("Loop over ", len(train_loader), " batches (print . every ", print_every, " steps)") |
|
|
| for item in train_loader: |
|
|
| X_train_batch= item[0].to(device) |
|
|
| y_train_batch=item[1].to(device) |
|
|
| if exists (trainer): |
| if cond_image==False: |
| loss = trainer( |
| y_train_batch.unsqueeze(1) , |
| x=X_train_batch, |
| unet_number=train_unet_number, |
| max_batch_size = max_batch_size, |
| ) |
| if cond_image==True: |
|
|
| loss = trainer( |
| y_train_batch.unsqueeze(1) ,x=None, |
| cond_images=X_train_batch.unsqueeze(1), |
| unet_number=train_unet_number, |
| max_batch_size = max_batch_size, |
| ) |
| trainer.update(unet_number = train_unet_number) |
|
|
| else: |
| optimizer.zero_grad() |
| if cond_image==False: |
| loss=model (y_train_batch.unsqueeze(1) , x=X_train_batch, unet_number=train_unet_number) |
| if cond_image==True: |
| loss=model (y_train_batch.unsqueeze(1) ,x=None, cond_images=X_train_batch.unsqueeze(1), unet_number=train_unet_number) |
|
|
| loss.backward( ) |
|
|
| torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) |
|
|
| optimizer.step() |
|
|
| loss_total=loss_total+loss.item() |
|
|
| if steps % print_every == 0: |
| print(".", end="") |
|
|
| if steps>0: |
| if steps % print_loss == 0: |
|
|
| if plot_unscaled: |
| |
| plt.plot (y_train_batch.unsqueeze(1)[0,0,:].cpu().detach().numpy(),label= 'Unscaled GT') |
| plt.legend() |
| plt.show() |
|
|
|
|
| GT=y_train_batch.cpu().detach() |
|
|
| GT=resize_image_to( |
| GT.unsqueeze(1), |
| model.imagen.image_sizes[train_unet_number-1], |
|
|
| ) |
|
|
| norm_loss=loss_total/print_loss |
| print (f"\nTOTAL LOSS at epoch={e}, step={steps}: {norm_loss}") |
|
|
| loss_list.append (norm_loss) |
| loss_total=0 |
|
|
| plt.plot (loss_list, label='Loss') |
| plt.legend() |
|
|
| outname = prefix+ f"loss_{e}_{steps}.jpg" |
| plt.savefig(outname, dpi=200) |
| plt.show() |
|
|
| |
| num_samples = min (num_samples,y_train_batch.shape[0] ) |
| print (f"Producing {num_samples} samples...") |
|
|
|
|
| if cond_image == True: |
| use_text_embedd=False |
| else: |
| use_text_embedd=True |
|
|
| sample_loop_FromModelB ( |
| model, |
| test_loader, |
| cond_scales=cond_scales, |
| num_samples=num_samples, |
| timesteps=64, |
| flag=steps, |
| |
| foldproteins=foldproteins, |
| use_text_embedd= use_text_embedd, |
| |
| train_unet_number=train_unet_number, |
| ynormfac=ynormfac, |
| prefix=prefix, |
| tokenizer_y=tokenizer_y, |
| Xnormfac=Xnormfac, |
| tokenizer_X=tokenizer_X, |
| ) |
|
|
| |
| |
|
|
| AH_code=2/Xnormfac |
| BS_code=3/Xnormfac |
| unstr_code= 1/Xnormfac |
|
|
| print ("SAMPLING FOR DE NOVO:") |
| |
| |
| for this_x_data in test_condition_list: |
| sample_sequence_FromModelB ( |
| model, |
| x_data=this_x_data, |
| flag=steps,cond_scales=1., |
| foldproteins=True, |
| |
| ynormfac=ynormfac, |
| train_unet_number=train_unet_number, |
| tokenizer_X=tokenizer_X, |
| Xnormfac=Xnormfac, |
| max_length=max_length, |
| prefix=prefix, |
| tokenizer_y=tokenizer_y, |
| ) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| if steps>0: |
| if save_model and steps % print_loss==0: |
| fname=f"{prefix}trainer_save-model-epoch_{e}.pt" |
| trainer.save(fname) |
| print (f"Model saved: ", fname) |
| fname=f"{prefix}statedict_save-model-epoch_{e}.pt" |
| torch.save(model.state_dict(), fname) |
| print (f"Statedict model saved: ", fname) |
|
|
| steps=steps+1 |
|
|
| print (f"\n\n-------------------\nTime for epoch {e}={(time.time()-start)/60}\n-------------------") |
| |
|
|
| |
| def foldandsavePDB_pdb_fasta ( |
| sequence, |
| filename_out, |
| num_cycle=16, |
| flag=0, |
| |
| prefix=None, |
| ): |
| |
| filename=f"{prefix}fasta_in_{flag}.fasta" |
| print ("Writing FASTA file: ", filename) |
| OUTFILE=f"{filename_out}_{flag}" |
| with open (filename, mode ='w') as f: |
| f.write (f'>{OUTFILE}\n') |
| f.write (f'{sequence}') |
| |
| print (f"Now run OmegaFold.... on device={device}") |
| |
| cmd_line=F"omegafold {filename} {prefix} --num_cycle {num_cycle} --device={device}" |
| print(os.popen(cmd_line).read()) |
| |
| print ("Done OmegaFold") |
| |
| |
| PDB_result=f"{prefix}{OUTFILE}.pdb" |
| print (f"Resulting PDB file...: {PDB_result}") |
| |
| return PDB_result, filename |
|
|
|
|
|
|
| def foldandsavePDB ( |
| sequence, |
| filename_out, |
| num_cycle=16, |
| flag=0, |
| |
| prefix=None, |
| ): |
| |
| filename=f"{prefix}fasta_in_{flag}.fasta" |
| print ("Writing FASTA file: ", filename) |
| OUTFILE=f"{filename_out}_{flag}" |
| with open (filename, mode ='w') as f: |
| f.write (f'>{OUTFILE}\n') |
| f.write (f'{sequence}') |
| |
| print (f"Now run OmegaFold.... on device={device}") |
| |
| cmd_line=F"omegafold {filename} {prefix} --num_cycle {num_cycle} --device={device}" |
| print(os.popen(cmd_line).read()) |
| |
| print ("Done OmegaFold") |
| |
| |
| PDB_result=f"{prefix}{OUTFILE}.pdb" |
| print (f"Resulting PDB file...: {PDB_result}") |
| |
| return PDB_result |
|
|
| import py3Dmol |
| def plot_plddt_legend(dpi=100): |
| thresh = ['plDDT:','Very low (<50)','Low (60)','OK (70)','Confident (80)','Very high (>90)'] |
| plt.figure(figsize=(1,0.1),dpi=dpi) |
| |
| for c in ["#FFFFFF","#FF0000","#FFFF00","#00FF00","#00FFFF","#0000FF"]: |
| plt.bar(0, 0, color=c) |
| plt.legend(thresh, frameon=False, |
| loc='center', ncol=6, |
| handletextpad=1, |
| columnspacing=1, |
| markerscale=0.5,) |
| plt.axis(False) |
| return plt |
| color = "lDDT" |
| show_sidechains = False |
| show_mainchains = False |
|
|
| def show_pdb(pdb_file, flag=0, show_sidechains=False, show_mainchains=False, color="lDDT"): |
| model_name = f"Flag_{flag}" |
| view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js',) |
| view.addModel(open(pdb_file,'r').read(),'pdb') |
|
|
| if color == "lDDT": |
| view.setStyle({'cartoon': {'colorscheme': {'prop':'b','gradient': 'roygb','min':50,'max':90}}}) |
| elif color == "rainbow": |
| view.setStyle({'cartoon': {'color':'spectrum'}}) |
| elif color == "chain": |
| chains = len(queries[0][1]) + 1 if is_complex else 1 |
| for n,chain,color in zip(range(chains),list("ABCDEFGH"), |
| ["lime","cyan","magenta","yellow","salmon","white","blue","orange"]): |
| view.setStyle({'chain':chain},{'cartoon': {'color':color}}) |
| if show_sidechains: |
| BB = ['C','O','N'] |
| view.addStyle({'and':[{'resn':["GLY","PRO"],'invert':True},{'atom':BB,'invert':True}]}, |
| {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}}) |
| view.addStyle({'and':[{'resn':"GLY"},{'atom':'CA'}]}, |
| {'sphere':{'colorscheme':f"WhiteCarbon",'radius':0.3}}) |
| view.addStyle({'and':[{'resn':"PRO"},{'atom':['C','O'],'invert':True}]}, |
| {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}}) |
| if show_mainchains: |
| BB = ['C','O','N','CA'] |
| view.addStyle({'atom':BB},{'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}}) |
|
|
| view.zoomTo() |
| if color == "lDDT": |
| plot_plddt_legend().show() |
| return view |
|
|
| def get_avg_Bfac (file='./output_v3/[0.0 0.5 0.0 0.0 0.0 0.0 0.0 0.0].pdb'): |
| p = PDBParser() |
| avg_B=0 |
| bfac_list=[] |
| |
| structure = p.get_structure("X", file) |
| for PDBmodel in structure: |
| for chain in PDBmodel: |
| for residue in chain: |
| for atom in residue: |
| |
| Bfac=atom.get_bfactor() |
| bfac_list.append(Bfac) |
| avg_B=avg_B+Bfac |
| |
| avg_B=avg_B/len (bfac_list) |
| print (f"For {file}, average B-factor={avg_B}") |
| plt.plot (bfac_list, label='lDDT') |
| plt.xlabel ('Atom #' ) |
| plt.ylabel ('iDDT') |
| plt.legend() |
| plt.show() |
| return avg_B, bfac_list |
|
|
| def sample_sequence_normalized_Bfac (seccs=[0.3, 0.3, 0.1, 0., 0., 0., 0., 0. ]): |
| sample_numbers=torch.tensor([seccs]) |
| sample_numbers=torch.nn.functional.normalize (sample_numbers, dim=1) |
| sample_numbers=sample_numbers/torch.sum(sample_numbers) |
|
|
| PDB=sample_sequence (model, |
| X=sample_numbers, |
| flag=0,cond_scales=1, foldproteins=True, |
| ) |
|
|
| avg,_ = get_avg_Bfac (file=PDB[0]) |
|
|
| return PDB, avg |
|
|
| |
| |
| |
| def train_loop_Old_FromModelA ( |
| model, |
| train_loader, |
| test_loader, |
| |
| optimizer=None, |
| print_every=1, |
| epochs= 300, |
| start_ep=0, |
| start_step=0, |
| train_unet_number=1, |
| print_loss_every_steps=1000, |
| |
| trainer=None, |
| plot_unscaled=False, |
| max_batch_size=4, |
| save_model=False, |
| cond_scales=[1.0], |
| num_samples=2, |
| foldproteins=False, |
| |
| cond_image=False, |
| cond_text=True, |
| |
| device=None, |
| loss_list=[], |
| epoch_list=[], |
| train_hist_file=None, |
| train_hist_file_full=None, |
| prefix=None, |
| Xnormfac=None, |
| ynormfac=1., |
| tokenizer_X=None, |
| tokenizer_y=None, |
| test_condition_list=[], |
| max_length_Y=1, |
| max_text_len_X=1, |
| CKeys=None, |
| sample_steps=1, |
| sample_dir=None, |
| save_every_epoch=1, |
| save_point_info_file=None, |
| store_dir=None, |
| ): |
| |
| |
| |
| if not exists (trainer): |
| if not exists (optimizer): |
| print ("ERROR: If trainer not used, need to provide optimizer.") |
| if exists (trainer): |
| print ("Trainer provided... will be used") |
| |
| |
| |
| steps=start_step |
| added_steps=0 |
|
|
| loss_total=0 |
| for e in range(1, epochs+1): |
| |
|
|
| torch.cuda.empty_cache() |
| print ("######################################################################################") |
| start = time.time() |
| print ("NOW: Training epoch: ", e+start_ep) |
|
|
| |
| train_epoch_loss = 0 |
| model.train() |
|
|
| print ("Loop over ", len(train_loader), " batches (print . every ", print_every, " steps)") |
|
|
| for item in train_loader: |
| |
| steps += 1 |
| added_steps += 1 |
|
|
| X_train_batch= item[0].to(device) |
| y_train_batch=item[1].to(device) |
|
|
| if exists (trainer): |
| if cond_image==False: |
| |
| |
| |
| |
| loss = trainer( |
| |
| |
| |
| |
| y_train_batch.unsqueeze(1) , |
| x=X_train_batch, |
| |
| unet_number=train_unet_number, |
| max_batch_size = max_batch_size, |
| ) |
| if cond_image==True: |
| |
| |
| |
| |
| loss = trainer( |
| y_train_batch.unsqueeze(1) , |
| x=None, |
| cond_images=X_train_batch.unsqueeze(1), |
| unet_number=train_unet_number, |
| max_batch_size = max_batch_size, |
| ) |
| |
| |
| trainer.update(unet_number = train_unet_number) |
|
|
| else: |
| optimizer.zero_grad() |
| if cond_image==False: |
| |
| loss=model ( |
| |
| |
| |
| |
| y_train_batch.unsqueeze(1) , |
| x=X_train_batch, |
| |
| unet_number=train_unet_number |
| ) |
| if cond_image==True: |
| |
| loss=model ( |
| y_train_batch.unsqueeze(1) , |
| x=None, |
| cond_images=X_train_batch.unsqueeze(1), |
| unet_number=train_unet_number |
| ) |
| |
| loss.backward( ) |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) |
| optimizer.step() |
|
|
| loss_total=loss_total+loss.item() |
| |
| train_epoch_loss=train_epoch_loss+loss.item() |
|
|
| if steps % print_every == 0: |
| |
| print(".", end="") |
|
|
| |
| |
| |
| |
| if added_steps>0: |
|
|
| if steps % print_loss_every_steps == 0: |
| |
| if CKeys['Debug_TrainerPack']==2: |
| print("Here is step: ", steps) |
|
|
| norm_loss=loss_total/print_loss_every_steps |
| print (f"\nTOTAL LOSS at epoch={e+start_ep}, step={steps}: {norm_loss}") |
| |
| |
| add_line = str(e+start_ep)+','+str(steps)+','+str(norm_loss)+'\n' |
| with open(train_hist_file,'a') as f: |
| f.write(add_line) |
|
|
| loss_list.append (norm_loss) |
| loss_total=0 |
| |
| epoch_list.append(e+start_ep) |
|
|
| fig = plt.figure() |
| plt.plot (epoch_list, loss_list, label='Loss') |
| plt.legend() |
| |
| outname = sample_dir+ f"loss_{e+start_ep}_{steps}.jpg" |
| |
| |
| if CKeys['SlientRun']==1: |
| plt.savefig(outname, dpi=200) |
| else: |
| plt.show() |
| plt.close(fig) |
|
|
| |
| |
| |
| |
| |
| if added_steps>0: |
| if steps % sample_steps == 0 and sample_steps > 0: |
| |
| if CKeys['Debug_TrainerPack']==2: |
| print("Here is steps: ", steps) |
|
|
| if plot_unscaled: |
| |
| plt.plot ( |
| y_train_batch.unsqueeze(1)[0,0,:].cpu().detach().numpy(), |
| label= 'Unscaled GT' |
| ) |
| plt.legend() |
| plt.show() |
|
|
| |
| GT=y_train_batch.cpu().detach() |
|
|
| GT=resize_image_to( |
| GT.unsqueeze(1), |
| model.imagen.image_sizes[train_unet_number-1], |
|
|
| ) |
| |
| print ("I. SAMPLING IN TEST SET: ") |
| |
|
|
| num_samples = min (num_samples,y_train_batch.shape[0] ) |
| print (f"Producing {num_samples} samples...") |
|
|
| sample_loop_omegafold_ModelA ( |
| model, |
| test_loader, |
| cond_scales=cond_scales, |
| num_samples=num_samples, |
| timesteps=None, |
| flag=e+start_ep, |
| foldproteins=foldproteins, |
| |
| cond_image=cond_image, |
| cond_text=cond_text, |
| skip_steps=0, |
| |
| max_text_len=max_text_len_X, |
| max_length=max_length_Y, |
| |
| train_unet_number=train_unet_number, |
| ynormfac=ynormfac, |
| prefix=prefix, |
| tokenizer_y=tokenizer_y, |
| Xnormfac_CondiText=Xnormfac, |
| tokenizer_X_CondiText=tokenizer_X, |
| |
| CKeys=CKeys, |
| sample_dir=sample_dir, |
| steps=steps, |
| e=e+start_ep, |
| IF_showfig= CKeys['SlientRun']!=1 , |
| ) |
|
|
| print ("II. SAMPLING FOR DE NOVO:") |
|
|
| sample_sequence_omegafold_ModelA ( |
| |
| |
| |
| |
| |
| |
| model, |
| X=test_condition_list, |
| flag=e+start_ep, |
| cond_scales=cond_scales, |
| foldproteins=True, |
| X_string=None, |
| x_data=None, |
| skip_steps=0, |
| inpaint_images=None, |
| inpaint_masks = None, |
| inpaint_resample_times = None, |
| init_images = None, |
| num_cycle=16, |
| calc_error=False, |
| |
| |
| tokenizer_X_forImageCondi=None, |
| Xnormfac_forImageCondi=1., |
| tokenizer_X_forTextCondi=None, |
| Xnormfac_forTextCondi=1., |
| tokenizer_y=tokenizer_y, |
| ynormfac=ynormfac, |
| |
| train_unet_number=1, |
| max_length_Y=max_length_Y, |
| max_text_len=max_text_len_X, |
| |
| steps=steps, |
| e=e, |
| sample_dir=sample_dir, |
| prefix=prefix, |
| IF_showfig= CKeys['SlientRun']!=1, |
| CKeys=CKeys, |
| |
| normalize_X_cond_to_one=False, |
| ) |
|
|
| |
| |
| |
| |
|
|
| |
| norm_loss_over_e = train_epoch_loss/len(train_loader) |
| print("\nnorm_loss over 1 epoch: ", norm_loss_over_e) |
| |
| |
| add_line = str(e+start_ep)+','+str(steps)+','+str(norm_loss_over_e)+'\n' |
| with open(train_hist_file_full,'a') as f: |
| f.write(add_line) |
| |
| |
| |
| if save_model and (e+start_ep) % save_every_epoch==0 and e>1: |
| |
| fname=f"{store_dir}trainer_save-model-epoch_{e+start_ep}.pt" |
| trainer.save(fname) |
| print (f"Model saved: ", fname) |
| |
| fname=f"{store_dir}statedict_save-model-epoch_{e+start_ep}.pt" |
| torch.save(model.state_dict(), fname) |
| print (f"Statedict model saved: ", fname) |
| |
| top_line='epoch,steps,norm_loss'+'\n' |
| add_line = str(e+start_ep)+','+str(steps)+','+str(norm_loss)+'\n' |
| with open(save_point_info_file, "w") as f: |
| f.write(top_line) |
| f.write(add_line) |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
|
|
| print (f"\n\n-------------------\nTime for epoch {e+start_ep}={(time.time()-start)/60}\n-------------------") |
| |
| def train_loop_ForModelA_II ( |
| model, |
| train_loader, |
| test_loader, |
| |
| optimizer=None, |
| print_every=1, |
| epochs= 300, |
| start_ep=0, |
| start_step=0, |
| train_unet_number=1, |
| print_loss_every_steps=1000, |
| |
| trainer=None, |
| plot_unscaled=False, |
| max_batch_size=4, |
| save_model=False, |
| cond_scales=[1.0], |
| num_samples=2, |
| foldproteins=False, |
| |
| cond_image=False, |
| cond_text=True, |
| |
| device=None, |
| loss_list=[], |
| epoch_list=[], |
| train_hist_file=None, |
| train_hist_file_full=None, |
| prefix=None, |
| Xnormfac=None, |
| ynormfac=1., |
| tokenizer_X=None, |
| tokenizer_y=None, |
| test_condition_list=[], |
| max_length_Y=1, |
| max_text_len_X=1, |
| CKeys=None, |
| sample_steps=1, |
| sample_dir=None, |
| save_every_epoch=1, |
| save_point_info_file=None, |
| store_dir=None, |
| |
| pLM_Model_Name=None, |
| image_channels=None, |
| print_error=False, |
| ): |
| |
| |
| |
| if not exists (trainer): |
| if not exists (optimizer): |
| print ("ERROR: If trainer not used, need to provide optimizer.") |
| if exists (trainer): |
| print ("Trainer provided... will be used") |
| |
| |
| |
| steps=start_step |
| added_steps=0 |
|
|
| loss_total=0 |
| |
| |
| if pLM_Model_Name=='None': |
| pLM_Model=None |
| |
| elif pLM_Model_Name=='esm2_t33_650M_UR50D': |
| |
| esm_layer=33 |
| pLM_Model, esm_alphabet = esm.pretrained.esm2_t33_650M_UR50D() |
| len_toks=len(esm_alphabet.all_toks) |
| pLM_Model.eval() |
| pLM_Model. to(device) |
| |
| elif pLM_Model_Name=='esm2_t36_3B_UR50D': |
| |
| esm_layer=36 |
| pLM_Model, esm_alphabet = esm.pretrained.esm2_t36_3B_UR50D() |
| len_toks=len(esm_alphabet.all_toks) |
| pLM_Model.eval() |
| pLM_Model. to(device) |
| |
| elif pLM_Model_Name=='esm2_t30_150M_UR50D': |
| |
| esm_layer=30 |
| pLM_Model, esm_alphabet = esm.pretrained.esm2_t30_150M_UR50D() |
| len_toks=len(esm_alphabet.all_toks) |
| pLM_Model.eval() |
| pLM_Model. to(device) |
| |
| elif pLM_Model_Name=='esm2_t12_35M_UR50D': |
| |
| esm_layer=12 |
| pLM_Model, esm_alphabet = esm.pretrained.esm2_t12_35M_UR50D() |
| len_toks=len(esm_alphabet.all_toks) |
| pLM_Model.eval() |
| pLM_Model. to(device) |
| |
| else: |
| print("pLM model is missing...") |
| |
| |
| for e in range(1, epochs+1): |
| |
|
|
| torch.cuda.empty_cache() |
| print ("######################################################################################") |
| start = time.time() |
| print ("NOW: Training epoch: ", e+start_ep) |
|
|
| |
| train_epoch_loss = 0 |
| model.train() |
|
|
| print ("Loop over ", len(train_loader), " batches (print . every ", print_every, " steps)") |
|
|
| for item in train_loader: |
| |
| steps += 1 |
| added_steps += 1 |
|
|
| X_train_batch= item[0].to(device) |
| y_train_batch=item[1].to(device) |
| |
| if CKeys["Debug_TrainerPack"]==1: |
| print("Initial unload the dataloader items: ...") |
| print(X_train_batch.shape) |
| print(y_train_batch.shape) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| if pLM_Model_Name=='None': |
| |
| y_train_batch_in = y_train_batch.unsqueeze(1) |
| |
| else: |
| with torch.no_grad(): |
| results = pLM_Model( |
| y_train_batch, |
| repr_layers=[esm_layer], |
| return_contacts=False, |
| ) |
| y_train_batch_in = results["representations"][esm_layer] |
| y_train_batch_in = rearrange( |
| y_train_batch_in, |
| 'b l c -> b c l' |
| ) |
| |
| |
| |
| |
| if cond_image==False: |
| |
| X_train_batch_in = X_train_batch |
| else: |
| |
| X_train_batch_in = X_train_batch.unsqueeze(1).repeat(1,image_channels,1) |
| |
| |
| if CKeys["Debug_TrainerPack"]==1: |
| print("After pLM model, the shape of X and y for training:") |
| print("X_train_batch_in.dim: ", X_train_batch_in.shape) |
| print("y_train_batch_in.dim: ", y_train_batch_in.shape) |
| |
| |
| |
|
|
| if exists (trainer): |
| if cond_image==False: |
| |
| |
| |
| |
| loss = trainer( |
| |
| |
| |
| |
| |
| |
| |
| y_train_batch_in, |
| x=X_train_batch_in, |
| |
| unet_number=train_unet_number, |
| max_batch_size = max_batch_size, |
| ) |
| if cond_image==True: |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| loss = trainer( |
| y_train_batch_in, |
| x=None, |
| cond_images=X_train_batch_in, |
| unet_number=train_unet_number, |
| max_batch_size = max_batch_size, |
| ) |
| |
| |
| trainer.update(unet_number = train_unet_number) |
|
|
| else: |
| optimizer.zero_grad() |
| if cond_image==False: |
| |
| loss=model ( |
| |
| |
| |
| |
| |
| |
| |
| y_train_batch_in, |
| x=X_train_batch_in, |
| |
| unet_number=train_unet_number |
| ) |
| if cond_image==True: |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| loss=model ( |
| y_train_batch_in , |
| x=None, |
| cond_images=X_train_batch_in, |
| unet_number=train_unet_number |
| ) |
| |
| loss.backward( ) |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) |
| optimizer.step() |
|
|
| loss_total=loss_total+loss.item() |
| |
| train_epoch_loss=train_epoch_loss+loss.item() |
|
|
| if steps % print_every == 0: |
| |
| print(".", end="") |
|
|
| |
| |
| |
| |
| if added_steps>0: |
|
|
| if steps % print_loss_every_steps == 0: |
| |
| if CKeys['Debug_TrainerPack']==2: |
| print("Here is step: ", steps) |
|
|
| norm_loss=loss_total/print_loss_every_steps |
| print (f"\nTOTAL LOSS at epoch={e+start_ep}, step={steps}: {norm_loss}") |
| |
| |
| add_line = str(e+start_ep)+','+str(steps)+','+str(norm_loss)+'\n' |
| with open(train_hist_file,'a') as f: |
| f.write(add_line) |
|
|
| loss_list.append (norm_loss) |
| loss_total=0 |
| |
| epoch_list.append(e+start_ep) |
|
|
| fig = plt.figure() |
| plt.plot (epoch_list, loss_list, label='Loss') |
| plt.legend() |
| |
| outname = sample_dir+ f"loss_{e+start_ep}_{steps}.jpg" |
| |
| |
| if CKeys['SlientRun']==1: |
| plt.savefig(outname, dpi=200) |
| else: |
| plt.show() |
| plt.close(fig) |
|
|
| |
| |
| |
| |
| |
| if added_steps>0: |
| if steps % sample_steps == 0 and sample_steps > 0: |
| |
| if CKeys['Debug_TrainerPack']==2: |
| print("Here is steps: ", steps) |
|
|
| if plot_unscaled: |
| |
| plt.plot ( |
| y_train_batch.unsqueeze(1)[0,0,:].cpu().detach().numpy(), |
| label= 'Unscaled GT' |
| ) |
| plt.legend() |
| plt.show() |
|
|
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| print (">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> ") |
| print ("I. SAMPLING IN TEST SET: ") |
| print ("<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< ") |
| |
|
|
| num_samples = min (num_samples,y_train_batch.shape[0] ) |
| print (f"Producing {num_samples} samples...") |
|
|
| sample_loop_omegafold_pLM_ModelA ( |
| model, |
| test_loader, |
| cond_scales=cond_scales, |
| num_samples=num_samples, |
| timesteps=None, |
| flag=e+start_ep, |
| foldproteins=foldproteins, |
| |
| cond_image=cond_image, |
| cond_text=cond_text, |
| skip_steps=0, |
| |
| max_text_len=max_text_len_X, |
| max_length=max_length_Y, |
| |
| train_unet_number=train_unet_number, |
| ynormfac=ynormfac, |
| prefix=prefix, |
| tokenizer_y=tokenizer_y, |
| Xnormfac_CondiText=Xnormfac, |
| tokenizer_X_CondiText=tokenizer_X, |
| |
| CKeys=CKeys, |
| sample_dir=sample_dir, |
| steps=steps, |
| e=e+start_ep, |
| IF_showfig= CKeys['SlientRun']!=1 , |
| |
| pLM_Model=pLM_Model, |
| pLM_Model_Name=pLM_Model_Name, |
| image_channels=image_channels, |
| pLM_alphabet=esm_alphabet, |
| ) |
|
|
| print (">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> ") |
| print ("II. SAMPLING FOR DE NOVO:") |
| print ("<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< ") |
|
|
| DeNovoSam_pdbs, fasta_file_list=\ |
| sample_sequence_omegafold_pLM_ModelA ( |
| |
| |
| |
| |
| |
| |
| model, |
| X=test_condition_list, |
| flag=e+start_ep, |
| cond_scales=cond_scales, |
| foldproteins=True, |
| X_string=None, |
| x_data=None, |
| skip_steps=0, |
| inpaint_images=None, |
| inpaint_masks = None, |
| inpaint_resample_times = None, |
| init_images = None, |
| num_cycle=16, |
| calc_error=False, |
| |
| |
| tokenizer_X_forImageCondi=None, |
| Xnormfac_forImageCondi=1., |
| tokenizer_X_forTextCondi=None, |
| Xnormfac_forTextCondi=1., |
| tokenizer_y=tokenizer_y, |
| ynormfac=ynormfac, |
| |
| train_unet_number=1, |
| max_length_Y=max_length_Y, |
| max_text_len=max_text_len_X, |
| |
| steps=steps, |
| e=e, |
| sample_dir=sample_dir, |
| prefix=prefix, |
| IF_showfig= CKeys['SlientRun']!=1, |
| CKeys=CKeys, |
| |
| normalize_X_cond_to_one=False, |
| |
| pLM_Model=pLM_Model, |
| pLM_Model_Name=pLM_Model_Name, |
| image_channels=image_channels, |
| pLM_alphabet=esm_alphabet, |
| ) |
|
|
| |
| |
| |
| |
|
|
| |
| norm_loss_over_e = train_epoch_loss/len(train_loader) |
| print("\nnorm_loss over 1 epoch: ", norm_loss_over_e) |
| |
| |
| add_line = str(e+start_ep)+','+str(steps)+','+str(norm_loss_over_e)+'\n' |
| with open(train_hist_file_full,'a') as f: |
| f.write(add_line) |
| |
| |
| |
| if save_model and (e+start_ep) % save_every_epoch==0 and e>1: |
| |
| fname=f"{store_dir}trainer_save-model-epoch_{e+start_ep}.pt" |
| trainer.save(fname) |
| print (f"Model saved: ", fname) |
| |
| fname=f"{store_dir}statedict_save-model-epoch_{e+start_ep}.pt" |
| torch.save(model.state_dict(), fname) |
| print (f"Statedict model saved: ", fname) |
| |
| top_line='epoch,steps,norm_loss'+'\n' |
| add_line = str(e+start_ep)+','+str(steps)+','+str(norm_loss)+'\n' |
| with open(save_point_info_file, "w") as f: |
| f.write(top_line) |
| f.write(add_line) |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
|
|
| print (f"\n\n-------------------\nTime for epoch {e+start_ep}={(time.time()-start)/60}\n-------------------") |
|
|
| |
| def train_loop_Model_A ( |
| model, |
| train_loader, |
| test_loader, |
| optimizer=None, |
| print_every=10, |
| epochs= 300, |
| start_ep=0, |
| start_step=0, |
| train_unet_number=1, |
| print_loss=1000, |
| trainer=None, |
| plot_unscaled=False, |
| max_batch_size=4, |
| save_model=False, |
| cond_scales=[1.0], |
| num_samples=2, |
| foldproteins=False, |
| ): |
| |
| |
| if not exists (trainer): |
| if not exists (optimizer): |
| print ("ERROR: If trainer not used, need to provide optimizer.") |
| if exists (trainer): |
| print ("Trainer provided... will be used") |
| steps=start_step |
|
|
| loss_total=0 |
| for e in range(1, epochs+1): |
| start = time.time() |
|
|
| torch.cuda.empty_cache() |
| print ("######################################################################################") |
| start = time.time() |
| print ("NOW: Training epoch: ", e+start_ep) |
|
|
| |
| train_epoch_loss = 0 |
| model.train() |
| |
| print ("Loop over ", len(train_loader), " batches (print . every ", print_every, " steps)") |
| for item in train_loader: |
| X_train_batch= item[0].to(device) |
| y_train_batch=item[1].to(device) |
|
|
| if exists (trainer): |
| loss = trainer( |
| X_train_batch, y_train_batch.unsqueeze(1) , |
| unet_number=train_unet_number, |
| max_batch_size = max_batch_size, |
| ) |
| trainer.update(unet_number = train_unet_number) |
|
|
| else: |
| optimizer.zero_grad() |
| loss=model ( X_train_batch, y_train_batch.unsqueeze(1) ,unet_number=train_unet_number) |
| loss.backward( ) |
| |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) |
|
|
| optimizer.step() |
|
|
| loss_total=loss_total+loss.item() |
| |
| if steps % print_every == 0: |
| print(".", end="") |
|
|
| if steps>0: |
| if steps % print_loss == 0: |
|
|
| if plot_unscaled: |
| |
| plt.plot (y_train_batch.unsqueeze(1)[0,0,:].cpu().detach().numpy(),label= 'Unscaled GT') |
| plt.legend() |
| plt.show() |
| |
| |
| GT=y_train_batch.cpu().detach() |
| |
| GT=resize_image_to( |
| GT.unsqueeze(1), |
| model.imagen.image_sizes[train_unet_number-1], |
|
|
| ) |
| |
| norm_loss=loss_total/print_loss |
| print (f"\nTOTAL LOSS at epoch={e}, step={steps}: {norm_loss}") |
|
|
| loss_list.append (norm_loss) |
| loss_total=0 |
|
|
| plt.plot (loss_list, label='Loss') |
| plt.legend() |
|
|
| outname = prefix+ f"loss_{e}_{steps}.jpg" |
| plt.savefig(outname, dpi=200) |
| plt.show() |
| |
| num_samples = min (num_samples,y_train_batch.shape[0] ) |
| print (f"Producing {num_samples} samples...") |
| |
| sample_loop (model, |
| test_loader, |
| cond_scales=cond_scales, |
| num_samples=1, |
| timesteps=64, |
| flag=steps,foldproteins=foldproteins, |
| ) |
| |
| print ("SAMPLING FOR DE NOVO:") |
| sample_sequence (model, |
| X=[[0, 0.7, 0.07, 0.1, 0.01, 0.02, 0.01, 0.11]],foldproteins=foldproteins, |
| flag=steps,cond_scales=1., |
| ) |
| sample_sequence (model, |
| X=[[0., 0.0, 0.0, 0.0, 0., 0., 0., 0., ]],foldproteins=foldproteins, |
| flag=steps,cond_scales=1., |
| ) |
|
|
| if steps>0: |
| if save_model and steps % print_loss==0: |
| fname=f"{prefix}trainer_save-model-epoch_{e}.pt" |
| trainer.save(fname) |
| fname=f"{prefix}statedict_save-model-epoch_{e}.pt" |
| torch.save(model.state_dict(), fname) |
| print (f"Model saved: ") |
| |
| steps=steps+1 |
| |
| print (f"\n\n-------------------\nTime for epoch {e}={(time.time()-start)/60}\n-------------------") |
|
|
| |
| def sample_sequence_omegafold_ModelA ( |
| model, |
| X=[[0.92, 0., 0.04, 0.04, 0., 0., 0., 0., ]], |
| flag=0, |
| cond_scales=1., |
| foldproteins=False, |
| X_string=None, |
| x_data=None, |
| skip_steps=0, |
| inpaint_images=None, |
| inpaint_masks = None, |
| inpaint_resample_times = None, |
| init_images = None, |
| num_cycle=16, |
| calc_error=False, |
| |
| |
| tokenizer_X_forImageCondi=None, |
| Xnormfac_forImageCondi=1., |
| tokenizer_X_forTextCondi=None, |
| Xnormfac_forTextCondi=1., |
| tokenizer_y=None, |
| ynormfac=1, |
| |
| train_unet_number=1, |
| max_length_Y=1, |
| max_text_len=1, |
| |
| steps=None, |
| e=None, |
| sample_dir=None, |
| prefix=None, |
| IF_showfig=True, |
| CKeys=None, |
| |
| normalize_X_cond_to_one=False, |
| ): |
| |
| |
| |
|
|
| |
| |
| |
| if X!=None: |
| print (f"Producing {len(X)} samples...from text conditioning X...") |
| lenn_val=len(X) |
| if X_string!=None: |
| lenn_val=len(X_string) |
| print (f"Producing {len(X_string)} samples...from text conditioning X_String (from string)...") |
| if x_data!=None: |
| print (f"Producing {len(x_data)} samples...from image conditingig x_data ...") |
| lenn_val=len(x_data) |
| |
| |
| print ('Device: ', model.device) |
| |
| |
| for iisample in range (lenn_val): |
| print(f"Working on {iisample}") |
| X_cond=None |
| |
| if X_string==None and X!=None: |
| X_cond=torch.Tensor (X[iisample]).to(device).unsqueeze (0) |
| if X_string!=None: |
| |
| |
| |
| |
| |
| |
| XX = tokenizer_X_forTextCondi.texts_to_sequences(X_string[iisample]) |
| XX = sequence.pad_sequences(XX, maxlen=max_text_len, padding='post', truncating='post') |
| XX = np.array(XX) |
| X_cond = torch.from_numpy(XX).float()/Xnormfac_forTextCondi |
| |
| print ('Tokenized and processed: ', X_cond) |
| |
| if X_cond!=None: |
| if normalize_X_cond_to_one: |
| X_cond=X_cond/X_cond.sum() |
| |
| print ("Text conditoning used: ", X_cond, "...sum: ", X_cond.sum(), "cond scale: ", cond_scales) |
| else: |
| print ("Text conditioning used: None") |
| |
| |
| if tokenizer_X_forImageCondi==None: |
| |
| |
| |
| |
| if x_data!=None: |
| x_data_tokenized=torch.from_numpy(x_data[iisample]/Xnormfac_forImageCondi) |
| x_data_tokenized=x_data_tokenized.to(torch.float) |
| |
| if CKeys['Debug_TrainerPack']==1: |
| print("x_data_tokenized dim: ", x_data_tokenized.shape) |
| print("x_data_tokenized dtype: ", x_data_tokenized.dtype) |
| print("test: ", x_data_tokenized!=None) |
| else: |
| x_data_tokenized=None |
| |
| if CKeys['Debug_TrainerPack']==1: |
| print("x_data_tokenized and x_data: None") |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| result=model.sample ( |
| x=X_cond, |
| stop_at_unet_number=train_unet_number , |
| cond_scale=cond_scales , |
| x_data=None, |
| |
| x_data_tokenized=x_data_tokenized, |
| |
| skip_steps=skip_steps, |
| inpaint_images = inpaint_images, |
| inpaint_masks = inpaint_masks, |
| inpaint_resample_times = inpaint_resample_times, |
| init_images = init_images, |
| device=model.device, |
| |
| tokenizer_X=tokenizer_X_forImageCondi, |
| Xnormfac=Xnormfac_forImageCondi, |
| |
| max_length=max_length_Y, |
| max_text_len=max_text_len, |
| ) |
| else: |
| |
| result=model.sample ( |
| x=X_cond, |
| stop_at_unet_number=train_unet_number , |
| cond_scale=cond_scales , |
| x_data=x_data[iisample], |
| |
| x_data_tokenized=None, |
| |
| skip_steps=skip_steps, |
| inpaint_images = inpaint_images, |
| inpaint_masks = inpaint_masks, |
| inpaint_resample_times = inpaint_resample_times, |
| init_images = init_images, |
| device=model.device, |
| |
| tokenizer_X=tokenizer_X_forImageCondi, |
| Xnormfac=Xnormfac_forImageCondi, |
| |
| max_length=max_length_Y, |
| max_text_len=max_text_len, |
| ) |
| |
| |
| |
| |
| |
| |
| |
| |
| result=torch.round(result*ynormfac) |
| |
| print("result.dim: ", result.shape) |
| |
| fig=plt.figure() |
| plt.plot ( |
| result[0,0,:].cpu().detach().numpy(), |
| label= f'Predicted' |
| ) |
| |
| plt.legend() |
| outname = sample_dir+ f"sampled_from_X_{iisample}_condscale-{str (cond_scales)}_{e}_{steps}.jpg" |
| |
| if IF_showfig==1: |
| plt.show () |
| else: |
| plt.savefig(outname, dpi=200) |
| plt.close() |
| |
| |
| |
| |
| |
| |
| |
|
|
| to_rev=result[:,0,:] |
| to_rev=to_rev.long().cpu().detach().numpy() |
| print("to_rev.dim: ", to_rev.shape) |
| y_data_reversed=tokenizer_y.sequences_to_texts (to_rev) |
|
|
| for iii in range (len(y_data_reversed)): |
| y_data_reversed[iii]=y_data_reversed[iii].upper().strip().replace(" ", "") |
| |
| |
| |
| pdb_list=[] |
| if X_cond != None: |
| |
| if X_string!=None: |
| X_cond=torch.round(X_cond*Xnormfac_forTextCondi) |
|
|
| to_rev=X_cond[:,:] |
| to_rev=to_rev.long().cpu().detach().numpy() |
| print ("to_rev.dim: ", to_rev.shape) |
| |
| |
| |
| X_text_reversed=tokenizer_X_forTextCondi.sequences_to_texts (to_rev) |
| for iii in range (len(y_text_reversed)): |
| X_text_reversed[iii]=X_text_reversed[iii].upper().strip().replace(" ", "") |
| |
| if X_string==None: |
| |
| X_text_reversed=X_cond |
| else: |
| X_text_reversed=None |
| |
| if x_data !=None: |
| x_data_reversed=x_data |
| else: |
| x_data_reversed=None |
| |
| |
| |
| print (f"For {X_text_reversed} or {X[iisample]} on Text_Condi,\n and {x_data_reversed} on Image_Condi,") |
| print (f"predicted sequence full: {y_data_reversed}") |
| |
| print (f"predicted sequence: {y_data_reversed[0]}") |
| |
| |
| print("================================================") |
| print("foldproteins: ", foldproteins) |
| |
| if not foldproteins: |
| pdb_file=None |
| else: |
| |
| |
| if X_cond != None: |
| pass |
| |
| tempname='temp' |
| pdb_file=foldandsavePDB ( |
| sequence=y_data_reversed[0], |
| filename_out=tempname, |
| num_cycle=num_cycle, |
| flag=flag, |
| |
| |
| prefix=sample_dir, |
| ) |
| |
| out_nam=iisample |
| |
| out_nam_fasta=f'{sample_dir}DeNovoSampling_{iisample}_epo_{e}_step_{steps}.fasta' |
| write_fasta (y_data_reversed[0], out_nam_fasta) |
| |
| out_nam=f'{sample_dir}DeNovoSampling_{iisample}_epo_{e}_step_{steps}.pdb' |
| shutil.copy (pdb_file, out_nam) |
| pdb_file=out_nam |
| |
| print (f"Properly named PDB file produced: {pdb_file}") |
| if IF_showfig==1: |
| |
| view=show_pdb( |
| pdb_file=pdb_file, |
| flag=flag, |
| show_sidechains=show_sidechains, |
| show_mainchains=show_mainchains, |
| color=color |
| ) |
| view.show() |
| |
| if calc_error: |
| |
| if CKeys['Problem_ID']==7: |
| get_Model_A_error (pdb_file, X[iisample], plotit=True) |
| else: |
| print ("Error calculation on the predicted results is not applicable") |
| |
| pdb_list.append(pdb_file) |
| |
| return pdb_list |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| def sample_sequence_omegafold_pLM_ModelA ( |
| model, |
| X=[[0.92, 0., 0.04, 0.04, 0., 0., 0., 0., ]], |
| flag=0, |
| cond_scales=1., |
| foldproteins=False, |
| X_string=None, |
| x_data=None, |
| skip_steps=0, |
| inpaint_images=None, |
| inpaint_masks = None, |
| inpaint_resample_times = None, |
| init_images = None, |
| num_cycle=16, |
| calc_error=False, |
| |
| |
| tokenizer_X_forImageCondi=None, |
| Xnormfac_forImageCondi=1., |
| tokenizer_X_forTextCondi=None, |
| Xnormfac_forTextCondi=1., |
| tokenizer_y=None, |
| ynormfac=1, |
| |
| train_unet_number=1, |
| max_length_Y=1, |
| max_text_len=1, |
| |
| steps=None, |
| e=None, |
| sample_dir=None, |
| prefix=None, |
| IF_showfig=True, |
| CKeys=None, |
| |
| normalize_X_cond_to_one=False, |
| |
| pLM_Model=None, |
| pLM_Model_Name=None, |
| image_channels=None, |
| pLM_alphabet=None, |
| ): |
| |
| |
| |
|
|
| |
| |
| |
| if X!=None: |
| print (f"Producing {len(X)} samples...from text conditioning X...") |
| lenn_val=len(X) |
| if X_string!=None: |
| lenn_val=len(X_string) |
| print (f"Producing {len(X_string)} samples...from text conditioning X_String (from string)...") |
| if x_data!=None: |
| print (f"Producing {len(x_data)} samples...from image conditingig x_data ...") |
| lenn_val=len(x_data) |
| |
| |
| print ('Device: ', model.device) |
| |
| pdb_list=[] |
| fasta_list=[] |
| |
| for iisample in range (lenn_val): |
| print(f"Working on {iisample}") |
| X_cond=None |
| |
| if X_string==None and X!=None: |
| X_cond=torch.Tensor (X[iisample]).to(device).unsqueeze (0) |
| if X_string!=None: |
| |
| |
| |
| |
| |
| |
| XX = tokenizer_X_forTextCondi.texts_to_sequences(X_string[iisample]) |
| XX = sequence.pad_sequences(XX, maxlen=max_text_len, padding='post', truncating='post') |
| XX = np.array(XX) |
| X_cond = torch.from_numpy(XX).float()/Xnormfac_forTextCondi |
| |
| print ('Tokenized and processed: ', X_cond) |
| |
| if X_cond!=None: |
| if normalize_X_cond_to_one: |
| X_cond=X_cond/X_cond.sum() |
| |
| print ("Text conditoning used: ", X_cond, "...sum: ", X_cond.sum(), "cond scale: ", cond_scales) |
| else: |
| print ("Text conditioning used: None") |
| |
| |
| if tokenizer_X_forImageCondi==None: |
| |
| |
| |
| |
| if x_data!=None: |
| x_data_tokenized=torch.from_numpy(x_data[iisample]/Xnormfac_forImageCondi) |
| x_data_tokenized=x_data_tokenized.to(torch.float) |
| |
| if CKeys['Debug_TrainerPack']==1: |
| print("x_data_tokenized dim: ", x_data_tokenized.shape) |
| print("x_data_tokenized dtype: ", x_data_tokenized.dtype) |
| print("test: ", x_data_tokenized!=None) |
| else: |
| x_data_tokenized=None |
| |
| if CKeys['Debug_TrainerPack']==1: |
| print("x_data_tokenized and x_data: None") |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| result_embedding=model.sample ( |
| x=X_cond, |
| stop_at_unet_number=train_unet_number , |
| cond_scale=cond_scales , |
| x_data=None, |
| |
| x_data_tokenized=x_data_tokenized, |
| |
| skip_steps=skip_steps, |
| inpaint_images = inpaint_images, |
| inpaint_masks = inpaint_masks, |
| inpaint_resample_times = inpaint_resample_times, |
| init_images = init_images, |
| device=model.device, |
| |
| tokenizer_X=tokenizer_X_forImageCondi, |
| Xnormfac=Xnormfac_forImageCondi, |
| |
| max_length=max_length_Y, |
| max_text_len=max_text_len, |
| ) |
| else: |
| |
| |
| |
| |
| |
| result_embedding=model.sample ( |
| x=X_cond, |
| stop_at_unet_number=train_unet_number , |
| cond_scale=cond_scales , |
| x_data=x_data[iisample], |
| |
| x_data_tokenized=None, |
| |
| skip_steps=skip_steps, |
| inpaint_images = inpaint_images, |
| inpaint_masks = inpaint_masks, |
| inpaint_resample_times = inpaint_resample_times, |
| init_images = init_images, |
| device=model.device, |
| |
| tokenizer_X=tokenizer_X_forImageCondi, |
| Xnormfac=Xnormfac_forImageCondi, |
| |
| max_length=max_length_Y, |
| max_text_len=max_text_len, |
| ) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| result_tokens,result_logits = convert_into_tokens( |
| pLM_Model, |
| result_embedding, |
| pLM_Model_Name, |
| ) |
| result=result_tokens.unsqueeze(1) |
| |
| |
| print("result.dim: ", result.shape) |
| |
| fig=plt.figure() |
| plt.plot ( |
| result[0,0,:].cpu().detach().numpy(), |
| label= f'Predicted' |
| ) |
| |
| plt.legend() |
| outname = sample_dir+ f"sampled_from_X_{iisample}_condscale-{str (cond_scales)}_{e}_{steps}.jpg" |
| |
| if IF_showfig==1: |
| plt.show () |
| else: |
| plt.savefig(outname, dpi=200) |
| plt.close() |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| to_rev=result[:,0,:] |
| |
| y_data_reversed=decode_many_ems_token_rec_for_folding( |
| to_rev, |
| result_logits, |
| pLM_alphabet, |
| pLM_Model, |
| ) |
| if CKeys['Debug_TrainerPack']==3: |
| print("on y_data_reversed[0]: ", y_data_reversed[0]) |
| |
| |
| |
| |
| |
| |
| if X_cond != None: |
| |
| if X_string!=None: |
| X_cond=torch.round(X_cond*Xnormfac_forTextCondi) |
|
|
| to_rev=X_cond[:,:] |
| to_rev=to_rev.long().cpu().detach().numpy() |
| print ("to_rev.dim: ", to_rev.shape) |
| |
| |
| |
| X_text_reversed=tokenizer_X_forTextCondi.sequences_to_texts (to_rev) |
| for iii in range (len(y_text_reversed)): |
| X_text_reversed[iii]=X_text_reversed[iii].upper().strip().replace(" ", "") |
| |
| if X_string==None: |
| |
| X_text_reversed=X_cond |
| else: |
| X_text_reversed=None |
| |
| if x_data !=None: |
| x_data_reversed=x_data |
| else: |
| x_data_reversed=None |
| |
| |
| print (f"For {X_text_reversed} or {X[iisample]} on Text_Condi,\n and {x_data_reversed} on Image_Condi,\n predicted sequence: ", y_data_reversed) |
| |
| |
| print("================================================") |
| print("foldproteins: ", foldproteins) |
| |
| if not foldproteins: |
| pdb_file=None |
| else: |
| |
| |
| if X_cond != None: |
| pass |
| |
| tempname='temp' |
| pdb_file, fasta_file=foldandsavePDB_pdb_fasta ( |
| sequence=y_data_reversed[0], |
| filename_out=tempname, |
| num_cycle=num_cycle, |
| flag=flag, |
| |
| |
| prefix=sample_dir, |
| ) |
| |
| out_nam=iisample |
| |
| |
| |
| |
| out_nam=f'{sample_dir}DeNovoSampling_{iisample}_epo_{e}_step_{steps}.pdb' |
| out_nam_fasta=f'{sample_dir}DeNovoSampling_{iisample}_epo_{e}_step_{steps}.fasta' |
| shutil.copy (pdb_file, out_nam) |
| shutil.copy (fasta_file, out_nam_fasta) |
| |
| os.remove (pdb_file) |
| os.remove (fasta_file) |
| |
| pdb_file=out_nam |
| fasta_file=out_nam_fasta |
| |
| pdb_list.append(pdb_file) |
| fasta_list.append(fasta_file) |
| |
| print (f"Properly named PDB file produced: {pdb_file}") |
| if IF_showfig==1: |
| |
| view=show_pdb( |
| pdb_file=pdb_file, |
| flag=flag, |
| show_sidechains=show_sidechains, |
| show_mainchains=show_mainchains, |
| color=color |
| ) |
| view.show() |
| |
| if calc_error: |
| if CKeys['Problem_ID']==7: |
| |
| get_Model_A_error (pdb_file, X[iisample], plotit=True) |
| else: |
| print("Error calculation on the predicted results is not applicable...") |
| |
| |
| |
| return pdb_list, fasta_list |
|
|
| |
| |
| |
| def sample_loop_omegafold_ModelA ( |
| model, |
| train_loader, |
| cond_scales=None, |
| num_samples=None, |
| timesteps=None, |
| flag=None, |
| foldproteins=False, |
| |
| cond_image=False, |
| cond_text=True, |
| skip_steps=0, |
| |
| max_text_len=None, |
| max_length=None, |
| |
| train_unet_number=1, |
| ynormfac=None, |
| prefix=None, |
| tokenizer_y=None, |
| Xnormfac_CondiText=1, |
| tokenizer_X_CondiText=None, |
| |
| CKeys=None, |
| sample_dir=None, |
| steps=None, |
| e=None, |
| IF_showfig=True, |
| ): |
| |
| |
| |
| |
| |
| |
| for idx, item in enumerate(train_loader): |
|
|
| X_train_batch= item[0].to(device) |
| y_train_batch=item[1].to(device) |
|
|
| GT=y_train_batch.cpu().detach() |
|
|
| GT= GT.unsqueeze(1) |
| if num_samples>y_train_batch.shape[0]: |
| print("Warning: sampling # > len(mini_batch)") |
|
|
| num_samples = min (num_samples,y_train_batch.shape[0] ) |
| print (f"Producing {num_samples} samples...") |
| X_train_batch_picked = X_train_batch[:num_samples,:] |
| print ('(TEST) X_batch shape: ', X_train_batch_picked.shape) |
|
|
| |
| for iisample in range (len (cond_scales)): |
|
|
| |
| result=model.sample ( |
| x=X_train_batch_picked, |
| stop_at_unet_number=train_unet_number, |
| cond_scale=cond_scales[iisample], |
| |
| skip_steps=skip_steps, |
| device=model.device, |
| |
| max_length=max_length, |
| max_text_len=max_text_len, |
| ) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| result=torch.round(result*ynormfac) |
| GT=torch.round (GT*ynormfac) |
|
|
| for samples in range (num_samples): |
| print ("sample ", samples+1, "out of ", num_samples) |
|
|
| fig=plt.figure() |
| plt.plot ( |
| result[samples,0,:].cpu().detach().numpy(), |
| label= f'Predicted' |
| ) |
| plt.plot ( |
| GT[samples,0,:], |
| label= f'GT {0}' |
| ) |
| plt.legend() |
| outname = sample_dir+ f"Batch_{idx}_sample_{samples}_condscale-{str (cond_scales[iisample])}_{e}_{steps}.jpg" |
| if IF_showfig==1: |
| plt.show() |
| else: |
| plt.savefig(outname, dpi=200) |
| plt.close () |
|
|
| |
| to_rev=result[:,0,:] |
| to_rev=to_rev.long().cpu().detach().numpy() |
|
|
| y_data_reversed=tokenizer_y.sequences_to_texts (to_rev) |
|
|
| for iii in range (len(y_data_reversed)): |
| y_data_reversed[iii]=y_data_reversed[iii].upper().strip().replace(" ", "") |
|
|
| |
| to_rev=GT[:,0,:] |
| to_rev=to_rev.long().cpu().detach().numpy() |
|
|
| GT_y_data_reversed=tokenizer_y.sequences_to_texts (to_rev) |
|
|
| for iii in range (len(y_data_reversed)): |
| GT_y_data_reversed[iii]=GT_y_data_reversed[iii].upper().strip().replace(" ", "") |
|
|
| |
| |
| |
| |
| |
| to_rev=torch.round (X_train_batch[:,:]*torch.FloatTensor(Xnormfac_CondiText).to(model.device)) |
| to_rev=to_rev.long().cpu().detach().numpy() |
| |
| |
| if CKeys['Debug_TrainerPack']==1: |
| print("tokenizer_X_CondiText: ", tokenizer_X_CondiText) |
| print("Xnormfac_CondiText: ", Xnormfac_CondiText) |
| |
| if tokenizer_X_CondiText!=None: |
| X_data_reversed=tokenizer_X_CondiText.sequences_to_texts (to_rev) |
| for iii in range (len(y_data_reversed)): |
| X_data_reversed[iii]=X_data_reversed[iii].upper().strip().replace(" ", "") |
| else: |
| X_data_reversed=to_rev.copy() |
| |
| if CKeys['Debug_TrainerPack']==1: |
| print("X_data_reversed: ", X_data_reversed) |
| |
|
|
| print (f"For {X_train_batch[samples,:].cpu().detach().numpy()} or {X_data_reversed[samples]}, \npredicted sequence: ", y_data_reversed[samples]) |
| print (f"Ground truth: {GT_y_data_reversed[samples]}") |
|
|
| if foldproteins: |
| xbc=X_train_batch[samples,:].cpu().detach().numpy() |
| out_nam=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.1f" % xbc}) |
| tempname='temp' |
| pdb_file=foldandsavePDB ( |
| sequence=y_data_reversed[samples], |
| filename_out=tempname, |
| num_cycle=16, flag=flag, |
| |
| prefix=prefix |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| out_nam=f'{sample_dir}SamplingLoop_B_{idx}_Sample_{samples}_condscale-{str (cond_scales[iisample])}_epo_{e}_step_{steps}.pdb' |
| out_nam_inX=f'{sample_dir}SamplingLoop_B_{idx}_Sample_{samples}_condscale-{str (cond_scales[iisample])}_epo_{e}_step_{steps}.txt' |
| |
| if CKeys['Debug_TrainerPack']==1: |
| print("pdb_file: ", pdb_file) |
| print("out_nam: ", out_nam) |
| |
| print (f'Original PDB: {pdb_file} OUT: {out_nam}') |
| shutil.copy (pdb_file, out_nam) |
| |
| with open(out_nam_inX, "w") as inX_file: |
| inX_file.write(f'{X_data_reversed[samples]}\n') |
| |
| pdb_file=out_nam |
| print (f"Properly named PDB file produced: {pdb_file}") |
| print (f"input X for sampling stored: {pdb_file}") |
| |
| if IF_showfig==1: |
| view=show_pdb( |
| pdb_file=pdb_file, |
| flag=flag, |
| show_sidechains=show_sidechains, |
| show_mainchains=show_mainchains, |
| color=color |
| ) |
| view.show() |
|
|
| |
| |
| |
| |
|
|
| |
| |
| def sample_loop_omegafold_pLM_ModelA ( |
| model, |
| train_loader, |
| cond_scales=None, |
| num_samples=None, |
| timesteps=None, |
| flag=None, |
| foldproteins=False, |
| |
| cond_image=False, |
| cond_text=True, |
| skip_steps=0, |
| |
| max_text_len=None, |
| max_length=None, |
| |
| train_unet_number=1, |
| ynormfac=None, |
| prefix=None, |
| tokenizer_y=None, |
| Xnormfac_CondiText=1, |
| tokenizer_X_CondiText=None, |
| |
| CKeys=None, |
| sample_dir=None, |
| steps=None, |
| e=None, |
| IF_showfig=True, |
| |
| pLM_Model=None, |
| pLM_Model_Name=None, |
| image_channels=None, |
| pLM_alphabet=None, |
| |
| calc_error=False, |
| ): |
| |
| |
| |
| |
| |
| |
| for idx, item in enumerate(train_loader): |
|
|
| X_train_batch= item[0].to(device) |
| y_train_batch=item[1].to(device) |
|
|
| GT=y_train_batch.cpu().detach() |
|
|
| GT= GT.unsqueeze(1) |
| if num_samples>y_train_batch.shape[0]: |
| print("Warning: sampling # > len(mini_batch)") |
|
|
| num_samples = min (num_samples,y_train_batch.shape[0] ) |
| print (f"Producing {num_samples} samples...") |
| X_train_batch_picked = X_train_batch[:num_samples,:] |
| print ('(TEST) X_batch shape: ', X_train_batch_picked.shape) |
|
|
| |
| for iisample in range (len (cond_scales)): |
|
|
| |
| result_embedding = model.sample ( |
| x=X_train_batch_picked, |
| stop_at_unet_number=train_unet_number, |
| cond_scale=cond_scales[iisample], |
| |
| skip_steps=skip_steps, |
| device=model.device, |
| |
| max_length=max_length, |
| max_text_len=max_text_len, |
| |
| x_data=None, |
| x_data_tokenized=None, |
| |
| tokenizer_X=tokenizer_X_CondiText, |
| Xnormfac=Xnormfac_CondiText, |
| ) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| result_tokens,result_logits = convert_into_tokens( |
| pLM_Model, |
| result_embedding, |
| pLM_Model_Name, |
| ) |
| |
| |
| |
| |
| |
| result=result_tokens.unsqueeze(1) |
| |
| |
| |
| |
| |
|
|
| |
|
|
| |
| |
| |
| to_rev=result[:,0,:] |
| y_data_reversed=decode_many_ems_token_rec_for_folding( |
| to_rev, |
| result_logits, |
| pLM_alphabet, |
| pLM_Model, |
| ) |
| if CKeys['Debug_TrainerPack']==3: |
| print("on y_data_reversed[0]: ", y_data_reversed[0]) |
| |
|
|
| |
| |
| |
| |
|
|
| |
|
|
| |
| |
| |
| |
| |
| to_rev=GT[:,0,:] |
| GT_y_data_reversed=decode_many_ems_token_rec( |
| to_rev, |
| pLM_alphabet, |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| to_rev=X_train_batch[:,:]*torch.tensor(Xnormfac_CondiText).to(model.device) |
| |
| |
| |
| |
| to_rev=to_rev.cpu().detach().numpy() |
| |
|
|
| |
| if CKeys['Debug_TrainerPack']==1: |
| print("tokenizer_X_CondiText: ", tokenizer_X_CondiText) |
| print("Xnormfac_CondiText: ", Xnormfac_CondiText) |
|
|
| if tokenizer_X_CondiText!=None: |
| |
| to_rev = np.round(to_rev) |
| X_data_reversed=tokenizer_X_CondiText.sequences_to_texts (to_rev) |
| for iii in range (len(y_data_reversed)): |
| X_data_reversed[iii]=X_data_reversed[iii].upper().strip().replace(" ", "") |
| else: |
| X_data_reversed=to_rev.copy() |
| |
| if CKeys['Debug_TrainerPack']==1: |
| print("X_data_reversed: ", X_data_reversed) |
| print("X_data_reversed.dim: ", X_data_reversed.shape) |
|
|
| for samples in range (num_samples): |
| print ("sample ", samples+1, "out of ", num_samples) |
|
|
| fig=plt.figure() |
| plt.plot ( |
| result[samples,0,:].cpu().detach().numpy(), |
| label= f'Predicted' |
| ) |
| plt.plot ( |
| GT[samples,0,:], |
| label= f'GT {0}' |
| ) |
| plt.legend() |
| outname = sample_dir+ f"Batch_{idx}_sample_{samples}_condscale-{str (cond_scales[iisample])}_{e}_{steps}.jpg" |
| if IF_showfig==1: |
| plt.show() |
| else: |
| plt.savefig(outname, dpi=200) |
| plt.close () |
| |
|
|
| print (f"For input in dataloader: {X_train_batch[samples,:].cpu().detach().numpy()} or \n recovered input {X_data_reversed[samples]}") |
| print (f"predicted sequence: {y_data_reversed[samples]}") |
| print (f"Ground truth: {GT_y_data_reversed[samples]}") |
|
|
| if foldproteins: |
| |
| if len(y_data_reversed[samples])>0: |
| |
| |
| |
| |
| |
| xbc=X_data_reversed[samples] |
| out_nam_content=np.array2string(xbc, formatter={'float_kind':lambda xbc: "%.4f" % xbc}) |
| |
| tempname='temp' |
| pdb_file,fasta_file=foldandsavePDB_pdb_fasta ( |
| sequence=y_data_reversed[samples], |
| filename_out=tempname, |
| num_cycle=16, flag=flag, |
| |
| prefix=prefix |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| out_nam=f'{sample_dir}SamplingLoop_B_{idx}_Sample_{samples}_condscale-{str (cond_scales[iisample])}_epo_{e}_step_{steps}.pdb' |
| out_nam_seq=f'{sample_dir}SamplingLoop_B_{idx}_Sample_{samples}_condscale-{str (cond_scales[iisample])}_epo_{e}_step_{steps}.fasta' |
| out_nam_inX=f'{sample_dir}SamplingLoop_B_{idx}_Sample_{samples}_condscale-{str (cond_scales[iisample])}_epo_{e}_step_{steps}.txt' |
|
|
| if CKeys['Debug_TrainerPack']==1: |
| print("pdb_file: ", pdb_file) |
| print("out_nam: ", out_nam) |
|
|
| print (f'Original PDB: {pdb_file} OUT: {out_nam}') |
| shutil.copy (pdb_file, out_nam) |
| shutil.copy (fasta_file, out_nam_seq) |
| |
| with open(out_nam_inX, "w") as inX_file: |
| |
| inX_file.write(out_nam_content) |
| |
| os.remove (pdb_file) |
| os.remove (fasta_file) |
|
|
|
|
| pdb_file=out_nam |
| print (f"Properly named PDB file produced: {pdb_file}") |
| print (f"input X for sampling stored: {pdb_file}") |
|
|
| if IF_showfig==1: |
| view=show_pdb( |
| pdb_file=pdb_file, |
| flag=flag, |
| show_sidechains=show_sidechains, |
| show_mainchains=show_mainchains, |
| color=color |
| ) |
| view.show() |
| |
| if calc_error: |
| print('On-fly check...') |
| if CKeys['Problem_ID']==7: |
| |
| get_Model_A_error (pdb_file, X_data_reversed[samples], plotit=True) |
| else: |
| print("Error calculation on the predicted results is not applicable...") |
| |
| |
| else: |
| print("The predicted sequence is EMPTY...") |
| |
|
|