| import gc |
| import torch |
| import torch.nn as nn |
| import lightning.pytorch as pl |
|
|
| from omegaconf import OmegaConf |
| from transformers import AutoModel |
| from torchmetrics.classification import BinaryAUROC, BinaryAccuracy |
|
|
| from src.utils.model_utils import _print |
| from src.guidance.utils import CosineWarmup |
|
|
|
|
| config = OmegaConf.load("/scratch/sgoel/MeMDLM_v2/src/configs/guidance.yaml") |
|
|
| class SolubilityClassifier(pl.LightningModule): |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.loss_fn = nn.BCEWithLogitsLoss(reduction='none') |
| self.auroc = BinaryAUROC() |
| self.accuracy = BinaryAccuracy() |
|
|
| self.esm_model = AutoModel.from_pretrained(self.config.lm.pretrained_esm) |
| for p in self.esm_model.parameters(): |
| p.requires_grad = False |
|
|
| encoder_layer = nn.TransformerEncoderLayer( |
| d_model=config.model.d_model, |
| nhead=config.model.num_heads, |
| dropout=config.model.dropout, |
| batch_first=True |
| ) |
| self.encoder = nn.TransformerEncoder(encoder_layer, config.model.num_layers) |
| self.layer_norm = nn.LayerNorm(config.model.d_model) |
| self.dropout = nn.Dropout(config.model.dropout) |
| self.mlp = nn.Sequential( |
| nn.Linear(config.model.d_model, config.model.d_model // 2), |
| nn.ReLU(), |
| nn.Dropout(config.model.dropout), |
| nn.Linear(config.model.d_model // 2, 1), |
| ) |
|
|
| |
| def forward(self, batch): |
| if 'input_ids' in batch: |
| esm_embeds = self.get_esm_embeddings(batch['input_ids'], batch['attention_mask']) |
| elif 'embeds' in batch: |
| esm_embeds = batch['embeds'] |
| encodings = self.encoder(esm_embeds, src_key_padding_mask=(batch['attention_mask'] == 0)) |
| encodings = self.dropout(self.layer_norm(encodings)) |
| logits = self.mlp(encodings).squeeze(-1) |
| return logits |
|
|
| |
| |
| def training_step(self, batch, batch_idx): |
| train_loss, _ = self.compute_loss(batch) |
| self.log(name="train/loss", value=train_loss.item(), on_step=True, on_epoch=False, logger=True, sync_dist=True) |
| self.save_ckpt() |
| return train_loss |
|
|
| def validation_step(self, batch, batch_idx): |
| val_loss, _ = self.compute_loss(batch) |
| self.log(name="val/loss", value=val_loss.item(), on_step=False, on_epoch=True, logger=True, sync_dist=True) |
| return val_loss |
|
|
| def test_step(self, batch): |
| test_loss, preds = self.compute_loss(batch) |
| auroc, accuracy = self.get_metrics(batch, preds) |
| self.log(name="test/loss", value=test_loss.item(), on_step=False, on_epoch=True, logger=True, sync_dist=True) |
| self.log(name="test/AUROC", value=auroc.item(), on_step=False, on_epoch=True, logger=True, sync_dist=True) |
| self.log(name="test/accuracy", value=accuracy.item(), on_step=False, on_epoch=True, logger=True, sync_dist=True) |
| return test_loss |
|
|
| def on_test_epoch_end(self): |
| self.auroc.reset() |
| self.accuracy.reset() |
| |
| def optimizer_step(self, *args, **kwargs): |
| super().optimizer_step(*args, **kwargs) |
| gc.collect() |
| torch.cuda.empty_cache() |
|
|
| def configure_optimizers(self): |
| path = self.config.training |
| optimizer = torch.optim.AdamW(self.parameters(), lr=self.config.optim.lr) |
| lr_scheduler = CosineWarmup( |
| optimizer, |
| warmup_steps=path.warmup_steps, |
| total_steps=path.max_steps, |
| ) |
| scheduler_dict = { |
| "scheduler": lr_scheduler, |
| "interval": 'step', |
| 'frequency': 1, |
| 'monitor': 'val/loss', |
| 'name': 'learning_rate' |
| } |
| return [optimizer], [scheduler_dict] |
| |
| def save_ckpt(self): |
| curr_step = self.global_step |
| save_every = self.config.training.val_check_interval |
| if curr_step % save_every == 0 and curr_step > 0: |
| ckpt_path = f"{self.config.checkpointing.save_dir}/step={curr_step}.ckpt" |
| self.trainer.save_checkpoint(ckpt_path) |
| |
| |
| @torch.no_grad |
| def get_esm_embeddings(self, input_ids, attention_mask): |
| outputs = self.esm_model(input_ids=input_ids, attention_mask=attention_mask) |
| embeddings = outputs.last_hidden_state |
| return embeddings |
|
|
| def compute_loss(self, batch): |
| """Helper method to handle loss calculation""" |
| labels = batch['labels'] |
| preds = self.forward(batch) |
| loss = self.loss_fn(preds, labels) |
| loss_mask = (labels != self.config.model.label_pad_value) |
| loss = (loss * loss_mask).sum() / loss_mask.sum() |
| return loss, preds |
|
|
| def get_metrics(self, batch, preds): |
| """Helper method to compute metrics""" |
| labels = batch['labels'] |
|
|
| valid_mask = (labels != self.config.model.label_pad_value) |
| labels = labels[valid_mask] |
| preds = preds[valid_mask] |
|
|
| _print(f"labels {labels.shape}") |
| _print(f"preds {preds.shape}") |
|
|
| auroc = self.auroc.forward(preds, labels) |
| accuracy = self.accuracy.forward(preds, labels) |
| return auroc, accuracy |
|
|
| |
| def get_state_dict(self, ckpt_path): |
| """Helper method to load and process a trained model's state dict from saved checkpoint""" |
| def remove_model_prefix(state_dict): |
| for k in state_dict.keys(): |
| if "model." in k: |
| k.replace('model.', '') |
| return state_dict |
|
|
| checkpoint = torch.load(ckpt_path, map_location='cuda' if torch.cuda.is_available() else 'cpu') |
| state_dict = checkpoint.get("state_dict", checkpoint) |
|
|
| if any(k.startswith("model.") for k in state_dict.keys()): |
| state_dict = remove_model_prefix(state_dict) |
| |
| return state_dict |