| """ |
| Marketing Mix Model Diffusion (MMM-Diffusion) |
| ============================================== |
| A generative diffusion model for Marketing Mix Modeling, adapted from NVIDIA's |
| Kimodo dual-denoiser architecture (GMD/MDM family). |
| |
| Architecture mapping: |
| Kimodo → MMM-Diffusion |
| ------- --------------- |
| Text prompts → Media spend, non-marketing vars, total sales |
| Motion/position constraints → Sign constraints (β_media ≥ 0) + prior constraints |
| Root denoiser (trajectory) → Campaign/Geo-level denoiser (aggregate patterns) |
| Body denoiser (joint rotations)→ Channel-level denoiser (per-channel coefficients) |
| Skeleton positions/rotations → Time-varying coefficients for sales decomposition |
| |
| References: |
| - GMD (arxiv:2305.12577) — Two-stage trajectory + body diffusion |
| - MDM (arxiv:2209.14916) — Transformer denoiser, x₀-prediction, geometric losses |
| - PhysDiff (arxiv:2212.02500) — Projection during denoising for constraints |
| - PDM (arxiv:2402.03559) — Projected diffusion for hard constraint satisfaction |
| - NNN (arxiv:2504.06212) — Neural network MMM architecture from Google |
| """ |
|
|
| import math |
| import json |
| import os |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import Dataset, DataLoader |
| from scipy.signal import lfilter |
| import matplotlib |
| matplotlib.use('Agg') |
| import matplotlib.pyplot as plt |
|
|
|
|
| |
| |
| |
|
|
| class MMMDataGenerator: |
| """ |
| Generate synthetic Marketing Mix Model data with known ground truth. |
| |
| Model: Sales_t = β_base_t + Σ_m β_m_t * Hill(Adstock(spend_m,t; α_m); ec50_m, k_m) |
| + Σ_c β_c * ctrl_c,t + ε_t |
| |
| Channels: 5 media channels + 3 non-marketing (control) variables |
| |
| Based on NNN paper (arxiv:2504.06212) simulation recipe and Meridian framework. |
| """ |
| |
| MEDIA_CHANNELS = ['TV', 'Digital', 'Social', 'Print', 'Radio'] |
| CONTROL_VARS = ['Seasonality', 'Trend', 'Competitor_Price'] |
| |
| def __init__(self, n_weeks=104, n_geos=1, seed=None): |
| self.n_weeks = n_weeks |
| self.n_geos = n_geos |
| self.n_media = 5 |
| self.n_ctrl = 3 |
| self.rng = np.random.RandomState(seed) |
| |
| def _generate_media_spend(self): |
| """Generate realistic media spend patterns with seasonality and campaigns.""" |
| spend = np.zeros((self.n_weeks, self.n_media)) |
| t = np.arange(self.n_weeks) |
| |
| |
| base_levels = self.rng.uniform(50, 500, size=self.n_media) |
| |
| for m in range(self.n_media): |
| |
| base = base_levels[m] * (1 + 0.2 * np.sin(2 * np.pi * t / 52)) |
| |
| |
| n_campaigns = self.rng.randint(3, 9) |
| for _ in range(n_campaigns): |
| start = self.rng.randint(0, self.n_weeks - 4) |
| duration = self.rng.randint(1, 6) |
| intensity = self.rng.uniform(1.5, 4.0) |
| end = min(start + duration, self.n_weeks) |
| base[start:end] *= intensity |
| |
| |
| spend[:, m] = np.maximum(base + self.rng.normal(0, base_levels[m] * 0.1, self.n_weeks), 0) |
| |
| return spend |
| |
| def _adstock(self, x, alpha): |
| """Geometric adstock transformation. α ∈ [0,1] is retention rate.""" |
| result = np.zeros_like(x) |
| result[0] = x[0] |
| for t in range(1, len(x)): |
| result[t] = x[t] + alpha * result[t-1] |
| return result |
| |
| def _hill(self, x, ec50, slope): |
| """Hill saturation function. ec50>0, slope>0.""" |
| x_safe = np.maximum(x, 0) |
| return x_safe**slope / (x_safe**slope + ec50**slope + 1e-10) |
| |
| def _generate_controls(self): |
| """Generate control (non-marketing) variables.""" |
| t = np.arange(self.n_weeks) |
| controls = np.zeros((self.n_weeks, self.n_ctrl)) |
| |
| |
| controls[:, 0] = (np.sin(2 * np.pi * t / 52) + |
| 0.5 * np.sin(4 * np.pi * t / 52) + |
| 0.3 * np.cos(2 * np.pi * t / 52)) |
| |
| |
| trend = t / self.n_weeks |
| controls[:, 1] = trend + 0.5 * trend**2 |
| |
| |
| price = np.zeros(self.n_weeks) |
| price[0] = 1.0 |
| for i in range(1, self.n_weeks): |
| price[i] = 0.95 * price[i-1] + 0.05 * 1.0 + self.rng.normal(0, 0.05) |
| controls[:, 2] = price |
| |
| return controls |
| |
| def _sample_true_params(self): |
| """Sample ground truth MMM parameters from realistic priors.""" |
| params = {} |
| |
| |
| |
| params['beta_media'] = np.abs(self.rng.normal(0, 0.5, self.n_media)) + 0.05 |
| |
| |
| |
| params['adstock_alpha'] = self.rng.beta(2, 2, self.n_media) |
| params['adstock_alpha'] = np.clip(params['adstock_alpha'], 0.1, 0.95) |
| |
| |
| |
| params['hill_ec50'] = np.abs(self.rng.lognormal(0, 0.5, self.n_media)) + 0.1 |
| |
| |
| params['hill_slope'] = self.rng.uniform(0.5, 3.0, self.n_media) |
| |
| |
| params['beta_base'] = self.rng.uniform(500, 2000) |
| |
| |
| params['beta_ctrl'] = self.rng.normal(0, 50, self.n_ctrl) |
| |
| |
| params['noise_std'] = self.rng.uniform(20, 100) |
| |
| return params |
| |
| def _make_time_varying(self, base_coeff, n_weeks, volatility=0.1): |
| """ |
| Make a coefficient time-varying via a random walk with mean reversion. |
| β_t = β * exp(z_t) where z_t follows an OU process. |
| """ |
| z = np.zeros(n_weeks) |
| for t in range(1, n_weeks): |
| z[t] = 0.9 * z[t-1] + self.rng.normal(0, volatility) |
| return base_coeff * np.exp(z) |
| |
| def generate_single(self): |
| """ |
| Generate a single MMM dataset with known ground truth. |
| |
| Returns: |
| dict with keys: |
| - media_spend: (T, 5) raw media spend |
| - controls: (T, 3) control variables |
| - total_sales: (T,) total sales |
| - true_coefficients: (T, 8) time-varying coefficients [5 media + 3 ctrl] |
| - true_contributions: (T, 8) sales contribution per variable |
| - true_params: dict of ground truth parameters |
| """ |
| spend = self._generate_media_spend() |
| controls = self._generate_controls() |
| params = self._sample_true_params() |
| |
| |
| transformed_media = np.zeros_like(spend) |
| for m in range(self.n_media): |
| adstocked = self._adstock(spend[:, m], params['adstock_alpha'][m]) |
| |
| adstocked_norm = adstocked / (np.percentile(adstocked, 90) + 1e-10) |
| transformed_media[:, m] = self._hill( |
| adstocked_norm, params['hill_ec50'][m], params['hill_slope'][m] |
| ) |
| |
| |
| tv_coeffs = np.zeros((self.n_weeks, self.n_media + self.n_ctrl)) |
| |
| |
| for m in range(self.n_media): |
| tv_coeffs[:, m] = self._make_time_varying( |
| params['beta_media'][m], self.n_weeks, volatility=0.05 |
| ) |
| tv_coeffs[:, m] = np.maximum(tv_coeffs[:, m], 0.01) |
| |
| |
| for c in range(self.n_ctrl): |
| tv_coeffs[:, self.n_media + c] = self._make_time_varying( |
| params['beta_ctrl'][c], self.n_weeks, volatility=0.03 |
| ) |
| |
| |
| contributions = np.zeros((self.n_weeks, self.n_media + self.n_ctrl)) |
| for m in range(self.n_media): |
| contributions[:, m] = tv_coeffs[:, m] * transformed_media[:, m] |
| for c in range(self.n_ctrl): |
| contributions[:, self.n_media + c] = tv_coeffs[:, self.n_media + c] * controls[:, c] |
| |
| |
| base = params['beta_base'] |
| noise = self.rng.normal(0, params['noise_std'], self.n_weeks) |
| total_sales = base + contributions.sum(axis=1) + noise |
| total_sales = np.maximum(total_sales, 0) |
| |
| return { |
| 'media_spend': spend, |
| 'controls': controls, |
| 'total_sales': total_sales, |
| 'true_coefficients': tv_coeffs, |
| 'true_contributions': contributions, |
| 'base_sales': np.full(self.n_weeks, base), |
| 'true_params': params |
| } |
| |
| def generate_dataset(self, n_samples): |
| """Generate n_samples MMM instances.""" |
| samples = [] |
| for i in range(n_samples): |
| self.rng = np.random.RandomState(self.rng.randint(0, 2**31)) |
| samples.append(self.generate_single()) |
| return samples |
|
|
|
|
| |
| |
| |
|
|
| class MMMDiffusionDataset(Dataset): |
| """ |
| Wraps generated MMM data for diffusion training. |
| |
| Each sample provides: |
| - conditioning: (T, n_media + n_ctrl + 1) — [media_spend, controls, total_sales] |
| - target: (T, n_media + n_ctrl) — time-varying coefficients to denoise |
| - media_mask: boolean mask for media channels (positivity constraint) |
| """ |
| |
| def __init__(self, samples, normalize=True): |
| self.samples = samples |
| self.normalize = normalize |
| self.n_media = 5 |
| self.n_ctrl = 3 |
| self.n_channels = self.n_media + self.n_ctrl |
| |
| |
| if normalize: |
| all_cond = np.stack([ |
| np.concatenate([s['media_spend'], s['controls'], s['total_sales'][:, None]], axis=1) |
| for s in samples |
| ]) |
| all_coeff = np.stack([s['true_coefficients'] for s in samples]) |
| |
| self.cond_mean = all_cond.mean(axis=(0, 1)) |
| self.cond_std = all_cond.std(axis=(0, 1)) + 1e-8 |
| self.coeff_mean = all_coeff.mean(axis=(0, 1)) |
| self.coeff_std = all_coeff.std(axis=(0, 1)) + 1e-8 |
| |
| |
| |
| media_coeffs = all_coeff[:, :, :self.n_media] |
| self.media_log_mean = np.log(media_coeffs + 1e-8).mean(axis=(0, 1)) |
| self.media_log_std = np.log(media_coeffs + 1e-8).std(axis=(0, 1)) + 1e-8 |
| |
| def __len__(self): |
| return len(self.samples) |
| |
| def __getitem__(self, idx): |
| s = self.samples[idx] |
| |
| |
| cond = np.concatenate([ |
| s['media_spend'], s['controls'], s['total_sales'][:, None] |
| ], axis=1).astype(np.float32) |
| |
| |
| coeffs = s['true_coefficients'].astype(np.float32) |
| |
| if self.normalize: |
| cond = (cond - self.cond_mean) / self.cond_std |
| |
| |
| log_media = np.log(coeffs[:, :self.n_media] + 1e-8) |
| log_media = (log_media - self.media_log_mean) / self.media_log_std |
| |
| |
| ctrl = (coeffs[:, self.n_media:] - self.coeff_mean[self.n_media:]) / self.coeff_std[self.n_media:] |
| |
| coeffs = np.concatenate([log_media, ctrl], axis=1) |
| |
| return { |
| 'conditioning': torch.tensor(cond, dtype=torch.float32), |
| 'coefficients': torch.tensor(coeffs, dtype=torch.float32), |
| } |
| |
| def decode_coefficients(self, coeffs_normalized): |
| """ |
| Inverse-transform normalized coefficients back to original scale. |
| Applies exp() to media channels to enforce positivity. |
| |
| Args: |
| coeffs_normalized: (batch, T, 8) normalized coefficients |
| Returns: |
| coeffs: (batch, T, 8) original-scale coefficients (media ≥ 0) |
| """ |
| if not self.normalize: |
| return coeffs_normalized |
| |
| coeffs = coeffs_normalized.clone() |
| |
| |
| media_log_mean = torch.tensor(self.media_log_mean, device=coeffs.device, dtype=coeffs.dtype) |
| media_log_std = torch.tensor(self.media_log_std, device=coeffs.device, dtype=coeffs.dtype) |
| coeffs[:, :, :self.n_media] = torch.exp( |
| coeffs[:, :, :self.n_media] * media_log_std + media_log_mean |
| ) |
| |
| |
| coeff_mean = torch.tensor(self.coeff_mean[self.n_media:], device=coeffs.device, dtype=coeffs.dtype) |
| coeff_std = torch.tensor(self.coeff_std[self.n_media:], device=coeffs.device, dtype=coeffs.dtype) |
| coeffs[:, :, self.n_media:] = coeffs[:, :, self.n_media:] * coeff_std + coeff_mean |
| |
| return coeffs |
|
|
|
|
| |
| |
| |
|
|
| def cosine_beta_schedule(T, s=0.008): |
| """Cosine noise schedule from 'Improved DDPM' (Nichol & Dhariwal, 2021).""" |
| t = torch.arange(T + 1, dtype=torch.float64) |
| f = torch.cos((t / T + s) / (1 + s) * math.pi / 2) ** 2 |
| alphas_cumprod = f / f[0] |
| betas = 1 - alphas_cumprod[1:] / alphas_cumprod[:-1] |
| return torch.clamp(betas, 0, 0.999).float() |
|
|
|
|
| class DiffusionSchedule: |
| """DDPM diffusion schedule with cosine noise.""" |
| |
| def __init__(self, T=1000): |
| self.T = T |
| self.betas = cosine_beta_schedule(T) |
| self.alphas = 1.0 - self.betas |
| self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) |
| self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) |
| self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod) |
| self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas) |
| |
| |
| self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0) |
| self.posterior_variance = ( |
| self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) |
| ) |
| self.posterior_log_variance_clipped = torch.log( |
| torch.clamp(self.posterior_variance, min=1e-20) |
| ) |
| self.posterior_mean_coef1 = ( |
| self.betas * torch.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) |
| ) |
| self.posterior_mean_coef2 = ( |
| (1.0 - self.alphas_cumprod_prev) * torch.sqrt(self.alphas) / (1.0 - self.alphas_cumprod) |
| ) |
| |
| def to(self, device): |
| """Move all tensors to device.""" |
| for attr in ['betas', 'alphas', 'alphas_cumprod', 'sqrt_alphas_cumprod', |
| 'sqrt_one_minus_alphas_cumprod', 'sqrt_recip_alphas', |
| 'alphas_cumprod_prev', 'posterior_variance', |
| 'posterior_log_variance_clipped', 'posterior_mean_coef1', |
| 'posterior_mean_coef2']: |
| setattr(self, attr, getattr(self, attr).to(device)) |
| return self |
| |
| def q_sample(self, x_0, t, noise=None): |
| """Forward diffusion: q(x_t | x_0).""" |
| if noise is None: |
| noise = torch.randn_like(x_0) |
| sqrt_alpha = self.sqrt_alphas_cumprod[t].view(-1, 1, 1) |
| sqrt_one_minus_alpha = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1) |
| return sqrt_alpha * x_0 + sqrt_one_minus_alpha * noise |
| |
| def posterior_mean(self, x_0_pred, x_t, t): |
| """Compute posterior mean q(x_{t-1} | x_t, x_0_pred).""" |
| coef1 = self.posterior_mean_coef1[t].view(-1, 1, 1) |
| coef2 = self.posterior_mean_coef2[t].view(-1, 1, 1) |
| return coef1 * x_0_pred + coef2 * x_t |
|
|
|
|
| |
| |
| |
|
|
| class SinusoidalPositionEmbeddings(nn.Module): |
| """Sinusoidal embeddings for diffusion timestep t.""" |
| def __init__(self, dim): |
| super().__init__() |
| self.dim = dim |
| |
| def forward(self, t): |
| device = t.device |
| half_dim = self.dim // 2 |
| emb = math.log(10000) / (half_dim - 1) |
| emb = torch.exp(torch.arange(half_dim, device=device) * -emb) |
| emb = t[:, None].float() * emb[None, :] |
| return torch.cat([emb.sin(), emb.cos()], dim=-1) |
|
|
|
|
| class TemporalTransformerBlock(nn.Module): |
| """Transformer block for temporal attention over time steps.""" |
| def __init__(self, d_model, nhead, dropout=0.1): |
| super().__init__() |
| self.attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True) |
| self.ff = nn.Sequential( |
| nn.Linear(d_model, d_model * 4), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(d_model * 4, d_model), |
| nn.Dropout(dropout), |
| ) |
| self.norm1 = nn.LayerNorm(d_model) |
| self.norm2 = nn.LayerNorm(d_model) |
| |
| def forward(self, x): |
| |
| h = self.norm1(x) |
| h = x + self.attn(h, h, h)[0] |
| h = h + self.ff(self.norm2(h)) |
| return h |
|
|
|
|
| class CampaignDenoiser(nn.Module): |
| """ |
| Stage 1: Campaign/Geo-level Denoiser |
| |
| Analogous to GMD's trajectory DPM / Kimodo's root denoiser. |
| Denoises aggregate-level patterns conditioned on non-marketing vars + total sales. |
| |
| Predicts x_0 directly (not noise ε), enabling constraint projection at each step. |
| |
| Input: x_t (B, T, n_agg) — noisy aggregate coefficients |
| Cond: (B, T, cond_dim) — non-marketing vars + total sales |
| Output: x_0_hat (B, T, n_agg) — predicted clean aggregate coefficients |
| """ |
| |
| def __init__(self, n_agg_channels=3, cond_dim=4, d_model=256, nhead=4, n_layers=4, T_diff=1000): |
| super().__init__() |
| self.d_model = d_model |
| |
| |
| self.time_embed = nn.Sequential( |
| SinusoidalPositionEmbeddings(d_model), |
| nn.Linear(d_model, d_model), |
| nn.GELU(), |
| nn.Linear(d_model, d_model), |
| ) |
| |
| |
| self.cond_proj = nn.Linear(cond_dim, d_model) |
| |
| |
| self.input_proj = nn.Linear(n_agg_channels, d_model) |
| |
| |
| self.pos_embed = nn.Parameter(torch.randn(1, 256, d_model) * 0.02) |
| |
| |
| self.blocks = nn.ModuleList([ |
| TemporalTransformerBlock(d_model, nhead) for _ in range(n_layers) |
| ]) |
| |
| |
| self.output_proj = nn.Sequential( |
| nn.LayerNorm(d_model), |
| nn.Linear(d_model, d_model // 2), |
| nn.GELU(), |
| nn.Linear(d_model // 2, n_agg_channels), |
| ) |
| |
| def forward(self, x_t, t, cond): |
| B, T_seq, _ = x_t.shape |
| |
| |
| t_emb = self.time_embed(t) |
| |
| |
| h_x = self.input_proj(x_t) |
| h_c = self.cond_proj(cond) |
| |
| |
| h = h_x + h_c + t_emb.unsqueeze(1) + self.pos_embed[:, :T_seq, :] |
| |
| |
| for block in self.blocks: |
| h = block(h) |
| |
| return self.output_proj(h) |
|
|
|
|
| class ChannelDenoiser(nn.Module): |
| """ |
| Stage 2: Channel-level Denoiser |
| |
| Analogous to GMD's full-body DPM / Kimodo's body denoiser. |
| Denoises per-channel time-varying coefficients, conditioned on: |
| - Stage 1 output (aggregate patterns) |
| - Media spend data |
| - Total sales |
| |
| Predicts x_0 directly for constraint projection. |
| |
| Input: x_t (B, T, n_channels) — noisy channel coefficients |
| Cond: campaign_ctx (B, T, n_agg) — from Stage 1 |
| media_spend (B, T, n_media) — raw media spend |
| total_sales (B, T, 1) — total sales |
| Output: x_0_hat (B, T, n_channels) — predicted clean coefficients |
| """ |
| |
| def __init__(self, n_channels=8, n_media=5, n_agg=3, d_model=384, nhead=8, n_layers=6, T_diff=1000): |
| super().__init__() |
| self.d_model = d_model |
| self.n_media = n_media |
| self.n_channels = n_channels |
| |
| |
| self.time_embed = nn.Sequential( |
| SinusoidalPositionEmbeddings(d_model), |
| nn.Linear(d_model, d_model), |
| nn.GELU(), |
| nn.Linear(d_model, d_model), |
| ) |
| |
| |
| self.input_proj = nn.Linear(n_channels, d_model) |
| |
| |
| self.campaign_proj = nn.Linear(n_agg, d_model) |
| self.spend_proj = nn.Linear(n_media, d_model) |
| self.sales_proj = nn.Linear(1, d_model) |
| |
| |
| self.cross_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True) |
| self.cross_norm = nn.LayerNorm(d_model) |
| |
| |
| self.pos_embed = nn.Parameter(torch.randn(1, 256, d_model) * 0.02) |
| |
| |
| self.blocks = nn.ModuleList([ |
| TemporalTransformerBlock(d_model, nhead) for _ in range(n_layers) |
| ]) |
| |
| |
| self.output_proj = nn.Sequential( |
| nn.LayerNorm(d_model), |
| nn.Linear(d_model, d_model // 2), |
| nn.GELU(), |
| nn.Linear(d_model // 2, n_channels), |
| ) |
| |
| def forward(self, x_t, t, campaign_ctx, media_spend, total_sales): |
| B, T_seq, _ = x_t.shape |
| |
| |
| t_emb = self.time_embed(t) |
| |
| |
| h_x = self.input_proj(x_t) |
| h_camp = self.campaign_proj(campaign_ctx) |
| h_spend = self.spend_proj(media_spend) |
| h_sales = self.sales_proj(total_sales) |
| |
| |
| cond_ctx = h_camp + h_spend + h_sales |
| |
| |
| h = h_x + t_emb.unsqueeze(1) + self.pos_embed[:, :T_seq, :] |
| |
| |
| h_normed = self.cross_norm(h) |
| h = h + self.cross_attn(h_normed, cond_ctx, cond_ctx)[0] |
| |
| |
| for block in self.blocks: |
| h = block(h) |
| |
| return self.output_proj(h) |
|
|
|
|
| |
| |
| |
|
|
| class MMMDiffusionModel(nn.Module): |
| """ |
| Full MMM-Diffusion model with dual denoiser architecture. |
| |
| Stage 1 (Campaign Denoiser): Denoises aggregate patterns from non-mktg + sales |
| Stage 2 (Channel Denoiser): Denoises per-channel coefficients conditioned on Stage 1 |
| |
| Constraint enforcement: |
| 1. Log-space reparametrization: media coefficients in log-space during training |
| 2. PhysDiff-style projection: clamp during denoising every K steps |
| 3. Soft loss penalties: L_sign, L_sales, L_smooth |
| """ |
| |
| def __init__(self, n_media=5, n_ctrl=3, d_model_campaign=256, d_model_channel=384, |
| n_layers_campaign=4, n_layers_channel=6, T_diff=1000): |
| super().__init__() |
| self.n_media = n_media |
| self.n_ctrl = n_ctrl |
| self.n_channels = n_media + n_ctrl |
| self.T_diff = T_diff |
| |
| |
| self.n_agg = 3 |
| |
| |
| self.campaign_denoiser = CampaignDenoiser( |
| n_agg_channels=self.n_agg, |
| cond_dim=n_ctrl + 1, |
| d_model=d_model_campaign, |
| nhead=4, |
| n_layers=n_layers_campaign, |
| T_diff=T_diff, |
| ) |
| |
| |
| self.channel_denoiser = ChannelDenoiser( |
| n_channels=self.n_channels, |
| n_media=n_media, |
| n_agg=self.n_agg, |
| d_model=d_model_channel, |
| nhead=8, |
| n_layers=n_layers_channel, |
| T_diff=T_diff, |
| ) |
| |
| |
| self.coeff_to_agg = nn.Linear(self.n_channels, self.n_agg) |
| self.agg_to_coeff_init = nn.Linear(self.n_agg, self.n_channels) |
| |
| |
| self.schedule = DiffusionSchedule(T_diff) |
| |
| def compute_aggregate(self, coefficients): |
| """Project full coefficients to aggregate representation for Stage 1.""" |
| return self.coeff_to_agg(coefficients) |
| |
| def forward_train(self, batch): |
| """ |
| Training forward pass. Uses x_0-prediction (predicts clean data, not noise). |
| |
| Returns dict of losses. |
| """ |
| cond = batch['conditioning'] |
| coeffs = batch['coefficients'] |
| |
| B, T_seq, _ = coeffs.shape |
| device = coeffs.device |
| |
| |
| media_spend = cond[:, :, :self.n_media] |
| controls = cond[:, :, self.n_media:self.n_media + self.n_ctrl] |
| total_sales = cond[:, :, -1:] |
| stage1_cond = torch.cat([controls, total_sales], dim=-1) |
| |
| |
| with torch.no_grad(): |
| agg_target = self.coeff_to_agg(coeffs) |
| |
| |
| t1 = torch.randint(0, self.T_diff, (B,), device=device) |
| noise1 = torch.randn_like(agg_target) |
| agg_noisy = self.schedule.q_sample(agg_target, t1, noise1) |
| agg_pred = self.campaign_denoiser(agg_noisy, t1, stage1_cond) |
| |
| |
| loss_campaign = F.mse_loss(agg_pred, agg_target) |
| |
| |
| t2 = torch.randint(0, self.T_diff, (B,), device=device) |
| noise2 = torch.randn_like(coeffs) |
| coeffs_noisy = self.schedule.q_sample(coeffs, t2, noise2) |
| |
| |
| campaign_ctx = agg_target.detach() |
| |
| coeffs_pred = self.channel_denoiser( |
| coeffs_noisy, t2, campaign_ctx, media_spend, total_sales |
| ) |
| |
| |
| loss_channel = F.mse_loss(coeffs_pred, coeffs) |
| |
| |
| |
| |
| delta_pred = coeffs_pred[:, 1:, :] - coeffs_pred[:, :-1, :] |
| delta_true = coeffs[:, 1:, :] - coeffs[:, :-1, :] |
| loss_smooth = F.mse_loss(delta_pred, delta_true) |
| |
| |
| |
| |
| media_pred_log = coeffs_pred[:, :, :self.n_media] |
| loss_sign = F.relu(-media_pred_log - 5.0).mean() |
| |
| |
| |
| loss_sales = 0.0 |
| |
| |
| loss = ( |
| 1.0 * loss_campaign + |
| 1.0 * loss_channel + |
| 0.1 * loss_smooth + |
| 0.01 * loss_sign |
| ) |
| |
| return { |
| 'loss': loss, |
| 'loss_campaign': loss_campaign.item(), |
| 'loss_channel': loss_channel.item(), |
| 'loss_smooth': loss_smooth.item(), |
| 'loss_sign': loss_sign.item(), |
| } |
| |
| @torch.no_grad() |
| def sample(self, conditioning, n_steps=None, constraint_every_k=10, guidance_scale=1.0): |
| """ |
| Generate time-varying coefficients via dual-denoiser reverse diffusion. |
| |
| Uses PhysDiff-style projection every K steps for constraint enforcement. |
| |
| Args: |
| conditioning: (B, T, 9) — [media_spend, controls, total_sales] |
| n_steps: number of denoising steps (None = full T) |
| constraint_every_k: apply hard constraints every K steps |
| guidance_scale: classifier-free guidance strength |
| |
| Returns: |
| coefficients: (B, T, 8) — predicted time-varying coefficients (normalized) |
| """ |
| B, T_seq, _ = conditioning.shape |
| device = conditioning.device |
| |
| |
| media_spend = conditioning[:, :, :self.n_media] |
| controls = conditioning[:, :, self.n_media:self.n_media + self.n_ctrl] |
| total_sales = conditioning[:, :, -1:] |
| stage1_cond = torch.cat([controls, total_sales], dim=-1) |
| |
| T_diff = n_steps or self.T_diff |
| |
| |
| z_t = torch.randn(B, T_seq, self.n_agg, device=device) |
| |
| for t in reversed(range(T_diff)): |
| t_batch = torch.full((B,), t, device=device, dtype=torch.long) |
| |
| |
| z_0_pred = self.campaign_denoiser(z_t, t_batch, stage1_cond) |
| |
| if t > 0: |
| |
| mean = self.schedule.posterior_mean(z_0_pred, z_t, t_batch) |
| var = self.schedule.posterior_variance[t] |
| noise = torch.randn_like(z_t) |
| z_t = mean + torch.sqrt(var) * noise |
| else: |
| z_t = z_0_pred |
| |
| campaign_ctx = z_t |
| |
| |
| x_t = torch.randn(B, T_seq, self.n_channels, device=device) |
| |
| for t in reversed(range(T_diff)): |
| t_batch = torch.full((B,), t, device=device, dtype=torch.long) |
| |
| |
| x_0_pred = self.channel_denoiser( |
| x_t, t_batch, campaign_ctx, media_spend, total_sales |
| ) |
| |
| |
| if t % constraint_every_k == 0: |
| |
| |
| x_0_pred[:, :, :self.n_media] = torch.clamp( |
| x_0_pred[:, :, :self.n_media], min=-8.0, max=8.0 |
| ) |
| |
| if t > 0: |
| mean = self.schedule.posterior_mean(x_0_pred, x_t, t_batch) |
| var = self.schedule.posterior_variance[t] |
| noise = torch.randn_like(x_t) |
| x_t = mean + torch.sqrt(var) * noise |
| else: |
| x_t = x_0_pred |
| |
| return x_t |
|
|
|
|
| |
| |
| |
|
|
| def train_mmm_diffusion( |
| model, dataset, |
| n_epochs=50, batch_size=16, lr=1e-4, |
| device='cpu', log_every=50, save_path='mmm_diffusion_model.pt' |
| ): |
| """Train the MMM diffusion model.""" |
| |
| dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True) |
| optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01) |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs) |
| |
| model = model.to(device) |
| model.schedule = model.schedule.to(device) |
| |
| history = {'loss': [], 'loss_campaign': [], 'loss_channel': [], 'loss_smooth': [], 'loss_sign': []} |
| |
| print(f"\nTraining MMM-Diffusion Model") |
| print(f" Device: {device}") |
| print(f" Samples: {len(dataset)}, Batch size: {batch_size}") |
| print(f" Epochs: {n_epochs}, LR: {lr}") |
| print(f" Model params: {sum(p.numel() for p in model.parameters()):,}") |
| print(f" Diffusion steps: {model.T_diff}") |
| print("-" * 60) |
| |
| step = 0 |
| for epoch in range(n_epochs): |
| model.train() |
| epoch_losses = {k: [] for k in history} |
| |
| for batch in dataloader: |
| batch = {k: v.to(device) for k, v in batch.items()} |
| |
| losses = model.forward_train(batch) |
| |
| optimizer.zero_grad() |
| losses['loss'].backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
| |
| for k in history: |
| val = losses[k].item() if isinstance(losses[k], torch.Tensor) else losses[k] |
| epoch_losses[k].append(val) |
| |
| step += 1 |
| if step % log_every == 0: |
| avg = {k: np.mean(v[-log_every:]) for k, v in epoch_losses.items() if v} |
| print(f" Step {step:5d} | loss={avg['loss']:.4f} " |
| f"camp={avg['loss_campaign']:.4f} chan={avg['loss_channel']:.4f} " |
| f"smooth={avg['loss_smooth']:.4f} sign={avg['loss_sign']:.6f}") |
| |
| scheduler.step() |
| |
| |
| avg = {k: np.mean(v) for k, v in epoch_losses.items() if v} |
| for k, v in avg.items(): |
| history[k].append(v) |
| |
| print(f"Epoch {epoch+1:3d}/{n_epochs} | loss={avg['loss']:.4f} " |
| f"camp={avg['loss_campaign']:.4f} chan={avg['loss_channel']:.4f} " |
| f"smooth={avg['loss_smooth']:.4f} sign={avg['loss_sign']:.6f} " |
| f"lr={scheduler.get_last_lr()[0]:.6f}") |
| |
| |
| torch.save({ |
| 'model_state_dict': model.state_dict(), |
| 'history': history, |
| 'config': { |
| 'n_media': model.n_media, |
| 'n_ctrl': model.n_ctrl, |
| 'T_diff': model.T_diff, |
| } |
| }, save_path) |
| print(f"\nModel saved to {save_path}") |
| |
| return history |
|
|
|
|
| |
| |
| |
|
|
| def decompose_sales(coefficients, media_spend, controls, adstock_alphas=None): |
| """ |
| Given predicted time-varying coefficients, decompose sales into contributions. |
| |
| Args: |
| coefficients: (T, 8) — decoded coefficients [5 media + 3 ctrl] |
| media_spend: (T, 5) — raw media spend |
| controls: (T, 3) — control variables |
| adstock_alphas: optional (5,) — adstock retention rates |
| |
| Returns: |
| contributions: dict with per-channel contributions |
| """ |
| T, n_total = coefficients.shape |
| n_media = 5 |
| |
| contributions = {} |
| total_media = np.zeros(T) |
| |
| for m in range(n_media): |
| name = MMMDataGenerator.MEDIA_CHANNELS[m] |
| spend = media_spend[:, m] |
| |
| |
| if adstock_alphas is not None: |
| adstocked = np.zeros(T) |
| adstocked[0] = spend[0] |
| for t in range(1, T): |
| adstocked[t] = spend[t] + adstock_alphas[m] * adstocked[t-1] |
| spend = adstocked |
| |
| |
| |
| contrib = coefficients[:, m] * (spend / (np.percentile(spend, 90) + 1e-10)) |
| contributions[name] = contrib |
| total_media += contrib |
| |
| total_ctrl = np.zeros(T) |
| for c in range(3): |
| name = MMMDataGenerator.CONTROL_VARS[c] |
| contrib = coefficients[:, n_media + c] * controls[:, c] |
| contributions[name] = contrib |
| total_ctrl += contrib |
| |
| contributions['Total_Media'] = total_media |
| contributions['Total_Controls'] = total_ctrl |
| contributions['Predicted_Sales'] = total_media + total_ctrl |
| |
| return contributions |
|
|
|
|
| |
| |
| |
|
|
| def plot_training_history(history, save_path='training_history.png'): |
| """Plot training loss curves.""" |
| fig, axes = plt.subplots(2, 2, figsize=(14, 10)) |
| |
| for ax, (key, values) in zip(axes.flatten(), history.items()): |
| ax.plot(values, linewidth=1.5) |
| ax.set_title(f'{key}', fontsize=12) |
| ax.set_xlabel('Epoch') |
| ax.set_ylabel('Loss') |
| ax.grid(True, alpha=0.3) |
| ax.set_yscale('log' if min(values) > 0 else 'linear') |
| |
| plt.suptitle('MMM-Diffusion Training History', fontsize=14, fontweight='bold') |
| plt.tight_layout() |
| plt.savefig(save_path, dpi=150, bbox_inches='tight') |
| plt.close() |
| print(f"Training history plot saved to {save_path}") |
|
|
|
|
| def plot_coefficient_comparison(true_coeffs, pred_coeffs, channel_names, save_path='coeff_comparison.png'): |
| """Compare true vs predicted time-varying coefficients.""" |
| n_channels = true_coeffs.shape[1] |
| fig, axes = plt.subplots(n_channels, 1, figsize=(14, 2.5 * n_channels)) |
| |
| for i, (ax, name) in enumerate(zip(axes, channel_names)): |
| ax.plot(true_coeffs[:, i], 'b-', label='Ground Truth', linewidth=1.5) |
| ax.plot(pred_coeffs[:, i], 'r--', label='Predicted', linewidth=1.5, alpha=0.8) |
| ax.set_title(f'{name} — Time-Varying Coefficient', fontsize=11) |
| ax.legend(fontsize=9) |
| ax.grid(True, alpha=0.3) |
| if i < 5: |
| ax.axhline(y=0, color='gray', linestyle=':', alpha=0.5) |
| ax.set_ylabel('β (≥0)') |
| else: |
| ax.set_ylabel('β') |
| |
| axes[-1].set_xlabel('Week') |
| plt.suptitle('MMM-Diffusion: Coefficient Prediction Quality', fontsize=14, fontweight='bold') |
| plt.tight_layout() |
| plt.savefig(save_path, dpi=150, bbox_inches='tight') |
| plt.close() |
| print(f"Coefficient comparison plot saved to {save_path}") |
|
|
|
|
| def plot_sales_decomposition(contributions, total_sales, save_path='sales_decomposition.png'): |
| """Plot stacked area chart of sales decomposition.""" |
| fig, axes = plt.subplots(2, 1, figsize=(14, 10)) |
| |
| |
| ax = axes[0] |
| weeks = np.arange(len(total_sales)) |
| media_names = MMMDataGenerator.MEDIA_CHANNELS |
| colors = plt.cm.Set2(np.linspace(0, 1, len(media_names))) |
| |
| bottom = np.zeros(len(total_sales)) |
| for name, color in zip(media_names, colors): |
| vals = np.maximum(contributions[name], 0) |
| ax.fill_between(weeks, bottom, bottom + vals, alpha=0.7, label=name, color=color) |
| bottom += vals |
| |
| ax.plot(weeks, total_sales, 'k-', linewidth=2, label='Total Sales', alpha=0.8) |
| ax.set_title('Sales Decomposition: Media Channel Contributions', fontsize=12) |
| ax.legend(loc='upper left', fontsize=9) |
| ax.set_xlabel('Week') |
| ax.set_ylabel('Sales Contribution') |
| ax.grid(True, alpha=0.3) |
| |
| |
| ax = axes[1] |
| ax.plot(weeks, total_sales, 'b-', linewidth=2, label='Actual Sales') |
| ax.plot(weeks, contributions['Predicted_Sales'], 'r--', linewidth=2, label='Predicted (Media + Controls)', alpha=0.8) |
| ax.set_title('Total Sales: Actual vs Predicted Decomposition', fontsize=12) |
| ax.legend(fontsize=10) |
| ax.set_xlabel('Week') |
| ax.set_ylabel('Sales') |
| ax.grid(True, alpha=0.3) |
| |
| plt.suptitle('MMM-Diffusion: Sales Decomposition', fontsize=14, fontweight='bold') |
| plt.tight_layout() |
| plt.savefig(save_path, dpi=150, bbox_inches='tight') |
| plt.close() |
| print(f"Sales decomposition plot saved to {save_path}") |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| import time |
| |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| print(f"=" * 60) |
| print(f"MMM-DIFFUSION: Marketing Mix Model via Diffusion") |
| print(f"Adapted from Kimodo/GMD dual-denoiser architecture") |
| print(f"Device: {device}") |
| print(f"=" * 60) |
| |
| |
| print("\n[1/5] Generating synthetic MMM data...") |
| t0 = time.time() |
| |
| gen = MMMDataGenerator(n_weeks=104, seed=42) |
| |
| n_train = 500 |
| n_val = 50 |
| |
| train_samples = gen.generate_dataset(n_train) |
| val_samples = gen.generate_dataset(n_val) |
| |
| print(f" Generated {n_train} train + {n_val} val scenarios") |
| print(f" Each: {gen.n_weeks} weeks, {gen.n_media} media channels, {gen.n_ctrl} control vars") |
| print(f" Time: {time.time()-t0:.1f}s") |
| |
| |
| sample = train_samples[0] |
| print(f"\n Data audit (sample 0):") |
| print(f" Media spend shape: {sample['media_spend'].shape}") |
| print(f" Media spend range: [{sample['media_spend'].min():.1f}, {sample['media_spend'].max():.1f}]") |
| print(f" Sales range: [{sample['total_sales'].min():.1f}, {sample['total_sales'].max():.1f}]") |
| print(f" Media coeff range: [{sample['true_coefficients'][:,:5].min():.4f}, {sample['true_coefficients'][:,:5].max():.4f}]") |
| print(f" All media coeffs positive: {(sample['true_coefficients'][:,:5] > 0).all()}") |
| |
| |
| print("\n[2/5] Creating training datasets...") |
| train_dataset = MMMDiffusionDataset(train_samples, normalize=True) |
| val_dataset = MMMDiffusionDataset(val_samples, normalize=True) |
| |
| item = train_dataset[0] |
| print(f" Conditioning shape: {item['conditioning'].shape}") |
| print(f" Coefficients shape: {item['coefficients'].shape}") |
| |
| |
| print("\n[3/5] Building MMM-Diffusion model...") |
| |
| |
| T_DIFF = 200 if device == 'cpu' else 500 |
| |
| model = MMMDiffusionModel( |
| n_media=5, n_ctrl=3, |
| d_model_campaign=128, |
| d_model_channel=192, |
| n_layers_campaign=3, |
| n_layers_channel=4, |
| T_diff=T_DIFF, |
| ) |
| |
| total_params = sum(p.numel() for p in model.parameters()) |
| print(f" Total parameters: {total_params:,}") |
| print(f" Campaign denoiser: {sum(p.numel() for p in model.campaign_denoiser.parameters()):,}") |
| print(f" Channel denoiser: {sum(p.numel() for p in model.channel_denoiser.parameters()):,}") |
| print(f" Diffusion steps: {T_DIFF}") |
| |
| |
| print("\n[4/5] Training...") |
| |
| N_EPOCHS = 30 if device == 'cpu' else 50 |
| BATCH_SIZE = 8 if device == 'cpu' else 16 |
| |
| history = train_mmm_diffusion( |
| model, train_dataset, |
| n_epochs=N_EPOCHS, |
| batch_size=BATCH_SIZE, |
| lr=3e-4, |
| device=device, |
| log_every=25, |
| save_path='/app/mmm_diffusion_model.pt', |
| ) |
| |
| |
| plot_training_history(history, save_path='/app/training_history.png') |
| |
| |
| print("\n[5/5] Validation: generating coefficients for held-out sample...") |
| |
| model.eval() |
| model = model.to(device) |
| |
| |
| val_item = val_dataset[0] |
| cond = val_item['conditioning'].unsqueeze(0).to(device) |
| true_coeffs_norm = val_item['coefficients'].unsqueeze(0) |
| |
| |
| t0 = time.time() |
| pred_coeffs_norm = model.sample( |
| cond, |
| n_steps=T_DIFF, |
| constraint_every_k=10, |
| ) |
| gen_time = time.time() - t0 |
| print(f" Generation time: {gen_time:.1f}s") |
| |
| |
| pred_coeffs = val_dataset.decode_coefficients(pred_coeffs_norm.cpu()) |
| true_coeffs = val_dataset.decode_coefficients(true_coeffs_norm) |
| |
| pred_np = pred_coeffs[0].numpy() |
| true_np = true_coeffs[0].numpy() |
| |
| |
| media_pred = pred_np[:, :5] |
| print(f"\n Constraint check:") |
| print(f" Media coefficients all positive: {(media_pred > 0).all()}") |
| print(f" Media coeff min: {media_pred.min():.6f}") |
| print(f" Media coeff max: {media_pred.max():.6f}") |
| |
| |
| print(f"\n Per-channel correlation (true vs predicted):") |
| channel_names = MMMDataGenerator.MEDIA_CHANNELS + MMMDataGenerator.CONTROL_VARS |
| for i, name in enumerate(channel_names): |
| corr = np.corrcoef(true_np[:, i], pred_np[:, i])[0, 1] |
| rmse = np.sqrt(np.mean((true_np[:, i] - pred_np[:, i])**2)) |
| print(f" {name:20s}: corr={corr:.3f}, RMSE={rmse:.4f}") |
| |
| |
| plot_coefficient_comparison( |
| true_np, pred_np, channel_names, |
| save_path='/app/coeff_comparison.png' |
| ) |
| |
| |
| val_raw = val_samples[0] |
| contributions = decompose_sales( |
| pred_np, val_raw['media_spend'], val_raw['controls'] |
| ) |
| plot_sales_decomposition( |
| contributions, val_raw['total_sales'], |
| save_path='/app/sales_decomposition.png' |
| ) |
| |
| |
| print(f"\n{'='*60}") |
| print(f"MMM-DIFFUSION POC COMPLETE") |
| print(f"{'='*60}") |
| print(f" Architecture: Dual-denoiser (Campaign + Channel) diffusion") |
| print(f" Based on: Kimodo/GMD pattern with PhysDiff constraint projection") |
| print(f" Data: {n_train} synthetic MMM scenarios, {gen.n_weeks} weeks each") |
| print(f" Channels: {gen.n_media} media + {gen.n_ctrl} non-marketing") |
| print(f" Constraints: Log-space media (guaranteed positive) + soft sign loss + PhysDiff projection") |
| print(f" Model size: {total_params:,} parameters") |
| print(f" Final training loss: {history['loss'][-1]:.4f}") |
| print(f"\nOutputs:") |
| print(f" Model checkpoint: /app/mmm_diffusion_model.pt") |
| print(f" Training history: /app/training_history.png") |
| print(f" Coefficient comparison: /app/coeff_comparison.png") |
| print(f" Sales decomposition: /app/sales_decomposition.png") |
| |
| return model, history, train_dataset, val_dataset |
|
|
|
|
| if __name__ == '__main__': |
| model, history, train_dataset, val_dataset = main() |
|
|