| | """PyTorch Sybil model for lung cancer risk prediction""" |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torchvision |
| | from transformers import PreTrainedModel |
| | from transformers.modeling_outputs import BaseModelOutput |
| | from typing import Optional, Dict, List, Tuple |
| | import numpy as np |
| | from dataclasses import dataclass |
| |
|
| | try: |
| | from .configuration_sybil import SybilConfig |
| | except ImportError: |
| | from configuration_sybil import SybilConfig |
| |
|
| |
|
| | @dataclass |
| | class SybilOutput(BaseModelOutput): |
| | """ |
| | Base class for Sybil model outputs. |
| | |
| | Args: |
| | risk_scores: (`torch.FloatTensor` of shape `(batch_size, max_followup)`): |
| | Predicted risk scores for each year up to max_followup. |
| | image_attention: (`torch.FloatTensor` of shape `(batch_size, num_slices, height, width)`, *optional*): |
| | Attention weights over image pixels. |
| | volume_attention: (`torch.FloatTensor` of shape `(batch_size, num_slices)`, *optional*): |
| | Attention weights over CT scan slices. |
| | hidden_states: (`torch.FloatTensor` of shape `(batch_size, hidden_dim)`, *optional*): |
| | Hidden states from the pooling layer. |
| | """ |
| | risk_scores: torch.FloatTensor = None |
| | image_attention: Optional[torch.FloatTensor] = None |
| | volume_attention: Optional[torch.FloatTensor] = None |
| | hidden_states: Optional[torch.FloatTensor] = None |
| |
|
| |
|
| | class CumulativeProbabilityLayer(nn.Module): |
| | """ |
| | Cumulative probability layer for survival prediction. |
| | |
| | Matches the original Sybil implementation exactly with: |
| | - hazard_fc: Year-specific hazards (can be zero after ReLU) |
| | - base_hazard_fc: Base hazard shared across all years |
| | - Triangular masking for cumulative hazard computation |
| | """ |
| |
|
| | def __init__(self, hidden_dim: int, max_followup: int = 6): |
| | super().__init__() |
| | self.max_followup = max_followup |
| |
|
| | |
| | self.hazard_fc = nn.Linear(hidden_dim, max_followup) |
| |
|
| | |
| | self.base_hazard_fc = nn.Linear(hidden_dim, 1) |
| |
|
| | self.relu = nn.ReLU(inplace=True) |
| |
|
| | |
| | mask = torch.ones([max_followup, max_followup]) |
| | mask = torch.tril(mask, diagonal=0) |
| | mask = torch.nn.Parameter(torch.t(mask), requires_grad=False) |
| | self.register_parameter("upper_triangular_mask", mask) |
| |
|
| | def hazards(self, x): |
| | """Compute positive hazards using ReLU""" |
| | raw_hazard = self.hazard_fc(x) |
| | pos_hazard = self.relu(raw_hazard) |
| | return pos_hazard |
| |
|
| | def forward(self, x): |
| | """ |
| | Compute cumulative probabilities matching original Sybil. |
| | |
| | Args: |
| | x: Hidden features [B, hidden_dim] |
| | |
| | Returns: |
| | Cumulative probabilities [B, max_followup] |
| | """ |
| | hazards = self.hazards(x) |
| | B, T = hazards.size() |
| |
|
| | |
| | expanded_hazards = hazards.unsqueeze(-1).expand(B, T, T) |
| |
|
| | |
| | masked_hazards = expanded_hazards * self.upper_triangular_mask |
| |
|
| | |
| | base_hazard = self.base_hazard_fc(x) |
| |
|
| | |
| | cum_prob = torch.sum(masked_hazards, dim=1) + base_hazard |
| |
|
| | return cum_prob |
| |
|
| |
|
| | class GlobalMaxPool(nn.Module): |
| | """Pool to obtain the maximum value for each channel""" |
| |
|
| | def __init__(self): |
| | super(GlobalMaxPool, self).__init__() |
| |
|
| | def forward(self, x): |
| | """ |
| | Args: |
| | - x: tensor of shape (B, C, T, W, H) |
| | Returns: |
| | - output: dict. output['hidden'] is (B, C) |
| | """ |
| | spatially_flat_size = (*x.size()[:2], -1) |
| | x = x.view(spatially_flat_size) |
| | hidden, _ = torch.max(x, dim=-1) |
| | return {'hidden': hidden} |
| |
|
| |
|
| | class PerFrameMaxPool(nn.Module): |
| | """Pool to obtain the maximum value for each slice in 3D input""" |
| |
|
| | def __init__(self): |
| | super(PerFrameMaxPool, self).__init__() |
| |
|
| | def forward(self, x): |
| | """ |
| | Args: |
| | - x: tensor of shape (B, C, T, W, H) |
| | Returns: |
| | - output: dict. |
| | + output['multi_image_hidden'] is (B, C, T) |
| | """ |
| | assert len(x.shape) == 5 |
| | output = {} |
| | spatially_flat_size = (*x.size()[:3], -1) |
| | x = x.view(spatially_flat_size) |
| | output['multi_image_hidden'], _ = torch.max(x, dim=-1) |
| | return output |
| |
|
| |
|
| | class Simple_AttentionPool(nn.Module): |
| | """Pool to learn an attention over the slices""" |
| |
|
| | def __init__(self, **kwargs): |
| | super(Simple_AttentionPool, self).__init__() |
| | self.attention_fc = nn.Linear(kwargs['num_chan'], 1) |
| | self.softmax = nn.Softmax(dim=-1) |
| | self.logsoftmax = nn.LogSoftmax(dim=-1) |
| |
|
| | def forward(self, x): |
| | """ |
| | Args: |
| | - x: tensor of shape (B, C, N) |
| | Returns: |
| | - output: dict |
| | + output['volume_attention']: tensor (B, N) |
| | + output['hidden']: tensor (B, C) |
| | """ |
| | output = {} |
| | B = x.shape[0] |
| | spatially_flat_size = (*x.size()[:2], -1) |
| |
|
| | x = x.view(spatially_flat_size) |
| | attention_scores = self.attention_fc(x.transpose(1, 2)) |
| |
|
| | output['volume_attention'] = self.logsoftmax(attention_scores.transpose(1, 2)).view(B, -1) |
| | attention_scores = self.softmax(attention_scores.transpose(1, 2)) |
| |
|
| | x = x * attention_scores |
| | output['hidden'] = torch.sum(x, dim=-1) |
| | return output |
| |
|
| |
|
| | class Simple_AttentionPool_MultiImg(nn.Module): |
| | """Pool to learn an attention over the slices and the volume""" |
| |
|
| | def __init__(self, **kwargs): |
| | super(Simple_AttentionPool_MultiImg, self).__init__() |
| | self.attention_fc = nn.Linear(kwargs['num_chan'], 1) |
| | self.softmax = nn.Softmax(dim=-1) |
| | self.logsoftmax = nn.LogSoftmax(dim=-1) |
| |
|
| | def forward(self, x): |
| | """ |
| | Args: |
| | - x: tensor of shape (B, C, T, W, H) |
| | Returns: |
| | - output: dict |
| | + output['image_attention']: tensor (B, T, W*H) |
| | + output['multi_image_hidden']: tensor (B, C, T) |
| | + output['hidden']: tensor (B, T*C) |
| | """ |
| | output = {} |
| | B, C, T, W, H = x.size() |
| | x = x.permute([0, 2, 1, 3, 4]) |
| | x = x.contiguous().view(B*T, C, W*H) |
| | attention_scores = self.attention_fc(x.transpose(1, 2)) |
| |
|
| | output['image_attention'] = self.logsoftmax(attention_scores.transpose(1, 2)).view(B, T, -1) |
| | attention_scores = self.softmax(attention_scores.transpose(1, 2)) |
| |
|
| | x = x * attention_scores |
| | x = torch.sum(x, dim=-1) |
| | output['multi_image_hidden'] = x.view(B, T, C).permute([0, 2, 1]).contiguous() |
| | output['hidden'] = x.view(B, T * C) |
| | return output |
| |
|
| |
|
| | class Conv1d_AttnPool(nn.Module): |
| | """Pool to learn an attention over the slices after convolution""" |
| |
|
| | def __init__(self, **kwargs): |
| | super(Conv1d_AttnPool, self).__init__() |
| | self.conv1d = nn.Conv1d( |
| | kwargs['num_chan'], |
| | kwargs['num_chan'], |
| | kernel_size=kwargs['conv_pool_kernel_size'], |
| | stride=kwargs['stride'], |
| | padding=kwargs['conv_pool_kernel_size']//2, |
| | bias=False |
| | ) |
| | self.aggregate = Simple_AttentionPool(**kwargs) |
| |
|
| | def forward(self, x): |
| | """ |
| | Args: |
| | - x: tensor of shape (B, C, T) |
| | Returns: |
| | - output: dict |
| | + output['attention_scores']: tensor (B, C) |
| | + output['hidden']: tensor (B, C) |
| | """ |
| | |
| | x = self.conv1d(x) |
| | return self.aggregate(x) |
| |
|
| |
|
| | class MultiAttentionPool(nn.Module): |
| | """Multi-attention pooling layer for CT scan aggregation - matches original Sybil architecture""" |
| |
|
| | def __init__(self, channels: int = 512): |
| | super().__init__() |
| | params = { |
| | 'num_chan': 512, |
| | 'conv_pool_kernel_size': 11, |
| | 'stride': 1 |
| | } |
| |
|
| | |
| | self.image_pool1 = Simple_AttentionPool_MultiImg(**params) |
| | self.volume_pool1 = Simple_AttentionPool(**params) |
| | self.image_pool2 = PerFrameMaxPool() |
| | self.volume_pool2 = Conv1d_AttnPool(**params) |
| | self.global_max_pool = GlobalMaxPool() |
| |
|
| | |
| | self.multi_img_hidden_fc = nn.Linear(2 * 512, 512) |
| | self.hidden_fc = nn.Linear(3 * 512, 512) |
| |
|
| | def forward(self, x): |
| | """ |
| | Args: |
| | x: tensor of shape (B, C, T, W, H) where |
| | - B: batch size |
| | - C: channels (512) |
| | - T: temporal/depth dimension (slices) |
| | - W, H: spatial dimensions |
| | |
| | Returns: |
| | output: dict with keys: |
| | - 'hidden': (B, 512) - final aggregated features |
| | - 'image_attention_1': (B, T, W*H) - image attention scores |
| | - 'volume_attention_1': (B, T) - volume attention scores |
| | - 'image_attention_2': None (no attention for max pool) |
| | - 'volume_attention_2': (B, T) - volume attention scores |
| | - 'multi_image_hidden': (B, 512, T) - intermediate features |
| | - 'maxpool_hidden': (B, 512) - max pooled features |
| | """ |
| | output = {} |
| |
|
| | |
| | image_pool_out1 = self.image_pool1(x) |
| | |
| |
|
| | volume_pool_out1 = self.volume_pool1(image_pool_out1['multi_image_hidden']) |
| | |
| |
|
| | |
| | image_pool_out2 = self.image_pool2(x) |
| | |
| |
|
| | volume_pool_out2 = self.volume_pool2(image_pool_out2['multi_image_hidden']) |
| | |
| |
|
| | |
| | for pool_out, num in [(image_pool_out1, 1), (volume_pool_out1, 1), |
| | (image_pool_out2, 2), (volume_pool_out2, 2)]: |
| | for key, val in pool_out.items(): |
| | output['{}_{}'.format(key, num)] = val |
| |
|
| | |
| | maxpool_out = self.global_max_pool(x) |
| | output['maxpool_hidden'] = maxpool_out['hidden'] |
| |
|
| | |
| | multi_image_hidden = torch.cat( |
| | [image_pool_out1['multi_image_hidden'], image_pool_out2['multi_image_hidden']], |
| | dim=-2 |
| | ) |
| | output['multi_image_hidden'] = self.multi_img_hidden_fc( |
| | multi_image_hidden.permute([0, 2, 1]).contiguous() |
| | ).permute([0, 2, 1]).contiguous() |
| |
|
| | |
| | hidden = torch.cat( |
| | [volume_pool_out1['hidden'], volume_pool_out2['hidden'], output['maxpool_hidden']], |
| | dim=-1 |
| | ) |
| | output['hidden'] = self.hidden_fc(hidden) |
| |
|
| | return output |
| |
|
| |
|
| | class SybilPreTrainedModel(PreTrainedModel): |
| | """ |
| | An abstract class to handle weights initialization and a simple interface |
| | for downloading and loading pretrained models. |
| | """ |
| | config_class = SybilConfig |
| | base_model_prefix = "sybil" |
| | supports_gradient_checkpointing = False |
| |
|
| | def _init_weights(self, module): |
| | """Initialize the weights""" |
| | if isinstance(module, nn.Linear): |
| | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
| | if module.bias is not None: |
| | module.bias.data.zero_() |
| | elif isinstance(module, nn.Conv3d): |
| | nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') |
| | if module.bias is not None: |
| | module.bias.data.zero_() |
| |
|
| |
|
| | class SybilForRiskPrediction(SybilPreTrainedModel): |
| | """ |
| | Sybil model for lung cancer risk prediction from CT scans. |
| | |
| | This model takes 3D CT scan volumes as input and predicts cancer risk scores |
| | for multiple future time points (typically 1-6 years). |
| | """ |
| |
|
| | def __init__(self, config: SybilConfig): |
| | super().__init__(config) |
| | self.config = config |
| |
|
| | |
| | encoder = torchvision.models.video.r3d_18(pretrained=True) |
| | self.image_encoder = nn.Sequential(*list(encoder.children())[:-2]) |
| |
|
| | |
| | self.pool = MultiAttentionPool(channels=512) |
| |
|
| | |
| | self.relu = nn.ReLU(inplace=False) |
| | self.dropout = nn.Dropout(p=config.dropout) |
| |
|
| | |
| | self.prob_of_failure_layer = CumulativeProbabilityLayer( |
| | config.hidden_dim, |
| | max_followup=config.max_followup |
| | ) |
| |
|
| | |
| | self.calibrator = None |
| | if config.calibrator_data: |
| | self.set_calibrator(config.calibrator_data) |
| |
|
| | |
| | self.post_init() |
| |
|
| | def set_calibrator(self, calibrator_data: Dict): |
| | """Set calibration data for risk score adjustment""" |
| | self.calibrator = calibrator_data |
| |
|
| | def _calibrate_scores(self, scores: torch.Tensor) -> torch.Tensor: |
| | """Apply calibration to raw risk scores""" |
| | if self.calibrator is None: |
| | return scores |
| |
|
| | |
| | scores_np = scores.detach().cpu().numpy() |
| | calibrated = np.zeros_like(scores_np) |
| |
|
| | |
| | for year in range(scores_np.shape[1]): |
| | year_key = f"Year{year + 1}" |
| | if year_key in self.calibrator: |
| | |
| | calibrated[:, year] = self._apply_calibration( |
| | scores_np[:, year], |
| | self.calibrator[year_key] |
| | ) |
| | else: |
| | calibrated[:, year] = scores_np[:, year] |
| |
|
| | return torch.from_numpy(calibrated).to(scores.device) |
| |
|
| | def _apply_calibration(self, scores: np.ndarray, calibrator_params: Dict) -> np.ndarray: |
| | """Apply specific calibration transformation""" |
| | |
| | |
| | return scores |
| |
|
| | def forward( |
| | self, |
| | pixel_values: torch.FloatTensor, |
| | return_attentions: bool = False, |
| | return_dict: bool = True, |
| | ) -> SybilOutput: |
| | """ |
| | Forward pass of the Sybil model. |
| | |
| | Args: |
| | pixel_values: (`torch.FloatTensor` of shape `(batch_size, channels, depth, height, width)`): |
| | Pixel values of CT scan volumes. |
| | return_attentions: (`bool`, *optional*, defaults to `False`): |
| | Whether to return attention weights. |
| | return_dict: (`bool`, *optional*, defaults to `True`): |
| | Whether to return a `SybilOutput` instead of a plain tuple. |
| | |
| | Returns: |
| | `SybilOutput` or tuple |
| | """ |
| | |
| | features = self.image_encoder(pixel_values) |
| |
|
| | |
| | pool_output = self.pool(features) |
| |
|
| | |
| | hidden = self.relu(pool_output['hidden']) |
| | hidden = self.dropout(hidden) |
| |
|
| | |
| | risk_logits = self.prob_of_failure_layer(hidden) |
| | risk_scores = torch.sigmoid(risk_logits) |
| |
|
| | |
| | risk_scores = self._calibrate_scores(risk_scores) |
| |
|
| | if not return_dict: |
| | outputs = (risk_scores,) |
| | if return_attentions: |
| | outputs = outputs + (pool_output.get('image_attention_1'), |
| | pool_output.get('volume_attention_1')) |
| | return outputs |
| |
|
| | return SybilOutput( |
| | risk_scores=risk_scores, |
| | image_attention=pool_output.get('image_attention_1') if return_attentions else None, |
| | volume_attention=pool_output.get('volume_attention_1') if return_attentions else None, |
| | hidden_states=hidden if return_attentions else None |
| | ) |
| |
|
| | @classmethod |
| | def from_pretrained_ensemble( |
| | cls, |
| | pretrained_model_name_or_path, |
| | checkpoint_paths: List[str], |
| | calibrator_path: Optional[str] = None, |
| | **kwargs |
| | ): |
| | """ |
| | Load an ensemble of Sybil models from checkpoints. |
| | |
| | Args: |
| | pretrained_model_name_or_path: Path to the pretrained model or model identifier. |
| | checkpoint_paths: List of paths to individual model checkpoints. |
| | calibrator_path: Path to calibration data. |
| | **kwargs: Additional keyword arguments for model initialization. |
| | |
| | Returns: |
| | SybilEnsemble: An ensemble of Sybil models. |
| | """ |
| | config = kwargs.pop("config", None) |
| | if config is None: |
| | config = SybilConfig.from_pretrained(pretrained_model_name_or_path) |
| |
|
| | |
| | calibrator_data = None |
| | if calibrator_path: |
| | import json |
| | with open(calibrator_path, 'r') as f: |
| | calibrator_data = json.load(f) |
| | config.calibrator_data = calibrator_data |
| |
|
| | |
| | models = [] |
| | for checkpoint_path in checkpoint_paths: |
| | model = cls(config) |
| | |
| | checkpoint = torch.load(checkpoint_path, map_location='cpu') |
| | |
| | state_dict = {} |
| | for k, v in checkpoint['state_dict'].items(): |
| | if k.startswith('model.'): |
| | state_dict[k[6:]] = v |
| | else: |
| | state_dict[k] = v |
| |
|
| | |
| | mapped_state_dict = model._map_checkpoint_weights(state_dict) |
| | model.load_state_dict(mapped_state_dict, strict=False) |
| | models.append(model) |
| |
|
| | return SybilEnsemble(models, config) |
| |
|
| | def _map_checkpoint_weights(self, state_dict: Dict) -> Dict: |
| | """Map original Sybil checkpoint weights to new structure""" |
| | mapped = {} |
| |
|
| | |
| | for k, v in state_dict.items(): |
| | if k.startswith('image_encoder'): |
| | mapped[k] = v |
| | elif k.startswith('pool'): |
| | |
| | mapped[k] = v |
| | elif k.startswith('prob_of_failure_layer'): |
| | |
| | mapped[k] = v |
| |
|
| | return mapped |
| |
|
| |
|
| | class SybilEnsemble: |
| | """Ensemble of Sybil models for improved predictions""" |
| |
|
| | def __init__(self, models: List[SybilForRiskPrediction], config: SybilConfig): |
| | self.models = models |
| | self.config = config |
| | self.device = None |
| |
|
| | def to(self, device): |
| | """Move all models to device""" |
| | self.device = device |
| | for model in self.models: |
| | model.to(device) |
| | return self |
| |
|
| | def eval(self): |
| | """Set all models to evaluation mode""" |
| | for model in self.models: |
| | model.eval() |
| |
|
| | def __call__( |
| | self, |
| | pixel_values: torch.FloatTensor, |
| | return_attentions: bool = False, |
| | ) -> SybilOutput: |
| | """ |
| | Run inference with ensemble voting. |
| | |
| | Args: |
| | pixel_values: Input CT scan volumes. |
| | return_attentions: Whether to return attention maps. |
| | |
| | Returns: |
| | SybilOutput with averaged predictions from all models. |
| | """ |
| | all_risk_scores = [] |
| | all_image_attentions = [] |
| | all_volume_attentions = [] |
| |
|
| | with torch.no_grad(): |
| | for model in self.models: |
| | output = model( |
| | pixel_values=pixel_values, |
| | return_attentions=return_attentions |
| | ) |
| | all_risk_scores.append(output.risk_scores) |
| |
|
| | if return_attentions: |
| | all_image_attentions.append(output.image_attention) |
| | all_volume_attentions.append(output.volume_attention) |
| |
|
| | |
| | risk_scores = torch.stack(all_risk_scores).mean(dim=0) |
| |
|
| | |
| | image_attention = None |
| | volume_attention = None |
| | if return_attentions: |
| | image_attention = torch.stack(all_image_attentions).mean(dim=0) |
| | volume_attention = torch.stack(all_volume_attentions).mean(dim=0) |
| |
|
| | return SybilOutput( |
| | risk_scores=risk_scores, |
| | image_attention=image_attention, |
| | volume_attention=volume_attention |
| | ) |