| import os |
| import gc |
| import torch |
|
|
| import torch.nn.functional as F |
| import lightning as pl |
|
|
| from typing import Optional |
| from transformers import AutoModelForMaskedLM, AutoTokenizer |
|
|
| from src.utils.model_utils import _print |
| from src.utils.optimizer_utils import get_optimizer, get_scheduler |
|
|
|
|
| class MembraneDiffusion(pl.LightningModule): |
| def __init__(self, config): |
| """ |
| Args: |
| config (OmegaConf): config.yaml file with all training parameters |
| """ |
| super().__init__() |
| self.config = config |
| self.save_hyperparameters(logger=True) |
|
|
| self.model = AutoModelForMaskedLM.from_pretrained(config.lm.pretrained_evoflow, trust_remote_code=True) |
| self.tokenizer = AutoTokenizer.from_pretrained(config.lm.pretrained_evoflow) |
|
|
| self.mask_id = self.tokenizer.mask_token_id |
| self.pad_id = self.tokenizer.pad_token_id |
|
|
| def forward(self, input_ids, attention_mask, guidance: Optional[bool] = False): |
| """ |
| Forward pass through language model. |
| |
| Args: |
| - input_ids (torch.Tensor): [B, L], token ids |
| - attention_mask (torch.Tensor): [B, L], pad/non-pad binary mask |
| Returns: |
| - logits (torch.Tensor): [B, L, V], unnormalized model outputs |
| """ |
| return self.model(input_ids=input_ids, attention_mask=attention_mask).logits |
|
|
| |
| def step(self, batch): |
| labels = batch['input_ids'] |
|
|
| |
| t1 = self.sample_t(labels) |
| xt, _ = self.noise_x0(labels, t1, maskable_mask=self.is_maskable(labels)) |
| logits = self.forward(input_ids=xt, attention_mask=batch['attention_mask']) |
|
|
| |
| weight = self.get_weight(t1, weight_type=self.config.lm.weight_type) |
| loss_out = self.compute_loss(logits, labels, weight) |
|
|
| self.cleanup() |
| return loss_out['loss'], loss_out['ppl'] |
| |
| def sample_t(self, labels, rdm_coupling=False): |
| """ |
| Sample diffusion timesteps. Non-coupling RDM only uses one timestep (t1). |
| """ |
| timesteps = torch.randint( |
| 1, |
| self.config.lm.num_diffusion_timesteps + 1, |
| (2 if rdm_coupling else 1) * (labels.size(0),), |
| device=labels.device |
| ) |
|
|
| if rdm_coupling: |
| return timesteps.chunk(2) |
| return timesteps |
|
|
| def noise_x0(self, x0, t1, maskable_mask): |
| """ |
| Apply noise to the initial sequence x0. |
| """ |
| u = torch.rand_like(x0, dtype=torch.float) |
| t1_mask = (u < (t1 / self.config.lm.num_diffusion_timesteps)[:, None]) & maskable_mask |
| x_t1 = x0.masked_fill(t1_mask, self.mask_id) |
| return x_t1, t1_mask |
|
|
| def get_weight(self, t, weight_type): |
| """ |
| Compute the weighting factor for the RDM-derived loss (weighted cross-entropy). |
| """ |
| num_timesteps = self.config.lm.num_diffusion_timesteps |
| weight = { |
| "linear": (num_timesteps - (t - 1)), |
| "constant": num_timesteps * torch.ones_like(t), |
| }[weight_type][:, None].float() / num_timesteps |
| return weight.squeeze() |
|
|
| def compute_loss(self, logits, labels, weight): |
| """ |
| Compute the cross entropy loss per sample. |
| First, compute the per-token loss (with no reduction), then reduce over the sequence length for each sample. |
| Finally, average over the batch. |
| |
| Args: |
| logits (torch.Tensor): [B, L, vocab_size], unnormalized model outputs |
| labels (torch.Tensor): [B, L], target labels (with padding tokens as -100) |
| weight (torch.Tensor): [B, 1], per-sample weight for loss calculation |
| Returns: |
| loss (torch.Tensor): Averaged loss over the batch |
| logging_output (torch.Tensor): Dictionary of values for logging |
| """ |
|
|
| loss_token = F.cross_entropy( |
| logits.view(-1, logits.size(-1)), |
| labels.view(-1), |
| reduction='none', |
| ignore_index=self.pad_id, |
| ) |
| |
| loss_token = loss_token.view(labels.size(0), labels.size(1)) |
| valid_mask = (labels != self.pad_id) |
| |
| sample_loss = (loss_token * valid_mask.float()).sum(dim=1) / valid_mask.float().sum(dim=1).clamp(min=1) |
| sample_loss *= weight |
| ppl = torch.exp(sample_loss) |
|
|
| return {'ppl': ppl.mean(), 'loss': sample_loss.mean()} |
| |
|
|
| |
| def training_step(self, batch): |
| loss, ppl = self.step(batch) |
| self.log("train/loss", loss.item(), on_step=True, on_epoch=False, prog_bar=True) |
| self.log("train/ppl", ppl.item(), on_step=True, on_epoch=False, prog_bar=False) |
| return loss |
| |
| def validation_step(self, batch): |
| loss, ppl = self.step(batch) |
| self.cleanup() |
| self.log("val/loss", loss.item(), on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) |
| self.log("val/ppl", ppl.item(), on_step=False, on_epoch=True, prog_bar=False, sync_dist=True) |
| return loss |
|
|
| def test_step(self, batch): |
| loss, ppl = self.step(batch) |
| self.cleanup() |
| self.log('test/loss', loss.item(), on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) |
| self.log("test/ppl", ppl.item(), on_step=False, on_epoch=True, prog_bar=False, sync_dist=True) |
| return loss |
|
|
|
|
| |
| def is_maskable(self, input_ids: torch.Tensor): |
| return ( |
| (input_ids != self.tokenizer.pad_token_id) |
| & (input_ids != self.tokenizer.cls_token_id) |
| & (input_ids != self.tokenizer.eos_token_id) |
| ) |
|
|
| def configure_optimizers(self): |
| """ |
| Choosing which optimizer and lr scheduler to use. |
| """ |
| optimizer = get_optimizer(self.config, self.model.parameters()) |
| lr_scheduler, extra_kwargs = get_scheduler(self.config, optimizer) |
| return { |
| "optimizer": optimizer, |
| "lr_scheduler": {"scheduler": lr_scheduler, **extra_kwargs}, |
| } |
|
|
| def validate_config(self): |
| assert os.path.isdir(self.config.checkpointing.save_dir), "invalid checkpointing path" |
| assert self.config.training.mode in ["train", "test", "resume_from_checkpoint"], "invalid mode" |
|
|
| def get_state_dict(self, ckpt_path): |
| def remove_model_prefix(state_dict): |
| for k, v in state_dict.items(): |
| if "model." in k: |
| k.replace('model.', '') |
| return state_dict |
|
|
| checkpoint = torch.load(ckpt_path, map_location='cuda' if torch.cuda.is_available() else 'cpu') |
| state_dict = checkpoint.get("state_dict", checkpoint) |
|
|
| if any(k.startswith("model.") for k in state_dict.keys()): |
| state_dict = remove_model_prefix(state_dict) |
| |
| return state_dict |
|
|
| def cleanup(self): |
| torch.cuda.empty_cache() |
| gc.collect() |