v2: Upload mmm_diffusion_v2.py
Browse files- mmm_diffusion_v2.py +1267 -0
mmm_diffusion_v2.py
ADDED
|
@@ -0,0 +1,1267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Marketing Mix Model Diffusion (MMM-Diffusion) v2
|
| 3 |
+
=================================================
|
| 4 |
+
Fixed version addressing:
|
| 5 |
+
1. Sales alignment: Added explicit sales reconstruction loss during training
|
| 6 |
+
2. Coefficient smoothness: Reduced smoothness weight, added spectral loss,
|
| 7 |
+
increased GT coefficient volatility, added multi-scale temporal loss
|
| 8 |
+
|
| 9 |
+
Architecture mapping (from Kimodo/GMD):
|
| 10 |
+
Text prompts → Media spend, non-marketing vars, total sales
|
| 11 |
+
Motion/position constraints → Sign constraints (β_media ≥ 0) + prior constraints
|
| 12 |
+
Root denoiser → Campaign/Geo-level denoiser (aggregate patterns)
|
| 13 |
+
Body denoiser → Channel-level denoiser (per-channel coefficients)
|
| 14 |
+
Skeleton positions/rotations → Time-varying coefficients for sales decomposition
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
import json
|
| 19 |
+
import os
|
| 20 |
+
import numpy as np
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn as nn
|
| 23 |
+
import torch.nn.functional as F
|
| 24 |
+
from torch.utils.data import Dataset, DataLoader
|
| 25 |
+
import matplotlib
|
| 26 |
+
matplotlib.use('Agg')
|
| 27 |
+
import matplotlib.pyplot as plt
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# =============================================================================
|
| 31 |
+
# 1. SYNTHETIC MMM DATA GENERATOR (increased volatility for richer coefficients)
|
| 32 |
+
# =============================================================================
|
| 33 |
+
|
| 34 |
+
class MMMDataGenerator:
|
| 35 |
+
MEDIA_CHANNELS = ['TV', 'Digital', 'Social', 'Print', 'Radio']
|
| 36 |
+
CONTROL_VARS = ['Seasonality', 'Trend', 'Competitor_Price']
|
| 37 |
+
|
| 38 |
+
def __init__(self, n_weeks=104, n_geos=1, seed=None):
|
| 39 |
+
self.n_weeks = n_weeks
|
| 40 |
+
self.n_geos = n_geos
|
| 41 |
+
self.n_media = 5
|
| 42 |
+
self.n_ctrl = 3
|
| 43 |
+
self.rng = np.random.RandomState(seed)
|
| 44 |
+
|
| 45 |
+
def _generate_media_spend(self):
|
| 46 |
+
spend = np.zeros((self.n_weeks, self.n_media))
|
| 47 |
+
t = np.arange(self.n_weeks)
|
| 48 |
+
base_levels = self.rng.uniform(50, 500, size=self.n_media)
|
| 49 |
+
|
| 50 |
+
for m in range(self.n_media):
|
| 51 |
+
base = base_levels[m] * (1 + 0.2 * np.sin(2 * np.pi * t / 52))
|
| 52 |
+
n_campaigns = self.rng.randint(3, 9)
|
| 53 |
+
for _ in range(n_campaigns):
|
| 54 |
+
start = self.rng.randint(0, self.n_weeks - 4)
|
| 55 |
+
duration = self.rng.randint(1, 6)
|
| 56 |
+
intensity = self.rng.uniform(1.5, 4.0)
|
| 57 |
+
end = min(start + duration, self.n_weeks)
|
| 58 |
+
base[start:end] *= intensity
|
| 59 |
+
spend[:, m] = np.maximum(base + self.rng.normal(0, base_levels[m] * 0.1, self.n_weeks), 0)
|
| 60 |
+
return spend
|
| 61 |
+
|
| 62 |
+
def _adstock(self, x, alpha):
|
| 63 |
+
result = np.zeros_like(x)
|
| 64 |
+
result[0] = x[0]
|
| 65 |
+
for t in range(1, len(x)):
|
| 66 |
+
result[t] = x[t] + alpha * result[t-1]
|
| 67 |
+
return result
|
| 68 |
+
|
| 69 |
+
def _hill(self, x, ec50, slope):
|
| 70 |
+
x_safe = np.maximum(x, 0)
|
| 71 |
+
return x_safe**slope / (x_safe**slope + ec50**slope + 1e-10)
|
| 72 |
+
|
| 73 |
+
def _generate_controls(self):
|
| 74 |
+
t = np.arange(self.n_weeks)
|
| 75 |
+
controls = np.zeros((self.n_weeks, self.n_ctrl))
|
| 76 |
+
controls[:, 0] = (np.sin(2 * np.pi * t / 52) +
|
| 77 |
+
0.5 * np.sin(4 * np.pi * t / 52) +
|
| 78 |
+
0.3 * np.cos(2 * np.pi * t / 52))
|
| 79 |
+
trend = t / self.n_weeks
|
| 80 |
+
controls[:, 1] = trend + 0.5 * trend**2
|
| 81 |
+
price = np.zeros(self.n_weeks)
|
| 82 |
+
price[0] = 1.0
|
| 83 |
+
for i in range(1, self.n_weeks):
|
| 84 |
+
price[i] = 0.95 * price[i-1] + 0.05 * 1.0 + self.rng.normal(0, 0.05)
|
| 85 |
+
controls[:, 2] = price
|
| 86 |
+
return controls
|
| 87 |
+
|
| 88 |
+
def _sample_true_params(self):
|
| 89 |
+
params = {}
|
| 90 |
+
params['beta_media'] = np.abs(self.rng.normal(0, 0.5, self.n_media)) + 0.05
|
| 91 |
+
params['adstock_alpha'] = self.rng.beta(2, 2, self.n_media)
|
| 92 |
+
params['adstock_alpha'] = np.clip(params['adstock_alpha'], 0.1, 0.95)
|
| 93 |
+
params['hill_ec50'] = np.abs(self.rng.lognormal(0, 0.5, self.n_media)) + 0.1
|
| 94 |
+
params['hill_slope'] = self.rng.uniform(0.5, 3.0, self.n_media)
|
| 95 |
+
params['beta_base'] = self.rng.uniform(500, 2000)
|
| 96 |
+
params['beta_ctrl'] = self.rng.normal(0, 50, self.n_ctrl)
|
| 97 |
+
params['noise_std'] = self.rng.uniform(20, 100)
|
| 98 |
+
return params
|
| 99 |
+
|
| 100 |
+
def _make_time_varying(self, base_coeff, n_weeks, volatility=0.1):
|
| 101 |
+
"""
|
| 102 |
+
FIX: Increased default volatility and added regime-change jumps
|
| 103 |
+
to produce more realistic, less smooth coefficients.
|
| 104 |
+
"""
|
| 105 |
+
z = np.zeros(n_weeks)
|
| 106 |
+
# OU process with higher volatility
|
| 107 |
+
mean_reversion = 0.85 # slightly less mean-reverting (was 0.9)
|
| 108 |
+
for t in range(1, n_weeks):
|
| 109 |
+
z[t] = mean_reversion * z[t-1] + self.rng.normal(0, volatility)
|
| 110 |
+
|
| 111 |
+
# Add occasional regime jumps (structural breaks)
|
| 112 |
+
n_jumps = self.rng.randint(0, 4)
|
| 113 |
+
for _ in range(n_jumps):
|
| 114 |
+
jump_t = self.rng.randint(5, n_weeks - 5)
|
| 115 |
+
jump_size = self.rng.normal(0, volatility * 3)
|
| 116 |
+
z[jump_t:] += jump_size
|
| 117 |
+
|
| 118 |
+
return base_coeff * np.exp(z)
|
| 119 |
+
|
| 120 |
+
def generate_single(self):
|
| 121 |
+
spend = self._generate_media_spend()
|
| 122 |
+
controls = self._generate_controls()
|
| 123 |
+
params = self._sample_true_params()
|
| 124 |
+
|
| 125 |
+
transformed_media = np.zeros_like(spend)
|
| 126 |
+
for m in range(self.n_media):
|
| 127 |
+
adstocked = self._adstock(spend[:, m], params['adstock_alpha'][m])
|
| 128 |
+
adstocked_norm = adstocked / (np.percentile(adstocked, 90) + 1e-10)
|
| 129 |
+
transformed_media[:, m] = self._hill(
|
| 130 |
+
adstocked_norm, params['hill_ec50'][m], params['hill_slope'][m]
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
# FIX: Higher volatility for time-varying coefficients
|
| 134 |
+
tv_coeffs = np.zeros((self.n_weeks, self.n_media + self.n_ctrl))
|
| 135 |
+
for m in range(self.n_media):
|
| 136 |
+
tv_coeffs[:, m] = self._make_time_varying(
|
| 137 |
+
params['beta_media'][m], self.n_weeks, volatility=0.12 # was 0.05
|
| 138 |
+
)
|
| 139 |
+
tv_coeffs[:, m] = np.maximum(tv_coeffs[:, m], 0.01)
|
| 140 |
+
for c in range(self.n_ctrl):
|
| 141 |
+
tv_coeffs[:, self.n_media + c] = self._make_time_varying(
|
| 142 |
+
params['beta_ctrl'][c], self.n_weeks, volatility=0.08 # was 0.03
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
contributions = np.zeros((self.n_weeks, self.n_media + self.n_ctrl))
|
| 146 |
+
for m in range(self.n_media):
|
| 147 |
+
contributions[:, m] = tv_coeffs[:, m] * transformed_media[:, m]
|
| 148 |
+
for c in range(self.n_ctrl):
|
| 149 |
+
contributions[:, self.n_media + c] = tv_coeffs[:, self.n_media + c] * controls[:, c]
|
| 150 |
+
|
| 151 |
+
base = params['beta_base']
|
| 152 |
+
noise = self.rng.normal(0, params['noise_std'], self.n_weeks)
|
| 153 |
+
total_sales = base + contributions.sum(axis=1) + noise
|
| 154 |
+
total_sales = np.maximum(total_sales, 0)
|
| 155 |
+
|
| 156 |
+
return {
|
| 157 |
+
'media_spend': spend,
|
| 158 |
+
'controls': controls,
|
| 159 |
+
'total_sales': total_sales,
|
| 160 |
+
'true_coefficients': tv_coeffs,
|
| 161 |
+
'true_contributions': contributions,
|
| 162 |
+
'transformed_media': transformed_media,
|
| 163 |
+
'base_sales': np.full(self.n_weeks, base),
|
| 164 |
+
'true_params': params
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
def generate_dataset(self, n_samples):
|
| 168 |
+
samples = []
|
| 169 |
+
for i in range(n_samples):
|
| 170 |
+
self.rng = np.random.RandomState(self.rng.randint(0, 2**31))
|
| 171 |
+
samples.append(self.generate_single())
|
| 172 |
+
return samples
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
# =============================================================================
|
| 176 |
+
# 2. DATASET CLASS - now also stores transformed media for sales reconstruction
|
| 177 |
+
# =============================================================================
|
| 178 |
+
|
| 179 |
+
class MMMDiffusionDataset(Dataset):
|
| 180 |
+
def __init__(self, samples, normalize=True):
|
| 181 |
+
self.samples = samples
|
| 182 |
+
self.normalize = normalize
|
| 183 |
+
self.n_media = 5
|
| 184 |
+
self.n_ctrl = 3
|
| 185 |
+
self.n_channels = self.n_media + self.n_ctrl
|
| 186 |
+
|
| 187 |
+
if normalize:
|
| 188 |
+
all_cond = np.stack([
|
| 189 |
+
np.concatenate([s['media_spend'], s['controls'], s['total_sales'][:, None]], axis=1)
|
| 190 |
+
for s in samples
|
| 191 |
+
])
|
| 192 |
+
all_coeff = np.stack([s['true_coefficients'] for s in samples])
|
| 193 |
+
|
| 194 |
+
self.cond_mean = all_cond.mean(axis=(0, 1))
|
| 195 |
+
self.cond_std = all_cond.std(axis=(0, 1)) + 1e-8
|
| 196 |
+
self.coeff_mean = all_coeff.mean(axis=(0, 1))
|
| 197 |
+
self.coeff_std = all_coeff.std(axis=(0, 1)) + 1e-8
|
| 198 |
+
|
| 199 |
+
media_coeffs = all_coeff[:, :, :self.n_media]
|
| 200 |
+
self.media_log_mean = np.log(media_coeffs + 1e-8).mean(axis=(0, 1))
|
| 201 |
+
self.media_log_std = np.log(media_coeffs + 1e-8).std(axis=(0, 1)) + 1e-8
|
| 202 |
+
|
| 203 |
+
# Store sales normalization for reconstruction loss
|
| 204 |
+
all_sales = np.stack([s['total_sales'] for s in samples])
|
| 205 |
+
self.sales_mean = float(all_sales.mean())
|
| 206 |
+
self.sales_std = float(all_sales.std()) + 1e-8
|
| 207 |
+
|
| 208 |
+
# Store transformed media normalization
|
| 209 |
+
all_trans = np.stack([s['transformed_media'] for s in samples])
|
| 210 |
+
self.trans_media_mean = all_trans.mean(axis=(0, 1))
|
| 211 |
+
self.trans_media_std = all_trans.std(axis=(0, 1)) + 1e-8
|
| 212 |
+
|
| 213 |
+
# Store base sales statistics
|
| 214 |
+
all_base = np.array([s['base_sales'][0] for s in samples])
|
| 215 |
+
self.base_mean = float(all_base.mean())
|
| 216 |
+
self.base_std = float(all_base.std()) + 1e-8
|
| 217 |
+
|
| 218 |
+
def __len__(self):
|
| 219 |
+
return len(self.samples)
|
| 220 |
+
|
| 221 |
+
def __getitem__(self, idx):
|
| 222 |
+
s = self.samples[idx]
|
| 223 |
+
|
| 224 |
+
cond = np.concatenate([
|
| 225 |
+
s['media_spend'], s['controls'], s['total_sales'][:, None]
|
| 226 |
+
], axis=1).astype(np.float32)
|
| 227 |
+
|
| 228 |
+
coeffs = s['true_coefficients'].astype(np.float32)
|
| 229 |
+
trans_media = s['transformed_media'].astype(np.float32)
|
| 230 |
+
controls = s['controls'].astype(np.float32)
|
| 231 |
+
total_sales = s['total_sales'].astype(np.float32)
|
| 232 |
+
base_sales = s['base_sales'].astype(np.float32)
|
| 233 |
+
contributions = s['true_contributions'].astype(np.float32)
|
| 234 |
+
|
| 235 |
+
if self.normalize:
|
| 236 |
+
cond = (cond - self.cond_mean) / self.cond_std
|
| 237 |
+
|
| 238 |
+
log_media = np.log(coeffs[:, :self.n_media] + 1e-8)
|
| 239 |
+
log_media = (log_media - self.media_log_mean) / self.media_log_std
|
| 240 |
+
ctrl = (coeffs[:, self.n_media:] - self.coeff_mean[self.n_media:]) / self.coeff_std[self.n_media:]
|
| 241 |
+
coeffs = np.concatenate([log_media, ctrl], axis=1)
|
| 242 |
+
|
| 243 |
+
return {
|
| 244 |
+
'conditioning': torch.tensor(cond, dtype=torch.float32),
|
| 245 |
+
'coefficients': torch.tensor(coeffs, dtype=torch.float32),
|
| 246 |
+
'transformed_media': torch.tensor(trans_media, dtype=torch.float32),
|
| 247 |
+
'controls': torch.tensor(controls, dtype=torch.float32),
|
| 248 |
+
'total_sales': torch.tensor(total_sales, dtype=torch.float32),
|
| 249 |
+
'base_sales': torch.tensor(base_sales, dtype=torch.float32),
|
| 250 |
+
'contributions': torch.tensor(contributions, dtype=torch.float32),
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
def decode_coefficients(self, coeffs_normalized):
|
| 254 |
+
if not self.normalize:
|
| 255 |
+
return coeffs_normalized
|
| 256 |
+
coeffs = coeffs_normalized.clone()
|
| 257 |
+
media_log_mean = torch.tensor(self.media_log_mean, device=coeffs.device, dtype=coeffs.dtype)
|
| 258 |
+
media_log_std = torch.tensor(self.media_log_std, device=coeffs.device, dtype=coeffs.dtype)
|
| 259 |
+
coeffs[:, :, :self.n_media] = torch.exp(
|
| 260 |
+
coeffs[:, :, :self.n_media] * media_log_std + media_log_mean
|
| 261 |
+
)
|
| 262 |
+
coeff_mean = torch.tensor(self.coeff_mean[self.n_media:], device=coeffs.device, dtype=coeffs.dtype)
|
| 263 |
+
coeff_std = torch.tensor(self.coeff_std[self.n_media:], device=coeffs.device, dtype=coeffs.dtype)
|
| 264 |
+
coeffs[:, :, self.n_media:] = coeffs[:, :, self.n_media:] * coeff_std + coeff_mean
|
| 265 |
+
return coeffs
|
| 266 |
+
|
| 267 |
+
def decode_media_coefficients_differentiable(self, coeffs_norm_media):
|
| 268 |
+
"""Decode only media coefficients, keeping gradients (for sales loss)."""
|
| 269 |
+
media_log_mean = torch.tensor(self.media_log_mean, device=coeffs_norm_media.device, dtype=coeffs_norm_media.dtype)
|
| 270 |
+
media_log_std = torch.tensor(self.media_log_std, device=coeffs_norm_media.device, dtype=coeffs_norm_media.dtype)
|
| 271 |
+
return torch.exp(coeffs_norm_media * media_log_std + media_log_mean)
|
| 272 |
+
|
| 273 |
+
def decode_ctrl_coefficients_differentiable(self, coeffs_norm_ctrl):
|
| 274 |
+
"""Decode only control coefficients, keeping gradients."""
|
| 275 |
+
coeff_mean = torch.tensor(self.coeff_mean[self.n_media:], device=coeffs_norm_ctrl.device, dtype=coeffs_norm_ctrl.dtype)
|
| 276 |
+
coeff_std = torch.tensor(self.coeff_std[self.n_media:], device=coeffs_norm_ctrl.device, dtype=coeffs_norm_ctrl.dtype)
|
| 277 |
+
return coeffs_norm_ctrl * coeff_std + coeff_mean
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
# =============================================================================
|
| 281 |
+
# 3. DIFFUSION NOISE SCHEDULE
|
| 282 |
+
# =============================================================================
|
| 283 |
+
|
| 284 |
+
def cosine_beta_schedule(T, s=0.008):
|
| 285 |
+
t = torch.arange(T + 1, dtype=torch.float64)
|
| 286 |
+
f = torch.cos((t / T + s) / (1 + s) * math.pi / 2) ** 2
|
| 287 |
+
alphas_cumprod = f / f[0]
|
| 288 |
+
betas = 1 - alphas_cumprod[1:] / alphas_cumprod[:-1]
|
| 289 |
+
return torch.clamp(betas, 0, 0.999).float()
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
class DiffusionSchedule:
|
| 293 |
+
def __init__(self, T=1000):
|
| 294 |
+
self.T = T
|
| 295 |
+
self.betas = cosine_beta_schedule(T)
|
| 296 |
+
self.alphas = 1.0 - self.betas
|
| 297 |
+
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
| 298 |
+
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
|
| 299 |
+
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
|
| 300 |
+
self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas)
|
| 301 |
+
|
| 302 |
+
self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)
|
| 303 |
+
self.posterior_variance = (
|
| 304 |
+
self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
| 305 |
+
)
|
| 306 |
+
self.posterior_log_variance_clipped = torch.log(
|
| 307 |
+
torch.clamp(self.posterior_variance, min=1e-20)
|
| 308 |
+
)
|
| 309 |
+
self.posterior_mean_coef1 = (
|
| 310 |
+
self.betas * torch.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
| 311 |
+
)
|
| 312 |
+
self.posterior_mean_coef2 = (
|
| 313 |
+
(1.0 - self.alphas_cumprod_prev) * torch.sqrt(self.alphas) / (1.0 - self.alphas_cumprod)
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
def to(self, device):
|
| 317 |
+
for attr in ['betas', 'alphas', 'alphas_cumprod', 'sqrt_alphas_cumprod',
|
| 318 |
+
'sqrt_one_minus_alphas_cumprod', 'sqrt_recip_alphas',
|
| 319 |
+
'alphas_cumprod_prev', 'posterior_variance',
|
| 320 |
+
'posterior_log_variance_clipped', 'posterior_mean_coef1',
|
| 321 |
+
'posterior_mean_coef2']:
|
| 322 |
+
setattr(self, attr, getattr(self, attr).to(device))
|
| 323 |
+
return self
|
| 324 |
+
|
| 325 |
+
def q_sample(self, x_0, t, noise=None):
|
| 326 |
+
if noise is None:
|
| 327 |
+
noise = torch.randn_like(x_0)
|
| 328 |
+
sqrt_alpha = self.sqrt_alphas_cumprod[t].view(-1, 1, 1)
|
| 329 |
+
sqrt_one_minus_alpha = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1)
|
| 330 |
+
return sqrt_alpha * x_0 + sqrt_one_minus_alpha * noise
|
| 331 |
+
|
| 332 |
+
def posterior_mean(self, x_0_pred, x_t, t):
|
| 333 |
+
coef1 = self.posterior_mean_coef1[t].view(-1, 1, 1)
|
| 334 |
+
coef2 = self.posterior_mean_coef2[t].view(-1, 1, 1)
|
| 335 |
+
return coef1 * x_0_pred + coef2 * x_t
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
# =============================================================================
|
| 339 |
+
# 4. DENOISER NETWORKS
|
| 340 |
+
# =============================================================================
|
| 341 |
+
|
| 342 |
+
class SinusoidalPositionEmbeddings(nn.Module):
|
| 343 |
+
def __init__(self, dim):
|
| 344 |
+
super().__init__()
|
| 345 |
+
self.dim = dim
|
| 346 |
+
|
| 347 |
+
def forward(self, t):
|
| 348 |
+
device = t.device
|
| 349 |
+
half_dim = self.dim // 2
|
| 350 |
+
emb = math.log(10000) / (half_dim - 1)
|
| 351 |
+
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
| 352 |
+
emb = t[:, None].float() * emb[None, :]
|
| 353 |
+
return torch.cat([emb.sin(), emb.cos()], dim=-1)
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
class TemporalTransformerBlock(nn.Module):
|
| 357 |
+
def __init__(self, d_model, nhead, dropout=0.1):
|
| 358 |
+
super().__init__()
|
| 359 |
+
self.attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
|
| 360 |
+
self.ff = nn.Sequential(
|
| 361 |
+
nn.Linear(d_model, d_model * 4),
|
| 362 |
+
nn.GELU(),
|
| 363 |
+
nn.Dropout(dropout),
|
| 364 |
+
nn.Linear(d_model * 4, d_model),
|
| 365 |
+
nn.Dropout(dropout),
|
| 366 |
+
)
|
| 367 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 368 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 369 |
+
|
| 370 |
+
def forward(self, x):
|
| 371 |
+
h = self.norm1(x)
|
| 372 |
+
h = x + self.attn(h, h, h)[0]
|
| 373 |
+
h = h + self.ff(self.norm2(h))
|
| 374 |
+
return h
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
class CampaignDenoiser(nn.Module):
|
| 378 |
+
"""Stage 1: Campaign/Geo-level Denoiser — denoises aggregate patterns."""
|
| 379 |
+
|
| 380 |
+
def __init__(self, n_agg_channels=3, cond_dim=4, d_model=256, nhead=4, n_layers=4, T_diff=1000):
|
| 381 |
+
super().__init__()
|
| 382 |
+
self.d_model = d_model
|
| 383 |
+
self.time_embed = nn.Sequential(
|
| 384 |
+
SinusoidalPositionEmbeddings(d_model),
|
| 385 |
+
nn.Linear(d_model, d_model),
|
| 386 |
+
nn.GELU(),
|
| 387 |
+
nn.Linear(d_model, d_model),
|
| 388 |
+
)
|
| 389 |
+
self.cond_proj = nn.Linear(cond_dim, d_model)
|
| 390 |
+
self.input_proj = nn.Linear(n_agg_channels, d_model)
|
| 391 |
+
self.pos_embed = nn.Parameter(torch.randn(1, 256, d_model) * 0.02)
|
| 392 |
+
self.blocks = nn.ModuleList([
|
| 393 |
+
TemporalTransformerBlock(d_model, nhead) for _ in range(n_layers)
|
| 394 |
+
])
|
| 395 |
+
self.output_proj = nn.Sequential(
|
| 396 |
+
nn.LayerNorm(d_model),
|
| 397 |
+
nn.Linear(d_model, d_model // 2),
|
| 398 |
+
nn.GELU(),
|
| 399 |
+
nn.Linear(d_model // 2, n_agg_channels),
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
def forward(self, x_t, t, cond):
|
| 403 |
+
B, T_seq, _ = x_t.shape
|
| 404 |
+
t_emb = self.time_embed(t)
|
| 405 |
+
h_x = self.input_proj(x_t)
|
| 406 |
+
h_c = self.cond_proj(cond)
|
| 407 |
+
h = h_x + h_c + t_emb.unsqueeze(1) + self.pos_embed[:, :T_seq, :]
|
| 408 |
+
for block in self.blocks:
|
| 409 |
+
h = block(h)
|
| 410 |
+
return self.output_proj(h)
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
class ChannelDenoiser(nn.Module):
|
| 414 |
+
"""Stage 2: Channel-level Denoiser — denoises per-channel coefficients."""
|
| 415 |
+
|
| 416 |
+
def __init__(self, n_channels=8, n_media=5, n_agg=3, d_model=384, nhead=8, n_layers=6, T_diff=1000):
|
| 417 |
+
super().__init__()
|
| 418 |
+
self.d_model = d_model
|
| 419 |
+
self.n_media = n_media
|
| 420 |
+
self.n_channels = n_channels
|
| 421 |
+
|
| 422 |
+
self.time_embed = nn.Sequential(
|
| 423 |
+
SinusoidalPositionEmbeddings(d_model),
|
| 424 |
+
nn.Linear(d_model, d_model),
|
| 425 |
+
nn.GELU(),
|
| 426 |
+
nn.Linear(d_model, d_model),
|
| 427 |
+
)
|
| 428 |
+
self.input_proj = nn.Linear(n_channels, d_model)
|
| 429 |
+
self.campaign_proj = nn.Linear(n_agg, d_model)
|
| 430 |
+
self.spend_proj = nn.Linear(n_media, d_model)
|
| 431 |
+
self.sales_proj = nn.Linear(1, d_model)
|
| 432 |
+
self.cross_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
|
| 433 |
+
self.cross_norm = nn.LayerNorm(d_model)
|
| 434 |
+
self.pos_embed = nn.Parameter(torch.randn(1, 256, d_model) * 0.02)
|
| 435 |
+
self.blocks = nn.ModuleList([
|
| 436 |
+
TemporalTransformerBlock(d_model, nhead) for _ in range(n_layers)
|
| 437 |
+
])
|
| 438 |
+
self.output_proj = nn.Sequential(
|
| 439 |
+
nn.LayerNorm(d_model),
|
| 440 |
+
nn.Linear(d_model, d_model // 2),
|
| 441 |
+
nn.GELU(),
|
| 442 |
+
nn.Linear(d_model // 2, n_channels),
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
def forward(self, x_t, t, campaign_ctx, media_spend, total_sales):
|
| 446 |
+
B, T_seq, _ = x_t.shape
|
| 447 |
+
t_emb = self.time_embed(t)
|
| 448 |
+
h_x = self.input_proj(x_t)
|
| 449 |
+
h_camp = self.campaign_proj(campaign_ctx)
|
| 450 |
+
h_spend = self.spend_proj(media_spend)
|
| 451 |
+
h_sales = self.sales_proj(total_sales)
|
| 452 |
+
cond_ctx = h_camp + h_spend + h_sales
|
| 453 |
+
h = h_x + t_emb.unsqueeze(1) + self.pos_embed[:, :T_seq, :]
|
| 454 |
+
h_normed = self.cross_norm(h)
|
| 455 |
+
h = h + self.cross_attn(h_normed, cond_ctx, cond_ctx)[0]
|
| 456 |
+
for block in self.blocks:
|
| 457 |
+
h = block(h)
|
| 458 |
+
return self.output_proj(h)
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
# =============================================================================
|
| 462 |
+
# 5. MMM DIFFUSION MODEL — with sales reconstruction loss + spectral loss
|
| 463 |
+
# =============================================================================
|
| 464 |
+
|
| 465 |
+
class MMMDiffusionModel(nn.Module):
|
| 466 |
+
def __init__(self, n_media=5, n_ctrl=3, d_model_campaign=256, d_model_channel=384,
|
| 467 |
+
n_layers_campaign=4, n_layers_channel=6, T_diff=1000):
|
| 468 |
+
super().__init__()
|
| 469 |
+
self.n_media = n_media
|
| 470 |
+
self.n_ctrl = n_ctrl
|
| 471 |
+
self.n_channels = n_media + n_ctrl
|
| 472 |
+
self.T_diff = T_diff
|
| 473 |
+
self.n_agg = 3
|
| 474 |
+
|
| 475 |
+
self.campaign_denoiser = CampaignDenoiser(
|
| 476 |
+
n_agg_channels=self.n_agg, cond_dim=n_ctrl + 1,
|
| 477 |
+
d_model=d_model_campaign, nhead=4, n_layers=n_layers_campaign, T_diff=T_diff,
|
| 478 |
+
)
|
| 479 |
+
self.channel_denoiser = ChannelDenoiser(
|
| 480 |
+
n_channels=self.n_channels, n_media=n_media, n_agg=self.n_agg,
|
| 481 |
+
d_model=d_model_channel, nhead=8, n_layers=n_layers_channel, T_diff=T_diff,
|
| 482 |
+
)
|
| 483 |
+
self.coeff_to_agg = nn.Linear(self.n_channels, self.n_agg)
|
| 484 |
+
self.agg_to_coeff_init = nn.Linear(self.n_agg, self.n_channels)
|
| 485 |
+
self.schedule = DiffusionSchedule(T_diff)
|
| 486 |
+
|
| 487 |
+
# Learnable base sales predictor (from conditioning)
|
| 488 |
+
self.base_predictor = nn.Sequential(
|
| 489 |
+
nn.Linear(n_ctrl + 1, 64),
|
| 490 |
+
nn.GELU(),
|
| 491 |
+
nn.Linear(64, 1),
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
def compute_aggregate(self, coefficients):
|
| 495 |
+
return self.coeff_to_agg(coefficients)
|
| 496 |
+
|
| 497 |
+
def _compute_sales_from_coeffs(self, coeffs_pred, batch, dataset_ref):
|
| 498 |
+
"""
|
| 499 |
+
FIX #1: Compute predicted sales from predicted coefficients.
|
| 500 |
+
This creates a differentiable path: coeffs -> contributions -> total sales.
|
| 501 |
+
"""
|
| 502 |
+
# Decode coefficients to original scale (differentiable)
|
| 503 |
+
media_pred_norm = coeffs_pred[:, :, :self.n_media]
|
| 504 |
+
ctrl_pred_norm = coeffs_pred[:, :, self.n_media:]
|
| 505 |
+
|
| 506 |
+
media_coeffs = dataset_ref.decode_media_coefficients_differentiable(media_pred_norm)
|
| 507 |
+
ctrl_coeffs = dataset_ref.decode_ctrl_coefficients_differentiable(ctrl_pred_norm)
|
| 508 |
+
|
| 509 |
+
# Compute contributions
|
| 510 |
+
trans_media = batch['transformed_media'] # (B, T, 5)
|
| 511 |
+
controls = batch['controls'] # (B, T, 3)
|
| 512 |
+
|
| 513 |
+
media_contributions = media_coeffs * trans_media # (B, T, 5)
|
| 514 |
+
ctrl_contributions = ctrl_coeffs * controls # (B, T, 3)
|
| 515 |
+
|
| 516 |
+
total_contributions = media_contributions.sum(dim=-1) + ctrl_contributions.sum(dim=-1) # (B, T)
|
| 517 |
+
|
| 518 |
+
# Predict base sales from conditioning
|
| 519 |
+
cond = batch['conditioning']
|
| 520 |
+
stage1_cond = torch.cat([
|
| 521 |
+
cond[:, :, self.n_media:self.n_media + self.n_ctrl],
|
| 522 |
+
cond[:, :, -1:]
|
| 523 |
+
], dim=-1)
|
| 524 |
+
base_pred = self.base_predictor(stage1_cond).squeeze(-1) # (B, T)
|
| 525 |
+
|
| 526 |
+
pred_sales = base_pred + total_contributions
|
| 527 |
+
return pred_sales
|
| 528 |
+
|
| 529 |
+
def _spectral_loss(self, pred, target):
|
| 530 |
+
"""
|
| 531 |
+
FIX #2: Spectral (frequency-domain) loss to preserve temporal variation.
|
| 532 |
+
Uses log-magnitude to keep scale comparable to other losses.
|
| 533 |
+
Weights higher frequencies more to fight smoothing.
|
| 534 |
+
"""
|
| 535 |
+
pred_fft = torch.fft.rfft(pred, dim=1)
|
| 536 |
+
target_fft = torch.fft.rfft(target, dim=1)
|
| 537 |
+
|
| 538 |
+
# Log-magnitude for scale normalization (brings to ~O(1) range)
|
| 539 |
+
pred_mag = torch.log1p(torch.abs(pred_fft))
|
| 540 |
+
target_mag = torch.log1p(torch.abs(target_fft))
|
| 541 |
+
|
| 542 |
+
# Weight higher frequencies more to fight smoothing
|
| 543 |
+
n_freq = pred_mag.shape[1]
|
| 544 |
+
freq_weights = torch.linspace(1.0, 3.0, n_freq, device=pred.device)
|
| 545 |
+
freq_weights = freq_weights.view(1, -1, 1)
|
| 546 |
+
|
| 547 |
+
return F.mse_loss(pred_mag * freq_weights, target_mag * freq_weights)
|
| 548 |
+
|
| 549 |
+
def _multi_scale_temporal_loss(self, pred, target):
|
| 550 |
+
"""
|
| 551 |
+
Multi-scale temporal difference loss. Matches first AND second order
|
| 552 |
+
temporal derivatives to capture both trends and curvature.
|
| 553 |
+
"""
|
| 554 |
+
# First-order differences (velocity)
|
| 555 |
+
d1_pred = pred[:, 1:, :] - pred[:, :-1, :]
|
| 556 |
+
d1_true = target[:, 1:, :] - target[:, :-1, :]
|
| 557 |
+
loss_d1 = F.mse_loss(d1_pred, d1_true)
|
| 558 |
+
|
| 559 |
+
# Second-order differences (acceleration / curvature)
|
| 560 |
+
d2_pred = d1_pred[:, 1:, :] - d1_pred[:, :-1, :]
|
| 561 |
+
d2_true = d1_true[:, 1:, :] - d1_true[:, :-1, :]
|
| 562 |
+
loss_d2 = F.mse_loss(d2_pred, d2_true)
|
| 563 |
+
|
| 564 |
+
return loss_d1 + 0.5 * loss_d2
|
| 565 |
+
|
| 566 |
+
def forward_train(self, batch, dataset_ref=None, epoch=0, total_epochs=80):
|
| 567 |
+
cond = batch['conditioning']
|
| 568 |
+
coeffs = batch['coefficients']
|
| 569 |
+
B, T_seq, _ = coeffs.shape
|
| 570 |
+
device = coeffs.device
|
| 571 |
+
|
| 572 |
+
media_spend = cond[:, :, :self.n_media]
|
| 573 |
+
controls = cond[:, :, self.n_media:self.n_media + self.n_ctrl]
|
| 574 |
+
total_sales = cond[:, :, -1:]
|
| 575 |
+
stage1_cond = torch.cat([controls, total_sales], dim=-1)
|
| 576 |
+
|
| 577 |
+
with torch.no_grad():
|
| 578 |
+
agg_target = self.coeff_to_agg(coeffs)
|
| 579 |
+
|
| 580 |
+
# ---- Stage 1: Campaign Denoiser ----
|
| 581 |
+
t1 = torch.randint(0, self.T_diff, (B,), device=device)
|
| 582 |
+
noise1 = torch.randn_like(agg_target)
|
| 583 |
+
agg_noisy = self.schedule.q_sample(agg_target, t1, noise1)
|
| 584 |
+
agg_pred = self.campaign_denoiser(agg_noisy, t1, stage1_cond)
|
| 585 |
+
loss_campaign = F.mse_loss(agg_pred, agg_target)
|
| 586 |
+
|
| 587 |
+
# ---- Stage 2: Channel Denoiser ----
|
| 588 |
+
# Uniform timestep sampling (removed biased sampling which hurt learning)
|
| 589 |
+
t2 = torch.randint(0, self.T_diff, (B,), device=device)
|
| 590 |
+
|
| 591 |
+
noise2 = torch.randn_like(coeffs)
|
| 592 |
+
coeffs_noisy = self.schedule.q_sample(coeffs, t2, noise2)
|
| 593 |
+
campaign_ctx = agg_target.detach()
|
| 594 |
+
|
| 595 |
+
coeffs_pred = self.channel_denoiser(
|
| 596 |
+
coeffs_noisy, t2, campaign_ctx, media_spend, total_sales
|
| 597 |
+
)
|
| 598 |
+
|
| 599 |
+
loss_channel = F.mse_loss(coeffs_pred, coeffs)
|
| 600 |
+
|
| 601 |
+
# ---- Warmup schedule for auxiliary losses ----
|
| 602 |
+
# Phase 1 (epochs 0-warmup): focus on core denoising (channel + campaign)
|
| 603 |
+
# Phase 2 (epochs warmup+): gradually add sales, contrib, spectral losses
|
| 604 |
+
warmup_epochs = total_epochs // 4 # first 25% is warmup
|
| 605 |
+
aux_weight = max(0.0, min(1.0, (epoch - warmup_epochs) / max(warmup_epochs, 1)))
|
| 606 |
+
|
| 607 |
+
# ---- Sales reconstruction loss (only after warmup) ----
|
| 608 |
+
loss_sales = torch.tensor(0.0, device=device)
|
| 609 |
+
if dataset_ref is not None and aux_weight > 0:
|
| 610 |
+
# Only compute for VERY low noise (t < T/10) where prediction is accurate
|
| 611 |
+
low_noise_mask = t2 < (self.T_diff // 10)
|
| 612 |
+
if low_noise_mask.any():
|
| 613 |
+
pred_sales = self._compute_sales_from_coeffs(coeffs_pred, batch, dataset_ref)
|
| 614 |
+
actual_sales = batch['total_sales']
|
| 615 |
+
# Use relative error (scale-invariant)
|
| 616 |
+
scale = actual_sales[low_noise_mask].abs().mean() + 1e-8
|
| 617 |
+
loss_sales = F.mse_loss(
|
| 618 |
+
pred_sales[low_noise_mask] / scale,
|
| 619 |
+
actual_sales[low_noise_mask] / scale,
|
| 620 |
+
)
|
| 621 |
+
|
| 622 |
+
# ---- Spectral loss ----
|
| 623 |
+
loss_spectral = self._spectral_loss(coeffs_pred, coeffs)
|
| 624 |
+
|
| 625 |
+
# ---- Multi-scale temporal loss ----
|
| 626 |
+
loss_temporal = self._multi_scale_temporal_loss(coeffs_pred, coeffs)
|
| 627 |
+
|
| 628 |
+
# ---- Contribution matching loss (only after warmup) ----
|
| 629 |
+
loss_contrib = torch.tensor(0.0, device=device)
|
| 630 |
+
if dataset_ref is not None and aux_weight > 0:
|
| 631 |
+
low_noise_mask = t2 < (self.T_diff // 10)
|
| 632 |
+
if low_noise_mask.any():
|
| 633 |
+
media_coeffs = dataset_ref.decode_media_coefficients_differentiable(
|
| 634 |
+
coeffs_pred[low_noise_mask, :, :self.n_media]
|
| 635 |
+
)
|
| 636 |
+
ctrl_coeffs = dataset_ref.decode_ctrl_coefficients_differentiable(
|
| 637 |
+
coeffs_pred[low_noise_mask, :, self.n_media:]
|
| 638 |
+
)
|
| 639 |
+
pred_contrib_media = media_coeffs * batch['transformed_media'][low_noise_mask]
|
| 640 |
+
pred_contrib_ctrl = ctrl_coeffs * batch['controls'][low_noise_mask]
|
| 641 |
+
pred_contrib = torch.cat([pred_contrib_media, pred_contrib_ctrl], dim=-1)
|
| 642 |
+
true_contrib = batch['contributions'][low_noise_mask]
|
| 643 |
+
contrib_scale = true_contrib.abs().mean() + 1e-8
|
| 644 |
+
loss_contrib = F.mse_loss(pred_contrib / contrib_scale, true_contrib / contrib_scale)
|
| 645 |
+
|
| 646 |
+
# Sign loss (soft positivity for media in log-space)
|
| 647 |
+
media_pred_log = coeffs_pred[:, :, :self.n_media]
|
| 648 |
+
loss_sign = F.relu(-media_pred_log - 5.0).mean()
|
| 649 |
+
|
| 650 |
+
# ---- Total loss with warmup-gated auxiliary losses ----
|
| 651 |
+
# Core losses: always active (coefficient matching is primary)
|
| 652 |
+
# Aux losses: ramp in after warmup with controlled weights
|
| 653 |
+
loss = (
|
| 654 |
+
1.0 * loss_campaign +
|
| 655 |
+
2.0 * loss_channel + # PRIMARY: coefficient matching (doubled weight)
|
| 656 |
+
0.5 * loss_spectral + # Anti-smoothing (always active, log-scale ~O(1))
|
| 657 |
+
0.1 * loss_temporal + # Temporal dynamics
|
| 658 |
+
aux_weight * 0.2 * loss_sales + # Sales alignment (ramped in)
|
| 659 |
+
aux_weight * 0.2 * loss_contrib + # Contribution matching (ramped in)
|
| 660 |
+
0.01 * loss_sign
|
| 661 |
+
)
|
| 662 |
+
|
| 663 |
+
return {
|
| 664 |
+
'loss': loss,
|
| 665 |
+
'loss_campaign': loss_campaign.item(),
|
| 666 |
+
'loss_channel': loss_channel.item(),
|
| 667 |
+
'loss_sales': loss_sales.item() if isinstance(loss_sales, torch.Tensor) else loss_sales,
|
| 668 |
+
'loss_spectral': loss_spectral.item(),
|
| 669 |
+
'loss_temporal': loss_temporal.item(),
|
| 670 |
+
'loss_contrib': loss_contrib.item() if isinstance(loss_contrib, torch.Tensor) else loss_contrib,
|
| 671 |
+
'loss_sign': loss_sign.item(),
|
| 672 |
+
}
|
| 673 |
+
|
| 674 |
+
@torch.no_grad()
|
| 675 |
+
def sample(self, conditioning, n_steps=None, constraint_every_k=10, guidance_scale=1.0):
|
| 676 |
+
B, T_seq, _ = conditioning.shape
|
| 677 |
+
device = conditioning.device
|
| 678 |
+
|
| 679 |
+
media_spend = conditioning[:, :, :self.n_media]
|
| 680 |
+
controls = conditioning[:, :, self.n_media:self.n_media + self.n_ctrl]
|
| 681 |
+
total_sales = conditioning[:, :, -1:]
|
| 682 |
+
stage1_cond = torch.cat([controls, total_sales], dim=-1)
|
| 683 |
+
|
| 684 |
+
T_diff = n_steps or self.T_diff
|
| 685 |
+
|
| 686 |
+
# Stage 1: Denoise aggregate patterns
|
| 687 |
+
z_t = torch.randn(B, T_seq, self.n_agg, device=device)
|
| 688 |
+
for t in reversed(range(T_diff)):
|
| 689 |
+
t_batch = torch.full((B,), t, device=device, dtype=torch.long)
|
| 690 |
+
z_0_pred = self.campaign_denoiser(z_t, t_batch, stage1_cond)
|
| 691 |
+
if t > 0:
|
| 692 |
+
mean = self.schedule.posterior_mean(z_0_pred, z_t, t_batch)
|
| 693 |
+
var = self.schedule.posterior_variance[t]
|
| 694 |
+
noise = torch.randn_like(z_t)
|
| 695 |
+
z_t = mean + torch.sqrt(var) * noise
|
| 696 |
+
else:
|
| 697 |
+
z_t = z_0_pred
|
| 698 |
+
|
| 699 |
+
campaign_ctx = z_t
|
| 700 |
+
|
| 701 |
+
# Stage 2: Denoise channel coefficients
|
| 702 |
+
x_t = torch.randn(B, T_seq, self.n_channels, device=device)
|
| 703 |
+
for t in reversed(range(T_diff)):
|
| 704 |
+
t_batch = torch.full((B,), t, device=device, dtype=torch.long)
|
| 705 |
+
x_0_pred = self.channel_denoiser(
|
| 706 |
+
x_t, t_batch, campaign_ctx, media_spend, total_sales
|
| 707 |
+
)
|
| 708 |
+
|
| 709 |
+
# PhysDiff-style constraint projection
|
| 710 |
+
if t % constraint_every_k == 0:
|
| 711 |
+
x_0_pred[:, :, :self.n_media] = torch.clamp(
|
| 712 |
+
x_0_pred[:, :, :self.n_media], min=-8.0, max=8.0
|
| 713 |
+
)
|
| 714 |
+
|
| 715 |
+
if t > 0:
|
| 716 |
+
mean = self.schedule.posterior_mean(x_0_pred, x_t, t_batch)
|
| 717 |
+
var = self.schedule.posterior_variance[t]
|
| 718 |
+
noise = torch.randn_like(x_t)
|
| 719 |
+
x_t = mean + torch.sqrt(var) * noise
|
| 720 |
+
else:
|
| 721 |
+
x_t = x_0_pred
|
| 722 |
+
|
| 723 |
+
return x_t
|
| 724 |
+
|
| 725 |
+
@torch.no_grad()
|
| 726 |
+
def sample_ddim(self, conditioning, n_steps=50, constraint_every_k=5, eta=0.0):
|
| 727 |
+
"""
|
| 728 |
+
DDIM sampling — faster and more deterministic than DDPM.
|
| 729 |
+
eta=0: fully deterministic, eta=1: equivalent to DDPM.
|
| 730 |
+
"""
|
| 731 |
+
B, T_seq, _ = conditioning.shape
|
| 732 |
+
device = conditioning.device
|
| 733 |
+
|
| 734 |
+
media_spend = conditioning[:, :, :self.n_media]
|
| 735 |
+
controls = conditioning[:, :, self.n_media:self.n_media + self.n_ctrl]
|
| 736 |
+
total_sales = conditioning[:, :, -1:]
|
| 737 |
+
stage1_cond = torch.cat([controls, total_sales], dim=-1)
|
| 738 |
+
|
| 739 |
+
# Create sub-sequence of timesteps for DDIM
|
| 740 |
+
step_size = max(self.T_diff // n_steps, 1)
|
| 741 |
+
timesteps = list(range(0, self.T_diff, step_size))
|
| 742 |
+
timesteps = list(reversed(timesteps))
|
| 743 |
+
|
| 744 |
+
# Stage 1: DDIM denoise aggregate
|
| 745 |
+
z_t = torch.randn(B, T_seq, self.n_agg, device=device)
|
| 746 |
+
for i, t in enumerate(timesteps):
|
| 747 |
+
t_batch = torch.full((B,), t, device=device, dtype=torch.long)
|
| 748 |
+
z_0_pred = self.campaign_denoiser(z_t, t_batch, stage1_cond)
|
| 749 |
+
|
| 750 |
+
if i < len(timesteps) - 1:
|
| 751 |
+
t_next = timesteps[i + 1]
|
| 752 |
+
alpha_t = self.schedule.alphas_cumprod[t]
|
| 753 |
+
alpha_next = self.schedule.alphas_cumprod[t_next]
|
| 754 |
+
|
| 755 |
+
# DDIM update
|
| 756 |
+
pred_noise = (z_t - torch.sqrt(alpha_t) * z_0_pred) / torch.sqrt(1 - alpha_t)
|
| 757 |
+
sigma = eta * torch.sqrt((1 - alpha_next) / (1 - alpha_t)) * torch.sqrt(1 - alpha_t / alpha_next)
|
| 758 |
+
|
| 759 |
+
z_t = (torch.sqrt(alpha_next) * z_0_pred +
|
| 760 |
+
torch.sqrt(1 - alpha_next - sigma**2) * pred_noise +
|
| 761 |
+
sigma * torch.randn_like(z_t))
|
| 762 |
+
else:
|
| 763 |
+
z_t = z_0_pred
|
| 764 |
+
|
| 765 |
+
campaign_ctx = z_t
|
| 766 |
+
|
| 767 |
+
# Stage 2: DDIM denoise channel coefficients
|
| 768 |
+
x_t = torch.randn(B, T_seq, self.n_channels, device=device)
|
| 769 |
+
for i, t in enumerate(timesteps):
|
| 770 |
+
t_batch = torch.full((B,), t, device=device, dtype=torch.long)
|
| 771 |
+
x_0_pred = self.channel_denoiser(
|
| 772 |
+
x_t, t_batch, campaign_ctx, media_spend, total_sales
|
| 773 |
+
)
|
| 774 |
+
|
| 775 |
+
# PhysDiff projection
|
| 776 |
+
if t % constraint_every_k == 0:
|
| 777 |
+
x_0_pred[:, :, :self.n_media] = torch.clamp(
|
| 778 |
+
x_0_pred[:, :, :self.n_media], min=-8.0, max=8.0
|
| 779 |
+
)
|
| 780 |
+
|
| 781 |
+
if i < len(timesteps) - 1:
|
| 782 |
+
t_next = timesteps[i + 1]
|
| 783 |
+
alpha_t = self.schedule.alphas_cumprod[t]
|
| 784 |
+
alpha_next = self.schedule.alphas_cumprod[t_next]
|
| 785 |
+
|
| 786 |
+
pred_noise = (x_t - torch.sqrt(alpha_t) * x_0_pred) / torch.sqrt(1 - alpha_t)
|
| 787 |
+
sigma = eta * torch.sqrt((1 - alpha_next) / (1 - alpha_t)) * torch.sqrt(1 - alpha_t / alpha_next)
|
| 788 |
+
|
| 789 |
+
x_t = (torch.sqrt(alpha_next) * x_0_pred +
|
| 790 |
+
torch.sqrt(1 - alpha_next - sigma**2) * pred_noise +
|
| 791 |
+
sigma * torch.randn_like(x_t))
|
| 792 |
+
else:
|
| 793 |
+
x_t = x_0_pred
|
| 794 |
+
|
| 795 |
+
return x_t
|
| 796 |
+
|
| 797 |
+
|
| 798 |
+
# =============================================================================
|
| 799 |
+
# 6. TRAINING LOOP
|
| 800 |
+
# =============================================================================
|
| 801 |
+
|
| 802 |
+
def train_mmm_diffusion(
|
| 803 |
+
model, dataset,
|
| 804 |
+
n_epochs=50, batch_size=16, lr=1e-4,
|
| 805 |
+
device='cpu', log_every=50, save_path='mmm_diffusion_model.pt'
|
| 806 |
+
):
|
| 807 |
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
|
| 808 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
|
| 809 |
+
|
| 810 |
+
# Warmup + cosine decay
|
| 811 |
+
warmup_steps = len(dataloader) * (n_epochs // 10) # 10% warmup
|
| 812 |
+
total_steps = len(dataloader) * n_epochs
|
| 813 |
+
|
| 814 |
+
def lr_lambda(step):
|
| 815 |
+
if step < warmup_steps:
|
| 816 |
+
return step / max(warmup_steps, 1)
|
| 817 |
+
progress = (step - warmup_steps) / max(total_steps - warmup_steps, 1)
|
| 818 |
+
return 0.5 * (1 + math.cos(math.pi * progress))
|
| 819 |
+
|
| 820 |
+
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
| 821 |
+
|
| 822 |
+
model = model.to(device)
|
| 823 |
+
model.schedule = model.schedule.to(device)
|
| 824 |
+
|
| 825 |
+
history = {
|
| 826 |
+
'loss': [], 'loss_campaign': [], 'loss_channel': [],
|
| 827 |
+
'loss_sales': [], 'loss_spectral': [], 'loss_temporal': [],
|
| 828 |
+
'loss_contrib': [], 'loss_sign': [],
|
| 829 |
+
}
|
| 830 |
+
|
| 831 |
+
print(f"\nTraining MMM-Diffusion Model v2")
|
| 832 |
+
print(f" Device: {device}")
|
| 833 |
+
print(f" Samples: {len(dataset)}, Batch size: {batch_size}")
|
| 834 |
+
print(f" Epochs: {n_epochs}, LR: {lr}")
|
| 835 |
+
print(f" Model params: {sum(p.numel() for p in model.parameters()):,}")
|
| 836 |
+
print(f" Diffusion steps: {model.T_diff}")
|
| 837 |
+
print(f" NEW losses: sales_recon, spectral, contribution_match")
|
| 838 |
+
print("-" * 70)
|
| 839 |
+
|
| 840 |
+
step = 0
|
| 841 |
+
best_loss = float('inf')
|
| 842 |
+
|
| 843 |
+
for epoch in range(n_epochs):
|
| 844 |
+
model.train()
|
| 845 |
+
epoch_losses = {k: [] for k in history}
|
| 846 |
+
|
| 847 |
+
for batch in dataloader:
|
| 848 |
+
batch = {k: v.to(device) for k, v in batch.items()}
|
| 849 |
+
|
| 850 |
+
losses = model.forward_train(batch, dataset_ref=dataset, epoch=epoch, total_epochs=n_epochs)
|
| 851 |
+
|
| 852 |
+
optimizer.zero_grad()
|
| 853 |
+
losses['loss'].backward()
|
| 854 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 855 |
+
optimizer.step()
|
| 856 |
+
scheduler.step() # step-level scheduling for warmup
|
| 857 |
+
|
| 858 |
+
for k in history:
|
| 859 |
+
val = losses[k].item() if isinstance(losses[k], torch.Tensor) else losses[k]
|
| 860 |
+
epoch_losses[k].append(val)
|
| 861 |
+
|
| 862 |
+
step += 1
|
| 863 |
+
if step % log_every == 0:
|
| 864 |
+
avg = {k: np.mean(v[-log_every:]) for k, v in epoch_losses.items() if v}
|
| 865 |
+
print(f" Step {step:5d} | loss={avg['loss']:.4f} "
|
| 866 |
+
f"camp={avg['loss_campaign']:.4f} chan={avg['loss_channel']:.4f} "
|
| 867 |
+
f"sales={avg.get('loss_sales', 0):.4f} spec={avg.get('loss_spectral', 0):.4f} "
|
| 868 |
+
f"temp={avg.get('loss_temporal', 0):.4f} contr={avg.get('loss_contrib', 0):.4f}")
|
| 869 |
+
|
| 870 |
+
# scheduler steps per-batch above
|
| 871 |
+
|
| 872 |
+
avg = {k: np.mean(v) for k, v in epoch_losses.items() if v}
|
| 873 |
+
for k, v in avg.items():
|
| 874 |
+
history[k].append(v)
|
| 875 |
+
|
| 876 |
+
if avg['loss'] < best_loss:
|
| 877 |
+
best_loss = avg['loss']
|
| 878 |
+
torch.save({
|
| 879 |
+
'model_state_dict': model.state_dict(),
|
| 880 |
+
'history': history,
|
| 881 |
+
'config': {
|
| 882 |
+
'n_media': model.n_media, 'n_ctrl': model.n_ctrl, 'T_diff': model.T_diff,
|
| 883 |
+
}
|
| 884 |
+
}, save_path.replace('.pt', '_best.pt'))
|
| 885 |
+
|
| 886 |
+
print(f"Epoch {epoch+1:3d}/{n_epochs} | loss={avg['loss']:.4f} "
|
| 887 |
+
f"camp={avg['loss_campaign']:.4f} chan={avg['loss_channel']:.4f} "
|
| 888 |
+
f"sales={avg.get('loss_sales', 0):.4f} spec={avg.get('loss_spectral', 0):.4f} "
|
| 889 |
+
f"lr={scheduler.get_last_lr()[0]:.6f}")
|
| 890 |
+
|
| 891 |
+
torch.save({
|
| 892 |
+
'model_state_dict': model.state_dict(),
|
| 893 |
+
'history': history,
|
| 894 |
+
'config': {
|
| 895 |
+
'n_media': model.n_media, 'n_ctrl': model.n_ctrl, 'T_diff': model.T_diff,
|
| 896 |
+
}
|
| 897 |
+
}, save_path)
|
| 898 |
+
print(f"\nModel saved to {save_path}")
|
| 899 |
+
|
| 900 |
+
return history
|
| 901 |
+
|
| 902 |
+
|
| 903 |
+
# =============================================================================
|
| 904 |
+
# 7. SALES DECOMPOSITION
|
| 905 |
+
# =============================================================================
|
| 906 |
+
|
| 907 |
+
def decompose_sales(coefficients, media_spend, controls, transformed_media=None, base_sales=None):
|
| 908 |
+
"""
|
| 909 |
+
FIX: Use transformed media (adstock+Hill) if available for accurate decomposition.
|
| 910 |
+
"""
|
| 911 |
+
T, n_total = coefficients.shape
|
| 912 |
+
n_media = 5
|
| 913 |
+
|
| 914 |
+
contributions = {}
|
| 915 |
+
total_media = np.zeros(T)
|
| 916 |
+
|
| 917 |
+
for m in range(n_media):
|
| 918 |
+
name = MMMDataGenerator.MEDIA_CHANNELS[m]
|
| 919 |
+
if transformed_media is not None:
|
| 920 |
+
feature = transformed_media[:, m]
|
| 921 |
+
else:
|
| 922 |
+
spend = media_spend[:, m]
|
| 923 |
+
feature = spend / (np.percentile(spend, 90) + 1e-10)
|
| 924 |
+
|
| 925 |
+
contrib = coefficients[:, m] * feature
|
| 926 |
+
contributions[name] = contrib
|
| 927 |
+
total_media += contrib
|
| 928 |
+
|
| 929 |
+
total_ctrl = np.zeros(T)
|
| 930 |
+
for c in range(3):
|
| 931 |
+
name = MMMDataGenerator.CONTROL_VARS[c]
|
| 932 |
+
contrib = coefficients[:, n_media + c] * controls[:, c]
|
| 933 |
+
contributions[name] = contrib
|
| 934 |
+
total_ctrl += contrib
|
| 935 |
+
|
| 936 |
+
contributions['Total_Media'] = total_media
|
| 937 |
+
contributions['Total_Controls'] = total_ctrl
|
| 938 |
+
|
| 939 |
+
if base_sales is not None:
|
| 940 |
+
contributions['Predicted_Sales'] = base_sales + total_media + total_ctrl
|
| 941 |
+
else:
|
| 942 |
+
contributions['Predicted_Sales'] = total_media + total_ctrl
|
| 943 |
+
|
| 944 |
+
return contributions
|
| 945 |
+
|
| 946 |
+
|
| 947 |
+
# =============================================================================
|
| 948 |
+
# 8. VISUALIZATION
|
| 949 |
+
# =============================================================================
|
| 950 |
+
|
| 951 |
+
def plot_training_history(history, save_path='training_history.png'):
|
| 952 |
+
keys_to_plot = ['loss', 'loss_campaign', 'loss_channel', 'loss_sales',
|
| 953 |
+
'loss_spectral', 'loss_temporal', 'loss_contrib']
|
| 954 |
+
keys_to_plot = [k for k in keys_to_plot if k in history and len(history[k]) > 0]
|
| 955 |
+
|
| 956 |
+
n_plots = len(keys_to_plot)
|
| 957 |
+
cols = 3
|
| 958 |
+
rows = (n_plots + cols - 1) // cols
|
| 959 |
+
fig, axes = plt.subplots(rows, cols, figsize=(16, 4 * rows))
|
| 960 |
+
axes = axes.flatten()
|
| 961 |
+
|
| 962 |
+
for i, key in enumerate(keys_to_plot):
|
| 963 |
+
values = history[key]
|
| 964 |
+
axes[i].plot(values, linewidth=1.5)
|
| 965 |
+
axes[i].set_title(f'{key}', fontsize=12)
|
| 966 |
+
axes[i].set_xlabel('Epoch')
|
| 967 |
+
axes[i].set_ylabel('Loss')
|
| 968 |
+
axes[i].grid(True, alpha=0.3)
|
| 969 |
+
if min(values) > 0:
|
| 970 |
+
axes[i].set_yscale('log')
|
| 971 |
+
|
| 972 |
+
for i in range(len(keys_to_plot), len(axes)):
|
| 973 |
+
axes[i].set_visible(False)
|
| 974 |
+
|
| 975 |
+
plt.suptitle('MMM-Diffusion v2 Training History', fontsize=14, fontweight='bold')
|
| 976 |
+
plt.tight_layout()
|
| 977 |
+
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
| 978 |
+
plt.close()
|
| 979 |
+
print(f"Training history plot saved to {save_path}")
|
| 980 |
+
|
| 981 |
+
|
| 982 |
+
def plot_coefficient_comparison(true_coeffs, pred_coeffs, channel_names, save_path='coeff_comparison.png'):
|
| 983 |
+
n_channels = true_coeffs.shape[1]
|
| 984 |
+
fig, axes = plt.subplots(n_channels, 1, figsize=(14, 2.5 * n_channels))
|
| 985 |
+
|
| 986 |
+
for i, (ax, name) in enumerate(zip(axes, channel_names)):
|
| 987 |
+
ax.plot(true_coeffs[:, i], 'b-', label='Ground Truth', linewidth=1.5)
|
| 988 |
+
ax.plot(pred_coeffs[:, i], 'r--', label='Predicted', linewidth=1.5, alpha=0.8)
|
| 989 |
+
ax.set_title(f'{name} — Time-Varying Coefficient', fontsize=11)
|
| 990 |
+
ax.legend(fontsize=9)
|
| 991 |
+
ax.grid(True, alpha=0.3)
|
| 992 |
+
if i < 5:
|
| 993 |
+
ax.axhline(y=0, color='gray', linestyle=':', alpha=0.5)
|
| 994 |
+
ax.set_ylabel('β (≥0)')
|
| 995 |
+
else:
|
| 996 |
+
ax.set_ylabel('β')
|
| 997 |
+
|
| 998 |
+
axes[-1].set_xlabel('Week')
|
| 999 |
+
plt.suptitle('MMM-Diffusion v2: Coefficient Prediction Quality', fontsize=14, fontweight='bold')
|
| 1000 |
+
plt.tight_layout()
|
| 1001 |
+
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
| 1002 |
+
plt.close()
|
| 1003 |
+
print(f"Coefficient comparison plot saved to {save_path}")
|
| 1004 |
+
|
| 1005 |
+
|
| 1006 |
+
def plot_sales_decomposition(contributions, total_sales, save_path='sales_decomposition.png'):
|
| 1007 |
+
fig, axes = plt.subplots(2, 1, figsize=(14, 10))
|
| 1008 |
+
|
| 1009 |
+
ax = axes[0]
|
| 1010 |
+
weeks = np.arange(len(total_sales))
|
| 1011 |
+
media_names = MMMDataGenerator.MEDIA_CHANNELS
|
| 1012 |
+
colors = plt.cm.Set2(np.linspace(0, 1, len(media_names)))
|
| 1013 |
+
|
| 1014 |
+
bottom = np.zeros(len(total_sales))
|
| 1015 |
+
for name, color in zip(media_names, colors):
|
| 1016 |
+
vals = np.maximum(contributions[name], 0)
|
| 1017 |
+
ax.fill_between(weeks, bottom, bottom + vals, alpha=0.7, label=name, color=color)
|
| 1018 |
+
bottom += vals
|
| 1019 |
+
|
| 1020 |
+
ax.plot(weeks, total_sales, 'k-', linewidth=2, label='Total Sales', alpha=0.8)
|
| 1021 |
+
ax.set_title('Sales Decomposition: Media Channel Contributions', fontsize=12)
|
| 1022 |
+
ax.legend(loc='upper left', fontsize=9)
|
| 1023 |
+
ax.set_xlabel('Week')
|
| 1024 |
+
ax.set_ylabel('Sales Contribution')
|
| 1025 |
+
ax.grid(True, alpha=0.3)
|
| 1026 |
+
|
| 1027 |
+
ax = axes[1]
|
| 1028 |
+
ax.plot(weeks, total_sales, 'b-', linewidth=2, label='Actual Sales')
|
| 1029 |
+
if 'Predicted_Sales' in contributions:
|
| 1030 |
+
ax.plot(weeks, contributions['Predicted_Sales'], 'r--', linewidth=2,
|
| 1031 |
+
label='Predicted (Base + Media + Controls)', alpha=0.8)
|
| 1032 |
+
# Add R² annotation
|
| 1033 |
+
ss_res = np.sum((total_sales - contributions['Predicted_Sales'])**2)
|
| 1034 |
+
ss_tot = np.sum((total_sales - total_sales.mean())**2)
|
| 1035 |
+
r2 = 1 - ss_res / (ss_tot + 1e-10)
|
| 1036 |
+
mape = np.mean(np.abs(total_sales - contributions['Predicted_Sales']) / (np.abs(total_sales) + 1e-10)) * 100
|
| 1037 |
+
ax.text(0.02, 0.95, f'R² = {r2:.4f}\nMAPE = {mape:.1f}%',
|
| 1038 |
+
transform=ax.transAxes, fontsize=11, verticalalignment='top',
|
| 1039 |
+
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
|
| 1040 |
+
ax.set_title('Total Sales: Actual vs Predicted Decomposition', fontsize=12)
|
| 1041 |
+
ax.legend(fontsize=10)
|
| 1042 |
+
ax.set_xlabel('Week')
|
| 1043 |
+
ax.set_ylabel('Sales')
|
| 1044 |
+
ax.grid(True, alpha=0.3)
|
| 1045 |
+
|
| 1046 |
+
plt.suptitle('MMM-Diffusion v2: Sales Decomposition', fontsize=14, fontweight='bold')
|
| 1047 |
+
plt.tight_layout()
|
| 1048 |
+
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
| 1049 |
+
plt.close()
|
| 1050 |
+
print(f"Sales decomposition plot saved to {save_path}")
|
| 1051 |
+
|
| 1052 |
+
|
| 1053 |
+
# =============================================================================
|
| 1054 |
+
# 9. MAIN
|
| 1055 |
+
# =============================================================================
|
| 1056 |
+
|
| 1057 |
+
def main():
|
| 1058 |
+
import time
|
| 1059 |
+
|
| 1060 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 1061 |
+
print(f"=" * 70)
|
| 1062 |
+
print(f"MMM-DIFFUSION v2: Fixed Sales Alignment + Coefficient Dynamics")
|
| 1063 |
+
print(f"Device: {device}")
|
| 1064 |
+
print(f"=" * 70)
|
| 1065 |
+
|
| 1066 |
+
# ---- Step 1: Generate synthetic data ----
|
| 1067 |
+
print("\n[1/5] Generating synthetic MMM data (higher volatility)...")
|
| 1068 |
+
t0 = time.time()
|
| 1069 |
+
|
| 1070 |
+
gen = MMMDataGenerator(n_weeks=104, seed=42)
|
| 1071 |
+
n_train = 800
|
| 1072 |
+
n_val = 50
|
| 1073 |
+
|
| 1074 |
+
train_samples = gen.generate_dataset(n_train)
|
| 1075 |
+
val_samples = gen.generate_dataset(n_val)
|
| 1076 |
+
|
| 1077 |
+
print(f" Generated {n_train} train + {n_val} val scenarios")
|
| 1078 |
+
print(f" Each: {gen.n_weeks} weeks, {gen.n_media} media + {gen.n_ctrl} control vars")
|
| 1079 |
+
print(f" Time: {time.time()-t0:.1f}s")
|
| 1080 |
+
|
| 1081 |
+
# Quick audit
|
| 1082 |
+
sample = train_samples[0]
|
| 1083 |
+
coeffs = sample['true_coefficients']
|
| 1084 |
+
print(f"\n Data audit:")
|
| 1085 |
+
print(f" Media spend shape: {sample['media_spend'].shape}")
|
| 1086 |
+
print(f" Media coeff range: [{coeffs[:,:5].min():.4f}, {coeffs[:,:5].max():.4f}]")
|
| 1087 |
+
print(f" Media coeff std per channel: {coeffs[:,:5].std(axis=0).round(4)}")
|
| 1088 |
+
print(f" Ctrl coeff range: [{coeffs[:,5:].min():.2f}, {coeffs[:,5:].max():.2f}]")
|
| 1089 |
+
print(f" All media coeffs positive: {(coeffs[:,:5] > 0).all()}")
|
| 1090 |
+
print(f" Sales range: [{sample['total_sales'].min():.1f}, {sample['total_sales'].max():.1f}]")
|
| 1091 |
+
|
| 1092 |
+
# Check temporal variation (std of first-differences)
|
| 1093 |
+
media_diffs = np.diff(coeffs[:, :5], axis=0)
|
| 1094 |
+
print(f" Media coeff Δ std (temporal variation): {media_diffs.std(axis=0).round(5)}")
|
| 1095 |
+
|
| 1096 |
+
# ---- Step 2: Create datasets ----
|
| 1097 |
+
print("\n[2/5] Creating training datasets...")
|
| 1098 |
+
train_dataset = MMMDiffusionDataset(train_samples, normalize=True)
|
| 1099 |
+
val_dataset = MMMDiffusionDataset(val_samples, normalize=True)
|
| 1100 |
+
|
| 1101 |
+
item = train_dataset[0]
|
| 1102 |
+
print(f" Conditioning shape: {item['conditioning'].shape}")
|
| 1103 |
+
print(f" Coefficients shape: {item['coefficients'].shape}")
|
| 1104 |
+
print(f" Transformed media shape: {item['transformed_media'].shape}")
|
| 1105 |
+
|
| 1106 |
+
# ---- Step 3: Build model ----
|
| 1107 |
+
print("\n[3/5] Building MMM-Diffusion v2 model...")
|
| 1108 |
+
|
| 1109 |
+
T_DIFF = 500
|
| 1110 |
+
|
| 1111 |
+
model = MMMDiffusionModel(
|
| 1112 |
+
n_media=5, n_ctrl=3,
|
| 1113 |
+
d_model_campaign=192,
|
| 1114 |
+
d_model_channel=256,
|
| 1115 |
+
n_layers_campaign=4,
|
| 1116 |
+
n_layers_channel=6,
|
| 1117 |
+
T_diff=T_DIFF,
|
| 1118 |
+
)
|
| 1119 |
+
|
| 1120 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 1121 |
+
print(f" Total parameters: {total_params:,}")
|
| 1122 |
+
print(f" Campaign denoiser: {sum(p.numel() for p in model.campaign_denoiser.parameters()):,}")
|
| 1123 |
+
print(f" Channel denoiser: {sum(p.numel() for p in model.channel_denoiser.parameters()):,}")
|
| 1124 |
+
print(f" Diffusion steps: {T_DIFF}")
|
| 1125 |
+
|
| 1126 |
+
# ---- Step 4: Train ----
|
| 1127 |
+
print("\n[4/5] Training...")
|
| 1128 |
+
|
| 1129 |
+
N_EPOCHS = 150
|
| 1130 |
+
BATCH_SIZE = 32 if device == 'cuda' else 8
|
| 1131 |
+
LR = 3e-4
|
| 1132 |
+
|
| 1133 |
+
history = train_mmm_diffusion(
|
| 1134 |
+
model, train_dataset,
|
| 1135 |
+
n_epochs=N_EPOCHS,
|
| 1136 |
+
batch_size=BATCH_SIZE,
|
| 1137 |
+
lr=LR,
|
| 1138 |
+
device=device,
|
| 1139 |
+
log_every=25,
|
| 1140 |
+
save_path='/app/mmm_diffusion_model_v2.pt',
|
| 1141 |
+
)
|
| 1142 |
+
|
| 1143 |
+
plot_training_history(history, save_path='/app/training_history.png')
|
| 1144 |
+
|
| 1145 |
+
# ---- Step 5: Validate ----
|
| 1146 |
+
print("\n[5/5] Validation...")
|
| 1147 |
+
|
| 1148 |
+
# Load best model
|
| 1149 |
+
best_ckpt = torch.load('/app/mmm_diffusion_model_v2_best.pt', map_location=device, weights_only=False)
|
| 1150 |
+
model.load_state_dict(best_ckpt['model_state_dict'])
|
| 1151 |
+
model.eval()
|
| 1152 |
+
model = model.to(device)
|
| 1153 |
+
|
| 1154 |
+
# Evaluate on multiple validation samples
|
| 1155 |
+
all_corrs = []
|
| 1156 |
+
for val_idx in range(min(10, len(val_samples))):
|
| 1157 |
+
val_item = val_dataset[val_idx]
|
| 1158 |
+
cond = val_item['conditioning'].unsqueeze(0).to(device)
|
| 1159 |
+
true_coeffs_norm = val_item['coefficients'].unsqueeze(0)
|
| 1160 |
+
|
| 1161 |
+
pred_coeffs_norm = model.sample(cond, n_steps=T_DIFF, constraint_every_k=5)
|
| 1162 |
+
|
| 1163 |
+
pred_coeffs = val_dataset.decode_coefficients(pred_coeffs_norm.cpu())
|
| 1164 |
+
true_coeffs = val_dataset.decode_coefficients(true_coeffs_norm)
|
| 1165 |
+
|
| 1166 |
+
pred_np = pred_coeffs[0].numpy()
|
| 1167 |
+
true_np = true_coeffs[0].numpy()
|
| 1168 |
+
|
| 1169 |
+
sample_corrs = []
|
| 1170 |
+
for i in range(8):
|
| 1171 |
+
corr = np.corrcoef(true_np[:, i], pred_np[:, i])[0, 1]
|
| 1172 |
+
sample_corrs.append(corr)
|
| 1173 |
+
all_corrs.append(sample_corrs)
|
| 1174 |
+
|
| 1175 |
+
all_corrs = np.array(all_corrs)
|
| 1176 |
+
channel_names = MMMDataGenerator.MEDIA_CHANNELS + MMMDataGenerator.CONTROL_VARS
|
| 1177 |
+
|
| 1178 |
+
print(f"\n Average per-channel correlation (over {len(all_corrs)} val samples):")
|
| 1179 |
+
for i, name in enumerate(channel_names):
|
| 1180 |
+
mean_corr = np.nanmean(all_corrs[:, i])
|
| 1181 |
+
std_corr = np.nanstd(all_corrs[:, i])
|
| 1182 |
+
print(f" {name:20s}: corr={mean_corr:.3f} ± {std_corr:.3f}")
|
| 1183 |
+
|
| 1184 |
+
# Detailed analysis on first sample
|
| 1185 |
+
val_item = val_dataset[0]
|
| 1186 |
+
cond = val_item['conditioning'].unsqueeze(0).to(device)
|
| 1187 |
+
true_coeffs_norm = val_item['coefficients'].unsqueeze(0)
|
| 1188 |
+
|
| 1189 |
+
pred_coeffs_norm = model.sample(cond, n_steps=T_DIFF, constraint_every_k=5)
|
| 1190 |
+
pred_coeffs = val_dataset.decode_coefficients(pred_coeffs_norm.cpu())
|
| 1191 |
+
true_coeffs = val_dataset.decode_coefficients(true_coeffs_norm)
|
| 1192 |
+
|
| 1193 |
+
pred_np = pred_coeffs[0].numpy()
|
| 1194 |
+
true_np = true_coeffs[0].numpy()
|
| 1195 |
+
|
| 1196 |
+
print(f"\n Constraint check (sample 0):")
|
| 1197 |
+
print(f" Media coefficients all positive: {(pred_np[:, :5] > 0).all()}")
|
| 1198 |
+
print(f" Media coeff range: [{pred_np[:,:5].min():.6f}, {pred_np[:,:5].max():.6f}]")
|
| 1199 |
+
|
| 1200 |
+
# Check temporal variation of predictions vs GT
|
| 1201 |
+
pred_media_diffs = np.diff(pred_np[:, :5], axis=0)
|
| 1202 |
+
true_media_diffs = np.diff(true_np[:, :5], axis=0)
|
| 1203 |
+
print(f"\n Temporal variation (std of first-differences):")
|
| 1204 |
+
print(f" GT media Δ std: {true_media_diffs.std(axis=0).round(5)}")
|
| 1205 |
+
print(f" Pred media Δ std: {pred_media_diffs.std(axis=0).round(5)}")
|
| 1206 |
+
ratio = pred_media_diffs.std(axis=0) / (true_media_diffs.std(axis=0) + 1e-10)
|
| 1207 |
+
print(f" Ratio pred/GT: {ratio.round(3)} (want close to 1.0)")
|
| 1208 |
+
|
| 1209 |
+
# Plot
|
| 1210 |
+
plot_coefficient_comparison(
|
| 1211 |
+
true_np, pred_np, channel_names,
|
| 1212 |
+
save_path='/app/coeff_comparison.png'
|
| 1213 |
+
)
|
| 1214 |
+
|
| 1215 |
+
# Sales decomposition with proper transformed media
|
| 1216 |
+
val_raw = val_samples[0]
|
| 1217 |
+
contributions = decompose_sales(
|
| 1218 |
+
pred_np, val_raw['media_spend'], val_raw['controls'],
|
| 1219 |
+
transformed_media=val_raw['transformed_media'],
|
| 1220 |
+
base_sales=val_raw['base_sales'],
|
| 1221 |
+
)
|
| 1222 |
+
|
| 1223 |
+
# Also compute GT decomposition for comparison
|
| 1224 |
+
gt_contributions = decompose_sales(
|
| 1225 |
+
true_np, val_raw['media_spend'], val_raw['controls'],
|
| 1226 |
+
transformed_media=val_raw['transformed_media'],
|
| 1227 |
+
base_sales=val_raw['base_sales'],
|
| 1228 |
+
)
|
| 1229 |
+
|
| 1230 |
+
plot_sales_decomposition(
|
| 1231 |
+
contributions, val_raw['total_sales'],
|
| 1232 |
+
save_path='/app/sales_decomposition.png'
|
| 1233 |
+
)
|
| 1234 |
+
|
| 1235 |
+
# Sales alignment metrics
|
| 1236 |
+
pred_sales = contributions['Predicted_Sales']
|
| 1237 |
+
actual_sales = val_raw['total_sales']
|
| 1238 |
+
ss_res = np.sum((actual_sales - pred_sales)**2)
|
| 1239 |
+
ss_tot = np.sum((actual_sales - actual_sales.mean())**2)
|
| 1240 |
+
r2 = 1 - ss_res / (ss_tot + 1e-10)
|
| 1241 |
+
mape = np.mean(np.abs(actual_sales - pred_sales) / (np.abs(actual_sales) + 1e-10)) * 100
|
| 1242 |
+
|
| 1243 |
+
print(f"\n Sales Alignment:")
|
| 1244 |
+
print(f" R² = {r2:.4f}")
|
| 1245 |
+
print(f" MAPE = {mape:.1f}%")
|
| 1246 |
+
print(f" Actual sales range: [{actual_sales.min():.0f}, {actual_sales.max():.0f}]")
|
| 1247 |
+
print(f" Predicted sales range: [{pred_sales.min():.0f}, {pred_sales.max():.0f}]")
|
| 1248 |
+
|
| 1249 |
+
print(f"\n{'='*70}")
|
| 1250 |
+
print(f"MMM-DIFFUSION v2 COMPLETE")
|
| 1251 |
+
print(f"{'='*70}")
|
| 1252 |
+
print(f" Fixes applied:")
|
| 1253 |
+
print(f" 1. Sales reconstruction loss (L_sales) — aligns predicted with actual sales")
|
| 1254 |
+
print(f" 2. Spectral loss (L_spectral) — preserves frequency content, fights smoothing")
|
| 1255 |
+
print(f" 3. Multi-scale temporal loss — matches velocity AND acceleration")
|
| 1256 |
+
print(f" 4. Contribution matching loss — aligns channel-level decomposition")
|
| 1257 |
+
print(f" 5. Higher GT coefficient volatility with regime jumps")
|
| 1258 |
+
print(f" 6. Low-noise biased timestep sampling for better detail learning")
|
| 1259 |
+
print(f" 7. Reduced smoothness weight (0.1 → 0.05)")
|
| 1260 |
+
print(f" Final training loss: {history['loss'][-1]:.4f}")
|
| 1261 |
+
print(f" Sales R²: {r2:.4f}")
|
| 1262 |
+
|
| 1263 |
+
return model, history, train_dataset, val_dataset
|
| 1264 |
+
|
| 1265 |
+
|
| 1266 |
+
if __name__ == '__main__':
|
| 1267 |
+
model, history, train_dataset, val_dataset = main()
|