mmm-diffusion / mmm_diffusion.py
sujimenon's picture
Add MMM-Diffusion model implementation
e34a047 verified
"""
Marketing Mix Model Diffusion (MMM-Diffusion)
==============================================
A generative diffusion model for Marketing Mix Modeling, adapted from NVIDIA's
Kimodo dual-denoiser architecture (GMD/MDM family).
Architecture mapping:
Kimodo → MMM-Diffusion
------- ---------------
Text prompts → Media spend, non-marketing vars, total sales
Motion/position constraints → Sign constraints (β_media ≥ 0) + prior constraints
Root denoiser (trajectory) → Campaign/Geo-level denoiser (aggregate patterns)
Body denoiser (joint rotations)→ Channel-level denoiser (per-channel coefficients)
Skeleton positions/rotations → Time-varying coefficients for sales decomposition
References:
- GMD (arxiv:2305.12577) — Two-stage trajectory + body diffusion
- MDM (arxiv:2209.14916) — Transformer denoiser, x₀-prediction, geometric losses
- PhysDiff (arxiv:2212.02500) — Projection during denoising for constraints
- PDM (arxiv:2402.03559) — Projected diffusion for hard constraint satisfaction
- NNN (arxiv:2504.06212) — Neural network MMM architecture from Google
"""
import math
import json
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from scipy.signal import lfilter
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
# =============================================================================
# 1. SYNTHETIC MMM DATA GENERATOR
# =============================================================================
class MMMDataGenerator:
"""
Generate synthetic Marketing Mix Model data with known ground truth.
Model: Sales_t = β_base_t + Σ_m β_m_t * Hill(Adstock(spend_m,t; α_m); ec50_m, k_m)
+ Σ_c β_c * ctrl_c,t + ε_t
Channels: 5 media channels + 3 non-marketing (control) variables
Based on NNN paper (arxiv:2504.06212) simulation recipe and Meridian framework.
"""
MEDIA_CHANNELS = ['TV', 'Digital', 'Social', 'Print', 'Radio']
CONTROL_VARS = ['Seasonality', 'Trend', 'Competitor_Price']
def __init__(self, n_weeks=104, n_geos=1, seed=None):
self.n_weeks = n_weeks
self.n_geos = n_geos
self.n_media = 5
self.n_ctrl = 3
self.rng = np.random.RandomState(seed)
def _generate_media_spend(self):
"""Generate realistic media spend patterns with seasonality and campaigns."""
spend = np.zeros((self.n_weeks, self.n_media))
t = np.arange(self.n_weeks)
# Base spend levels (different per channel)
base_levels = self.rng.uniform(50, 500, size=self.n_media)
for m in range(self.n_media):
# Base pattern with weekly variation
base = base_levels[m] * (1 + 0.2 * np.sin(2 * np.pi * t / 52))
# Random campaign bursts (3-8 per year)
n_campaigns = self.rng.randint(3, 9)
for _ in range(n_campaigns):
start = self.rng.randint(0, self.n_weeks - 4)
duration = self.rng.randint(1, 6)
intensity = self.rng.uniform(1.5, 4.0)
end = min(start + duration, self.n_weeks)
base[start:end] *= intensity
# Add noise
spend[:, m] = np.maximum(base + self.rng.normal(0, base_levels[m] * 0.1, self.n_weeks), 0)
return spend
def _adstock(self, x, alpha):
"""Geometric adstock transformation. α ∈ [0,1] is retention rate."""
result = np.zeros_like(x)
result[0] = x[0]
for t in range(1, len(x)):
result[t] = x[t] + alpha * result[t-1]
return result
def _hill(self, x, ec50, slope):
"""Hill saturation function. ec50>0, slope>0."""
x_safe = np.maximum(x, 0)
return x_safe**slope / (x_safe**slope + ec50**slope + 1e-10)
def _generate_controls(self):
"""Generate control (non-marketing) variables."""
t = np.arange(self.n_weeks)
controls = np.zeros((self.n_weeks, self.n_ctrl))
# Seasonality: annual cycle with harmonics
controls[:, 0] = (np.sin(2 * np.pi * t / 52) +
0.5 * np.sin(4 * np.pi * t / 52) +
0.3 * np.cos(2 * np.pi * t / 52))
# Trend: slow linear + mild quadratic
trend = t / self.n_weeks
controls[:, 1] = trend + 0.5 * trend**2
# Competitor price: random walk with mean reversion
price = np.zeros(self.n_weeks)
price[0] = 1.0
for i in range(1, self.n_weeks):
price[i] = 0.95 * price[i-1] + 0.05 * 1.0 + self.rng.normal(0, 0.05)
controls[:, 2] = price
return controls
def _sample_true_params(self):
"""Sample ground truth MMM parameters from realistic priors."""
params = {}
# Media coefficients (MUST BE POSITIVE — this is the key constraint)
# β_m ~ HalfNormal(0.5) — always positive
params['beta_media'] = np.abs(self.rng.normal(0, 0.5, self.n_media)) + 0.05
# Adstock retention rates α ∈ [0, 1]
# α ~ Beta(2, 2) — most mass in [0.2, 0.8]
params['adstock_alpha'] = self.rng.beta(2, 2, self.n_media)
params['adstock_alpha'] = np.clip(params['adstock_alpha'], 0.1, 0.95)
# Hill EC50 (half-saturation) — must be positive
# Sample relative to median spend
params['hill_ec50'] = np.abs(self.rng.lognormal(0, 0.5, self.n_media)) + 0.1
# Hill slope k ∈ [0.5, 3] — controls steepness
params['hill_slope'] = self.rng.uniform(0.5, 3.0, self.n_media)
# Base sales (intercept) — positive
params['beta_base'] = self.rng.uniform(500, 2000)
# Control coefficients (can be positive or negative)
params['beta_ctrl'] = self.rng.normal(0, 50, self.n_ctrl)
# Noise level
params['noise_std'] = self.rng.uniform(20, 100)
return params
def _make_time_varying(self, base_coeff, n_weeks, volatility=0.1):
"""
Make a coefficient time-varying via a random walk with mean reversion.
β_t = β * exp(z_t) where z_t follows an OU process.
"""
z = np.zeros(n_weeks)
for t in range(1, n_weeks):
z[t] = 0.9 * z[t-1] + self.rng.normal(0, volatility)
return base_coeff * np.exp(z)
def generate_single(self):
"""
Generate a single MMM dataset with known ground truth.
Returns:
dict with keys:
- media_spend: (T, 5) raw media spend
- controls: (T, 3) control variables
- total_sales: (T,) total sales
- true_coefficients: (T, 8) time-varying coefficients [5 media + 3 ctrl]
- true_contributions: (T, 8) sales contribution per variable
- true_params: dict of ground truth parameters
"""
spend = self._generate_media_spend()
controls = self._generate_controls()
params = self._sample_true_params()
# Apply adstock and Hill transformation to media
transformed_media = np.zeros_like(spend)
for m in range(self.n_media):
adstocked = self._adstock(spend[:, m], params['adstock_alpha'][m])
# Normalize before Hill
adstocked_norm = adstocked / (np.percentile(adstocked, 90) + 1e-10)
transformed_media[:, m] = self._hill(
adstocked_norm, params['hill_ec50'][m], params['hill_slope'][m]
)
# Generate time-varying coefficients
tv_coeffs = np.zeros((self.n_weeks, self.n_media + self.n_ctrl))
# Media coefficients — time-varying but ALWAYS POSITIVE
for m in range(self.n_media):
tv_coeffs[:, m] = self._make_time_varying(
params['beta_media'][m], self.n_weeks, volatility=0.05
)
tv_coeffs[:, m] = np.maximum(tv_coeffs[:, m], 0.01) # enforce positivity
# Control coefficients — mild time variation, can be negative
for c in range(self.n_ctrl):
tv_coeffs[:, self.n_media + c] = self._make_time_varying(
params['beta_ctrl'][c], self.n_weeks, volatility=0.03
)
# Compute contributions
contributions = np.zeros((self.n_weeks, self.n_media + self.n_ctrl))
for m in range(self.n_media):
contributions[:, m] = tv_coeffs[:, m] * transformed_media[:, m]
for c in range(self.n_ctrl):
contributions[:, self.n_media + c] = tv_coeffs[:, self.n_media + c] * controls[:, c]
# Total sales
base = params['beta_base']
noise = self.rng.normal(0, params['noise_std'], self.n_weeks)
total_sales = base + contributions.sum(axis=1) + noise
total_sales = np.maximum(total_sales, 0) # sales can't be negative
return {
'media_spend': spend,
'controls': controls,
'total_sales': total_sales,
'true_coefficients': tv_coeffs,
'true_contributions': contributions,
'base_sales': np.full(self.n_weeks, base),
'true_params': params
}
def generate_dataset(self, n_samples):
"""Generate n_samples MMM instances."""
samples = []
for i in range(n_samples):
self.rng = np.random.RandomState(self.rng.randint(0, 2**31))
samples.append(self.generate_single())
return samples
# =============================================================================
# 2. DATASET CLASS FOR TRAINING
# =============================================================================
class MMMDiffusionDataset(Dataset):
"""
Wraps generated MMM data for diffusion training.
Each sample provides:
- conditioning: (T, n_media + n_ctrl + 1) — [media_spend, controls, total_sales]
- target: (T, n_media + n_ctrl) — time-varying coefficients to denoise
- media_mask: boolean mask for media channels (positivity constraint)
"""
def __init__(self, samples, normalize=True):
self.samples = samples
self.normalize = normalize
self.n_media = 5
self.n_ctrl = 3
self.n_channels = self.n_media + self.n_ctrl # 8 total
# Compute normalization statistics
if normalize:
all_cond = np.stack([
np.concatenate([s['media_spend'], s['controls'], s['total_sales'][:, None]], axis=1)
for s in samples
])
all_coeff = np.stack([s['true_coefficients'] for s in samples])
self.cond_mean = all_cond.mean(axis=(0, 1))
self.cond_std = all_cond.std(axis=(0, 1)) + 1e-8
self.coeff_mean = all_coeff.mean(axis=(0, 1))
self.coeff_std = all_coeff.std(axis=(0, 1)) + 1e-8
# For media coefficients, use log-space normalization (ensures positivity)
# For ctrl coefficients, use standard z-score
media_coeffs = all_coeff[:, :, :self.n_media]
self.media_log_mean = np.log(media_coeffs + 1e-8).mean(axis=(0, 1))
self.media_log_std = np.log(media_coeffs + 1e-8).std(axis=(0, 1)) + 1e-8
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
s = self.samples[idx]
# Conditioning: [media_spend | controls | total_sales]
cond = np.concatenate([
s['media_spend'], s['controls'], s['total_sales'][:, None]
], axis=1).astype(np.float32) # (T, 9)
# Target: time-varying coefficients
coeffs = s['true_coefficients'].astype(np.float32) # (T, 8)
if self.normalize:
cond = (cond - self.cond_mean) / self.cond_std
# Log-space for media (positive) coefficients
log_media = np.log(coeffs[:, :self.n_media] + 1e-8)
log_media = (log_media - self.media_log_mean) / self.media_log_std
# Z-score for control coefficients
ctrl = (coeffs[:, self.n_media:] - self.coeff_mean[self.n_media:]) / self.coeff_std[self.n_media:]
coeffs = np.concatenate([log_media, ctrl], axis=1)
return {
'conditioning': torch.tensor(cond, dtype=torch.float32),
'coefficients': torch.tensor(coeffs, dtype=torch.float32),
}
def decode_coefficients(self, coeffs_normalized):
"""
Inverse-transform normalized coefficients back to original scale.
Applies exp() to media channels to enforce positivity.
Args:
coeffs_normalized: (batch, T, 8) normalized coefficients
Returns:
coeffs: (batch, T, 8) original-scale coefficients (media ≥ 0)
"""
if not self.normalize:
return coeffs_normalized
coeffs = coeffs_normalized.clone()
# Media channels: inverse log-space → guaranteed positive via exp()
media_log_mean = torch.tensor(self.media_log_mean, device=coeffs.device, dtype=coeffs.dtype)
media_log_std = torch.tensor(self.media_log_std, device=coeffs.device, dtype=coeffs.dtype)
coeffs[:, :, :self.n_media] = torch.exp(
coeffs[:, :, :self.n_media] * media_log_std + media_log_mean
)
# Control channels: inverse z-score
coeff_mean = torch.tensor(self.coeff_mean[self.n_media:], device=coeffs.device, dtype=coeffs.dtype)
coeff_std = torch.tensor(self.coeff_std[self.n_media:], device=coeffs.device, dtype=coeffs.dtype)
coeffs[:, :, self.n_media:] = coeffs[:, :, self.n_media:] * coeff_std + coeff_mean
return coeffs
# =============================================================================
# 3. DIFFUSION NOISE SCHEDULE
# =============================================================================
def cosine_beta_schedule(T, s=0.008):
"""Cosine noise schedule from 'Improved DDPM' (Nichol & Dhariwal, 2021)."""
t = torch.arange(T + 1, dtype=torch.float64)
f = torch.cos((t / T + s) / (1 + s) * math.pi / 2) ** 2
alphas_cumprod = f / f[0]
betas = 1 - alphas_cumprod[1:] / alphas_cumprod[:-1]
return torch.clamp(betas, 0, 0.999).float()
class DiffusionSchedule:
"""DDPM diffusion schedule with cosine noise."""
def __init__(self, T=1000):
self.T = T
self.betas = cosine_beta_schedule(T)
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas)
# Posterior variance for q(x_{t-1} | x_t, x_0)
self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)
self.posterior_variance = (
self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
)
self.posterior_log_variance_clipped = torch.log(
torch.clamp(self.posterior_variance, min=1e-20)
)
self.posterior_mean_coef1 = (
self.betas * torch.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
)
self.posterior_mean_coef2 = (
(1.0 - self.alphas_cumprod_prev) * torch.sqrt(self.alphas) / (1.0 - self.alphas_cumprod)
)
def to(self, device):
"""Move all tensors to device."""
for attr in ['betas', 'alphas', 'alphas_cumprod', 'sqrt_alphas_cumprod',
'sqrt_one_minus_alphas_cumprod', 'sqrt_recip_alphas',
'alphas_cumprod_prev', 'posterior_variance',
'posterior_log_variance_clipped', 'posterior_mean_coef1',
'posterior_mean_coef2']:
setattr(self, attr, getattr(self, attr).to(device))
return self
def q_sample(self, x_0, t, noise=None):
"""Forward diffusion: q(x_t | x_0)."""
if noise is None:
noise = torch.randn_like(x_0)
sqrt_alpha = self.sqrt_alphas_cumprod[t].view(-1, 1, 1)
sqrt_one_minus_alpha = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1)
return sqrt_alpha * x_0 + sqrt_one_minus_alpha * noise
def posterior_mean(self, x_0_pred, x_t, t):
"""Compute posterior mean q(x_{t-1} | x_t, x_0_pred)."""
coef1 = self.posterior_mean_coef1[t].view(-1, 1, 1)
coef2 = self.posterior_mean_coef2[t].view(-1, 1, 1)
return coef1 * x_0_pred + coef2 * x_t
# =============================================================================
# 4. DENOISER NETWORKS (Kimodo-adapted dual denoiser)
# =============================================================================
class SinusoidalPositionEmbeddings(nn.Module):
"""Sinusoidal embeddings for diffusion timestep t."""
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, t):
device = t.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = t[:, None].float() * emb[None, :]
return torch.cat([emb.sin(), emb.cos()], dim=-1)
class TemporalTransformerBlock(nn.Module):
"""Transformer block for temporal attention over time steps."""
def __init__(self, d_model, nhead, dropout=0.1):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
self.ff = nn.Sequential(
nn.Linear(d_model, d_model * 4),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(d_model * 4, d_model),
nn.Dropout(dropout),
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x):
# Self-attention over temporal dimension
h = self.norm1(x)
h = x + self.attn(h, h, h)[0]
h = h + self.ff(self.norm2(h))
return h
class CampaignDenoiser(nn.Module):
"""
Stage 1: Campaign/Geo-level Denoiser
Analogous to GMD's trajectory DPM / Kimodo's root denoiser.
Denoises aggregate-level patterns conditioned on non-marketing vars + total sales.
Predicts x_0 directly (not noise ε), enabling constraint projection at each step.
Input: x_t (B, T, n_agg) — noisy aggregate coefficients
Cond: (B, T, cond_dim) — non-marketing vars + total sales
Output: x_0_hat (B, T, n_agg) — predicted clean aggregate coefficients
"""
def __init__(self, n_agg_channels=3, cond_dim=4, d_model=256, nhead=4, n_layers=4, T_diff=1000):
super().__init__()
self.d_model = d_model
# Timestep embedding
self.time_embed = nn.Sequential(
SinusoidalPositionEmbeddings(d_model),
nn.Linear(d_model, d_model),
nn.GELU(),
nn.Linear(d_model, d_model),
)
# Conditioning projection: non-marketing vars + total sales
self.cond_proj = nn.Linear(cond_dim, d_model)
# Input projection: noisy aggregate coefficients
self.input_proj = nn.Linear(n_agg_channels, d_model)
# Learnable temporal positional encoding
self.pos_embed = nn.Parameter(torch.randn(1, 256, d_model) * 0.02)
# Transformer encoder for temporal attention
self.blocks = nn.ModuleList([
TemporalTransformerBlock(d_model, nhead) for _ in range(n_layers)
])
# Output projection
self.output_proj = nn.Sequential(
nn.LayerNorm(d_model),
nn.Linear(d_model, d_model // 2),
nn.GELU(),
nn.Linear(d_model // 2, n_agg_channels),
)
def forward(self, x_t, t, cond):
B, T_seq, _ = x_t.shape
# Embed diffusion timestep
t_emb = self.time_embed(t) # (B, d_model)
# Project inputs
h_x = self.input_proj(x_t) # (B, T, d_model)
h_c = self.cond_proj(cond) # (B, T, d_model)
# Combine: input + conditioning + time
h = h_x + h_c + t_emb.unsqueeze(1) + self.pos_embed[:, :T_seq, :]
# Temporal transformer
for block in self.blocks:
h = block(h)
return self.output_proj(h) # (B, T, n_agg)
class ChannelDenoiser(nn.Module):
"""
Stage 2: Channel-level Denoiser
Analogous to GMD's full-body DPM / Kimodo's body denoiser.
Denoises per-channel time-varying coefficients, conditioned on:
- Stage 1 output (aggregate patterns)
- Media spend data
- Total sales
Predicts x_0 directly for constraint projection.
Input: x_t (B, T, n_channels) — noisy channel coefficients
Cond: campaign_ctx (B, T, n_agg) — from Stage 1
media_spend (B, T, n_media) — raw media spend
total_sales (B, T, 1) — total sales
Output: x_0_hat (B, T, n_channels) — predicted clean coefficients
"""
def __init__(self, n_channels=8, n_media=5, n_agg=3, d_model=384, nhead=8, n_layers=6, T_diff=1000):
super().__init__()
self.d_model = d_model
self.n_media = n_media
self.n_channels = n_channels
# Timestep embedding
self.time_embed = nn.Sequential(
SinusoidalPositionEmbeddings(d_model),
nn.Linear(d_model, d_model),
nn.GELU(),
nn.Linear(d_model, d_model),
)
# Input projection
self.input_proj = nn.Linear(n_channels, d_model)
# Multi-source conditioning
self.campaign_proj = nn.Linear(n_agg, d_model) # Stage 1 output
self.spend_proj = nn.Linear(n_media, d_model) # Media spend
self.sales_proj = nn.Linear(1, d_model) # Total sales
# Cross-attention: channel features attend to conditioning
self.cross_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
self.cross_norm = nn.LayerNorm(d_model)
# Temporal position encoding
self.pos_embed = nn.Parameter(torch.randn(1, 256, d_model) * 0.02)
# Transformer blocks
self.blocks = nn.ModuleList([
TemporalTransformerBlock(d_model, nhead) for _ in range(n_layers)
])
# Channel-specific output heads (allows per-channel specialization)
self.output_proj = nn.Sequential(
nn.LayerNorm(d_model),
nn.Linear(d_model, d_model // 2),
nn.GELU(),
nn.Linear(d_model // 2, n_channels),
)
def forward(self, x_t, t, campaign_ctx, media_spend, total_sales):
B, T_seq, _ = x_t.shape
# Embed timestep
t_emb = self.time_embed(t) # (B, d_model)
# Project all inputs
h_x = self.input_proj(x_t) # (B, T, d_model)
h_camp = self.campaign_proj(campaign_ctx) # (B, T, d_model)
h_spend = self.spend_proj(media_spend) # (B, T, d_model)
h_sales = self.sales_proj(total_sales) # (B, T, d_model)
# Conditioning context: concatenate along sequence dim for cross-attention
cond_ctx = h_camp + h_spend + h_sales # (B, T, d_model) — additive fusion
# Add position + time embeddings
h = h_x + t_emb.unsqueeze(1) + self.pos_embed[:, :T_seq, :]
# Cross-attention: channel features attend to conditioning
h_normed = self.cross_norm(h)
h = h + self.cross_attn(h_normed, cond_ctx, cond_ctx)[0]
# Self-attention transformer blocks
for block in self.blocks:
h = block(h)
return self.output_proj(h) # (B, T, n_channels)
# =============================================================================
# 5. MMM DIFFUSION MODEL (Full Pipeline)
# =============================================================================
class MMMDiffusionModel(nn.Module):
"""
Full MMM-Diffusion model with dual denoiser architecture.
Stage 1 (Campaign Denoiser): Denoises aggregate patterns from non-mktg + sales
Stage 2 (Channel Denoiser): Denoises per-channel coefficients conditioned on Stage 1
Constraint enforcement:
1. Log-space reparametrization: media coefficients in log-space during training
2. PhysDiff-style projection: clamp during denoising every K steps
3. Soft loss penalties: L_sign, L_sales, L_smooth
"""
def __init__(self, n_media=5, n_ctrl=3, d_model_campaign=256, d_model_channel=384,
n_layers_campaign=4, n_layers_channel=6, T_diff=1000):
super().__init__()
self.n_media = n_media
self.n_ctrl = n_ctrl
self.n_channels = n_media + n_ctrl
self.T_diff = T_diff
# Aggregate channels for Stage 1 (we use 3: total_media_effect, seasonality_trend, base_level)
self.n_agg = 3
# Stage 1: Campaign/Geo Denoiser
self.campaign_denoiser = CampaignDenoiser(
n_agg_channels=self.n_agg,
cond_dim=n_ctrl + 1, # non-marketing vars + total sales
d_model=d_model_campaign,
nhead=4,
n_layers=n_layers_campaign,
T_diff=T_diff,
)
# Stage 2: Channel Denoiser
self.channel_denoiser = ChannelDenoiser(
n_channels=self.n_channels,
n_media=n_media,
n_agg=self.n_agg,
d_model=d_model_channel,
nhead=8,
n_layers=n_layers_channel,
T_diff=T_diff,
)
# Projection from full coefficients to aggregate representation
self.coeff_to_agg = nn.Linear(self.n_channels, self.n_agg)
self.agg_to_coeff_init = nn.Linear(self.n_agg, self.n_channels)
# Diffusion schedule
self.schedule = DiffusionSchedule(T_diff)
def compute_aggregate(self, coefficients):
"""Project full coefficients to aggregate representation for Stage 1."""
return self.coeff_to_agg(coefficients)
def forward_train(self, batch):
"""
Training forward pass. Uses x_0-prediction (predicts clean data, not noise).
Returns dict of losses.
"""
cond = batch['conditioning'] # (B, T, 9)
coeffs = batch['coefficients'] # (B, T, 8) — normalized, media in log-space
B, T_seq, _ = coeffs.shape
device = coeffs.device
# Separate conditioning components
media_spend = cond[:, :, :self.n_media] # (B, T, 5)
controls = cond[:, :, self.n_media:self.n_media + self.n_ctrl] # (B, T, 3)
total_sales = cond[:, :, -1:] # (B, T, 1)
stage1_cond = torch.cat([controls, total_sales], dim=-1) # (B, T, 4)
# Compute aggregate targets for Stage 1
with torch.no_grad():
agg_target = self.coeff_to_agg(coeffs) # (B, T, 3)
# ---- Stage 1: Campaign Denoiser ----
t1 = torch.randint(0, self.T_diff, (B,), device=device)
noise1 = torch.randn_like(agg_target)
agg_noisy = self.schedule.q_sample(agg_target, t1, noise1)
agg_pred = self.campaign_denoiser(agg_noisy, t1, stage1_cond)
# Stage 1 loss: x_0 prediction
loss_campaign = F.mse_loss(agg_pred, agg_target)
# ---- Stage 2: Channel Denoiser ----
t2 = torch.randint(0, self.T_diff, (B,), device=device)
noise2 = torch.randn_like(coeffs)
coeffs_noisy = self.schedule.q_sample(coeffs, t2, noise2)
# Use ground truth aggregate as conditioning (teacher forcing during training)
campaign_ctx = agg_target.detach()
coeffs_pred = self.channel_denoiser(
coeffs_noisy, t2, campaign_ctx, media_spend, total_sales
)
# Stage 2 loss: x_0 prediction
loss_channel = F.mse_loss(coeffs_pred, coeffs)
# ---- Auxiliary losses (geometric losses from MDM) ----
# L_smooth: temporal smoothness of predicted coefficients (analog of velocity loss)
delta_pred = coeffs_pred[:, 1:, :] - coeffs_pred[:, :-1, :]
delta_true = coeffs[:, 1:, :] - coeffs[:, :-1, :]
loss_smooth = F.mse_loss(delta_pred, delta_true)
# L_sign: soft positivity penalty for media coefficients (in log-space they should be finite)
# In log-space, very negative values → near-zero coefficients (OK but warn)
# We add a mild penalty for extremely negative log-values
media_pred_log = coeffs_pred[:, :, :self.n_media]
loss_sign = F.relu(-media_pred_log - 5.0).mean() # penalize if log(β) < -5
# L_sales: reconstruction consistency — predicted decomposition should match sales
# This is a soft constraint; exact matching comes from the conditioning
loss_sales = 0.0 # Computed during inference validation
# Total loss with weights
loss = (
1.0 * loss_campaign +
1.0 * loss_channel +
0.1 * loss_smooth +
0.01 * loss_sign
)
return {
'loss': loss,
'loss_campaign': loss_campaign.item(),
'loss_channel': loss_channel.item(),
'loss_smooth': loss_smooth.item(),
'loss_sign': loss_sign.item(),
}
@torch.no_grad()
def sample(self, conditioning, n_steps=None, constraint_every_k=10, guidance_scale=1.0):
"""
Generate time-varying coefficients via dual-denoiser reverse diffusion.
Uses PhysDiff-style projection every K steps for constraint enforcement.
Args:
conditioning: (B, T, 9) — [media_spend, controls, total_sales]
n_steps: number of denoising steps (None = full T)
constraint_every_k: apply hard constraints every K steps
guidance_scale: classifier-free guidance strength
Returns:
coefficients: (B, T, 8) — predicted time-varying coefficients (normalized)
"""
B, T_seq, _ = conditioning.shape
device = conditioning.device
# Separate conditioning
media_spend = conditioning[:, :, :self.n_media]
controls = conditioning[:, :, self.n_media:self.n_media + self.n_ctrl]
total_sales = conditioning[:, :, -1:]
stage1_cond = torch.cat([controls, total_sales], dim=-1)
T_diff = n_steps or self.T_diff
# ==== Stage 1: Denoise aggregate patterns ====
z_t = torch.randn(B, T_seq, self.n_agg, device=device)
for t in reversed(range(T_diff)):
t_batch = torch.full((B,), t, device=device, dtype=torch.long)
# Predict x_0
z_0_pred = self.campaign_denoiser(z_t, t_batch, stage1_cond)
if t > 0:
# Posterior sampling
mean = self.schedule.posterior_mean(z_0_pred, z_t, t_batch)
var = self.schedule.posterior_variance[t]
noise = torch.randn_like(z_t)
z_t = mean + torch.sqrt(var) * noise
else:
z_t = z_0_pred
campaign_ctx = z_t # (B, T, n_agg) — denoised aggregate
# ==== Stage 2: Denoise channel coefficients ====
x_t = torch.randn(B, T_seq, self.n_channels, device=device)
for t in reversed(range(T_diff)):
t_batch = torch.full((B,), t, device=device, dtype=torch.long)
# Predict x_0 (channel coefficients)
x_0_pred = self.channel_denoiser(
x_t, t_batch, campaign_ctx, media_spend, total_sales
)
# PhysDiff-style constraint projection every K steps
if t % constraint_every_k == 0:
# Soft-clamp media coefficients in log-space
# (very negative log = near zero, not necessarily bad, but prevent extreme)
x_0_pred[:, :, :self.n_media] = torch.clamp(
x_0_pred[:, :, :self.n_media], min=-8.0, max=8.0
)
if t > 0:
mean = self.schedule.posterior_mean(x_0_pred, x_t, t_batch)
var = self.schedule.posterior_variance[t]
noise = torch.randn_like(x_t)
x_t = mean + torch.sqrt(var) * noise
else:
x_t = x_0_pred
return x_t # (B, T, n_channels) — normalized coefficients
# =============================================================================
# 6. TRAINING LOOP
# =============================================================================
def train_mmm_diffusion(
model, dataset,
n_epochs=50, batch_size=16, lr=1e-4,
device='cpu', log_every=50, save_path='mmm_diffusion_model.pt'
):
"""Train the MMM diffusion model."""
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)
model = model.to(device)
model.schedule = model.schedule.to(device)
history = {'loss': [], 'loss_campaign': [], 'loss_channel': [], 'loss_smooth': [], 'loss_sign': []}
print(f"\nTraining MMM-Diffusion Model")
print(f" Device: {device}")
print(f" Samples: {len(dataset)}, Batch size: {batch_size}")
print(f" Epochs: {n_epochs}, LR: {lr}")
print(f" Model params: {sum(p.numel() for p in model.parameters()):,}")
print(f" Diffusion steps: {model.T_diff}")
print("-" * 60)
step = 0
for epoch in range(n_epochs):
model.train()
epoch_losses = {k: [] for k in history}
for batch in dataloader:
batch = {k: v.to(device) for k, v in batch.items()}
losses = model.forward_train(batch)
optimizer.zero_grad()
losses['loss'].backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
for k in history:
val = losses[k].item() if isinstance(losses[k], torch.Tensor) else losses[k]
epoch_losses[k].append(val)
step += 1
if step % log_every == 0:
avg = {k: np.mean(v[-log_every:]) for k, v in epoch_losses.items() if v}
print(f" Step {step:5d} | loss={avg['loss']:.4f} "
f"camp={avg['loss_campaign']:.4f} chan={avg['loss_channel']:.4f} "
f"smooth={avg['loss_smooth']:.4f} sign={avg['loss_sign']:.6f}")
scheduler.step()
# Epoch summary
avg = {k: np.mean(v) for k, v in epoch_losses.items() if v}
for k, v in avg.items():
history[k].append(v)
print(f"Epoch {epoch+1:3d}/{n_epochs} | loss={avg['loss']:.4f} "
f"camp={avg['loss_campaign']:.4f} chan={avg['loss_channel']:.4f} "
f"smooth={avg['loss_smooth']:.4f} sign={avg['loss_sign']:.6f} "
f"lr={scheduler.get_last_lr()[0]:.6f}")
# Save
torch.save({
'model_state_dict': model.state_dict(),
'history': history,
'config': {
'n_media': model.n_media,
'n_ctrl': model.n_ctrl,
'T_diff': model.T_diff,
}
}, save_path)
print(f"\nModel saved to {save_path}")
return history
# =============================================================================
# 7. SALES DECOMPOSITION FROM PREDICTED COEFFICIENTS
# =============================================================================
def decompose_sales(coefficients, media_spend, controls, adstock_alphas=None):
"""
Given predicted time-varying coefficients, decompose sales into contributions.
Args:
coefficients: (T, 8) — decoded coefficients [5 media + 3 ctrl]
media_spend: (T, 5) — raw media spend
controls: (T, 3) — control variables
adstock_alphas: optional (5,) — adstock retention rates
Returns:
contributions: dict with per-channel contributions
"""
T, n_total = coefficients.shape
n_media = 5
contributions = {}
total_media = np.zeros(T)
for m in range(n_media):
name = MMMDataGenerator.MEDIA_CHANNELS[m]
spend = media_spend[:, m]
# Apply default adstock if alphas provided
if adstock_alphas is not None:
adstocked = np.zeros(T)
adstocked[0] = spend[0]
for t in range(1, T):
adstocked[t] = spend[t] + adstock_alphas[m] * adstocked[t-1]
spend = adstocked
# Contribution = coefficient × transformed spend
# (simplified: using raw spend here; full model would apply Hill too)
contrib = coefficients[:, m] * (spend / (np.percentile(spend, 90) + 1e-10))
contributions[name] = contrib
total_media += contrib
total_ctrl = np.zeros(T)
for c in range(3):
name = MMMDataGenerator.CONTROL_VARS[c]
contrib = coefficients[:, n_media + c] * controls[:, c]
contributions[name] = contrib
total_ctrl += contrib
contributions['Total_Media'] = total_media
contributions['Total_Controls'] = total_ctrl
contributions['Predicted_Sales'] = total_media + total_ctrl
return contributions
# =============================================================================
# 8. VISUALIZATION
# =============================================================================
def plot_training_history(history, save_path='training_history.png'):
"""Plot training loss curves."""
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
for ax, (key, values) in zip(axes.flatten(), history.items()):
ax.plot(values, linewidth=1.5)
ax.set_title(f'{key}', fontsize=12)
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.grid(True, alpha=0.3)
ax.set_yscale('log' if min(values) > 0 else 'linear')
plt.suptitle('MMM-Diffusion Training History', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.close()
print(f"Training history plot saved to {save_path}")
def plot_coefficient_comparison(true_coeffs, pred_coeffs, channel_names, save_path='coeff_comparison.png'):
"""Compare true vs predicted time-varying coefficients."""
n_channels = true_coeffs.shape[1]
fig, axes = plt.subplots(n_channels, 1, figsize=(14, 2.5 * n_channels))
for i, (ax, name) in enumerate(zip(axes, channel_names)):
ax.plot(true_coeffs[:, i], 'b-', label='Ground Truth', linewidth=1.5)
ax.plot(pred_coeffs[:, i], 'r--', label='Predicted', linewidth=1.5, alpha=0.8)
ax.set_title(f'{name} — Time-Varying Coefficient', fontsize=11)
ax.legend(fontsize=9)
ax.grid(True, alpha=0.3)
if i < 5: # Media channels
ax.axhline(y=0, color='gray', linestyle=':', alpha=0.5)
ax.set_ylabel('β (≥0)')
else:
ax.set_ylabel('β')
axes[-1].set_xlabel('Week')
plt.suptitle('MMM-Diffusion: Coefficient Prediction Quality', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.close()
print(f"Coefficient comparison plot saved to {save_path}")
def plot_sales_decomposition(contributions, total_sales, save_path='sales_decomposition.png'):
"""Plot stacked area chart of sales decomposition."""
fig, axes = plt.subplots(2, 1, figsize=(14, 10))
# Top: Stacked area of contributions
ax = axes[0]
weeks = np.arange(len(total_sales))
media_names = MMMDataGenerator.MEDIA_CHANNELS
colors = plt.cm.Set2(np.linspace(0, 1, len(media_names)))
bottom = np.zeros(len(total_sales))
for name, color in zip(media_names, colors):
vals = np.maximum(contributions[name], 0)
ax.fill_between(weeks, bottom, bottom + vals, alpha=0.7, label=name, color=color)
bottom += vals
ax.plot(weeks, total_sales, 'k-', linewidth=2, label='Total Sales', alpha=0.8)
ax.set_title('Sales Decomposition: Media Channel Contributions', fontsize=12)
ax.legend(loc='upper left', fontsize=9)
ax.set_xlabel('Week')
ax.set_ylabel('Sales Contribution')
ax.grid(True, alpha=0.3)
# Bottom: Total predicted vs actual
ax = axes[1]
ax.plot(weeks, total_sales, 'b-', linewidth=2, label='Actual Sales')
ax.plot(weeks, contributions['Predicted_Sales'], 'r--', linewidth=2, label='Predicted (Media + Controls)', alpha=0.8)
ax.set_title('Total Sales: Actual vs Predicted Decomposition', fontsize=12)
ax.legend(fontsize=10)
ax.set_xlabel('Week')
ax.set_ylabel('Sales')
ax.grid(True, alpha=0.3)
plt.suptitle('MMM-Diffusion: Sales Decomposition', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.close()
print(f"Sales decomposition plot saved to {save_path}")
# =============================================================================
# 9. MAIN — FULL POC PIPELINE
# =============================================================================
def main():
import time
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"=" * 60)
print(f"MMM-DIFFUSION: Marketing Mix Model via Diffusion")
print(f"Adapted from Kimodo/GMD dual-denoiser architecture")
print(f"Device: {device}")
print(f"=" * 60)
# ---- Step 1: Generate synthetic data ----
print("\n[1/5] Generating synthetic MMM data...")
t0 = time.time()
gen = MMMDataGenerator(n_weeks=104, seed=42)
n_train = 500 # 500 training scenarios
n_val = 50 # 50 validation scenarios
train_samples = gen.generate_dataset(n_train)
val_samples = gen.generate_dataset(n_val)
print(f" Generated {n_train} train + {n_val} val scenarios")
print(f" Each: {gen.n_weeks} weeks, {gen.n_media} media channels, {gen.n_ctrl} control vars")
print(f" Time: {time.time()-t0:.1f}s")
# Quick data audit
sample = train_samples[0]
print(f"\n Data audit (sample 0):")
print(f" Media spend shape: {sample['media_spend'].shape}")
print(f" Media spend range: [{sample['media_spend'].min():.1f}, {sample['media_spend'].max():.1f}]")
print(f" Sales range: [{sample['total_sales'].min():.1f}, {sample['total_sales'].max():.1f}]")
print(f" Media coeff range: [{sample['true_coefficients'][:,:5].min():.4f}, {sample['true_coefficients'][:,:5].max():.4f}]")
print(f" All media coeffs positive: {(sample['true_coefficients'][:,:5] > 0).all()}")
# ---- Step 2: Create datasets ----
print("\n[2/5] Creating training datasets...")
train_dataset = MMMDiffusionDataset(train_samples, normalize=True)
val_dataset = MMMDiffusionDataset(val_samples, normalize=True)
item = train_dataset[0]
print(f" Conditioning shape: {item['conditioning'].shape}")
print(f" Coefficients shape: {item['coefficients'].shape}")
# ---- Step 3: Build model ----
print("\n[3/5] Building MMM-Diffusion model...")
# For PoC: use smaller model and fewer diffusion steps for faster training on CPU
T_DIFF = 200 if device == 'cpu' else 500
model = MMMDiffusionModel(
n_media=5, n_ctrl=3,
d_model_campaign=128, # Smaller for PoC
d_model_channel=192, # Smaller for PoC
n_layers_campaign=3,
n_layers_channel=4,
T_diff=T_DIFF,
)
total_params = sum(p.numel() for p in model.parameters())
print(f" Total parameters: {total_params:,}")
print(f" Campaign denoiser: {sum(p.numel() for p in model.campaign_denoiser.parameters()):,}")
print(f" Channel denoiser: {sum(p.numel() for p in model.channel_denoiser.parameters()):,}")
print(f" Diffusion steps: {T_DIFF}")
# ---- Step 4: Train ----
print("\n[4/5] Training...")
N_EPOCHS = 30 if device == 'cpu' else 50
BATCH_SIZE = 8 if device == 'cpu' else 16
history = train_mmm_diffusion(
model, train_dataset,
n_epochs=N_EPOCHS,
batch_size=BATCH_SIZE,
lr=3e-4,
device=device,
log_every=25,
save_path='/app/mmm_diffusion_model.pt',
)
# Plot training history
plot_training_history(history, save_path='/app/training_history.png')
# ---- Step 5: Validate — Generate coefficients and decompose ----
print("\n[5/5] Validation: generating coefficients for held-out sample...")
model.eval()
model = model.to(device)
# Take a validation sample
val_item = val_dataset[0]
cond = val_item['conditioning'].unsqueeze(0).to(device) # (1, T, 9)
true_coeffs_norm = val_item['coefficients'].unsqueeze(0) # (1, T, 8)
# Generate coefficients via reverse diffusion
t0 = time.time()
pred_coeffs_norm = model.sample(
cond,
n_steps=T_DIFF,
constraint_every_k=10,
)
gen_time = time.time() - t0
print(f" Generation time: {gen_time:.1f}s")
# Decode to original scale
pred_coeffs = val_dataset.decode_coefficients(pred_coeffs_norm.cpu())
true_coeffs = val_dataset.decode_coefficients(true_coeffs_norm)
pred_np = pred_coeffs[0].numpy()
true_np = true_coeffs[0].numpy()
# Check constraints
media_pred = pred_np[:, :5]
print(f"\n Constraint check:")
print(f" Media coefficients all positive: {(media_pred > 0).all()}")
print(f" Media coeff min: {media_pred.min():.6f}")
print(f" Media coeff max: {media_pred.max():.6f}")
# Correlation between true and predicted coefficients
print(f"\n Per-channel correlation (true vs predicted):")
channel_names = MMMDataGenerator.MEDIA_CHANNELS + MMMDataGenerator.CONTROL_VARS
for i, name in enumerate(channel_names):
corr = np.corrcoef(true_np[:, i], pred_np[:, i])[0, 1]
rmse = np.sqrt(np.mean((true_np[:, i] - pred_np[:, i])**2))
print(f" {name:20s}: corr={corr:.3f}, RMSE={rmse:.4f}")
# Plot coefficient comparison
plot_coefficient_comparison(
true_np, pred_np, channel_names,
save_path='/app/coeff_comparison.png'
)
# Sales decomposition
val_raw = val_samples[0]
contributions = decompose_sales(
pred_np, val_raw['media_spend'], val_raw['controls']
)
plot_sales_decomposition(
contributions, val_raw['total_sales'],
save_path='/app/sales_decomposition.png'
)
# ---- Summary ----
print(f"\n{'='*60}")
print(f"MMM-DIFFUSION POC COMPLETE")
print(f"{'='*60}")
print(f" Architecture: Dual-denoiser (Campaign + Channel) diffusion")
print(f" Based on: Kimodo/GMD pattern with PhysDiff constraint projection")
print(f" Data: {n_train} synthetic MMM scenarios, {gen.n_weeks} weeks each")
print(f" Channels: {gen.n_media} media + {gen.n_ctrl} non-marketing")
print(f" Constraints: Log-space media (guaranteed positive) + soft sign loss + PhysDiff projection")
print(f" Model size: {total_params:,} parameters")
print(f" Final training loss: {history['loss'][-1]:.4f}")
print(f"\nOutputs:")
print(f" Model checkpoint: /app/mmm_diffusion_model.pt")
print(f" Training history: /app/training_history.png")
print(f" Coefficient comparison: /app/coeff_comparison.png")
print(f" Sales decomposition: /app/sales_decomposition.png")
return model, history, train_dataset, val_dataset
if __name__ == '__main__':
model, history, train_dataset, val_dataset = main()