| import os |
| import json |
| import random |
| from typing import Dict, List, Tuple, Optional, Any |
|
|
| import numpy as np |
| from PIL import Image |
| from tqdm import tqdm |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch.utils.data import Dataset, DataLoader |
| from torchvision.transforms import Compose, Resize, ToTensor, CenterCrop |
| from torchvision.utils import save_image |
| import lpips |
|
|
| from diffusers import ( |
| AutoencoderKL, |
| AutoencoderKLWan, |
| AutoencoderKLLTXVideo, |
| AutoencoderKLQwenImage |
| ) |
|
|
| from scipy.stats import skew, kurtosis |
|
|
|
|
| |
| DEVICE = "cuda" |
| DTYPE = torch.float16 |
| IMAGE_FOLDER = "/home/recoilme/dataset/alchemist" |
| MIN_SIZE = 1280 |
| CROP_SIZE = 512 |
| BATCH_SIZE = 5 |
| MAX_IMAGES = 0 |
| NUM_WORKERS = 4 |
| SAMPLES_DIR = "test" |
|
|
| VAE_LIST = [ |
| ("SD15 VAE", AutoencoderKL, "stable-diffusion-v1-5/stable-diffusion-v1-5", "vae"), |
| ("SDXL VAE fp16 fix", AutoencoderKL, "madebyollin/sdxl-vae-fp16-fix", None), |
| ("AiArtLab/sdxl_vae", AutoencoderKL, "AiArtLab/sdxl_vae", "vae"), |
| ("LTX-Video VAE", AutoencoderKLLTXVideo, "Lightricks/LTX-Video", "vae"), |
| ("Wan2.2-TI2V-5B", AutoencoderKLWan, "Wan-AI/Wan2.2-TI2V-5B-Diffusers", "vae"), |
| ("AiArtLab/wan16x_vae", AutoencoderKLWan, "AiArtLab/wan16x_vae", "vae"), |
| ("Wan2.2-T2V-A14B", AutoencoderKLWan, "Wan-AI/Wan2.2-T2V-A14B-Diffusers", "vae"), |
| ("QwenImage", AutoencoderKLQwenImage, "Qwen/Qwen-Image", "vae"), |
| ("AuraDiffusion/16ch-vae", AutoencoderKL, "AuraDiffusion/16ch-vae", None), |
| ("FLUX.1-schnell VAE", AutoencoderKL, "black-forest-labs/FLUX.1-schnell", "vae"), |
| ("AiArtLab/simplevae", AutoencoderKL, "AiArtLab/simplevae", "vae"), |
| ] |
|
|
|
|
| |
| def to_neg1_1(x: torch.Tensor) -> torch.Tensor: |
| return x * 2 - 1 |
|
|
|
|
| def to_0_1(x: torch.Tensor) -> torch.Tensor: |
| return (x + 1) * 0.5 |
|
|
|
|
| def safe_psnr(mse: float) -> float: |
| if mse <= 1e-12: |
| return float("inf") |
| return 10.0 * float(np.log10(1.0 / mse)) |
|
|
|
|
| def is_video_like_vae(vae) -> bool: |
| |
| return isinstance(vae, (AutoencoderKLWan, AutoencoderKLLTXVideo,AutoencoderKLQwenImage)) |
|
|
|
|
| def add_time_dim_if_needed(x: torch.Tensor, vae) -> torch.Tensor: |
| if is_video_like_vae(vae) and x.ndim == 4: |
| return x.unsqueeze(2) |
| return x |
|
|
|
|
| def strip_time_dim_if_possible(x: torch.Tensor, vae) -> torch.Tensor: |
| if is_video_like_vae(vae) and x.ndim == 5 and x.shape[2] == 1: |
| return x.squeeze(2) |
| return x |
|
|
|
|
| @torch.no_grad() |
| def sobel_edge_l1(real_0_1: torch.Tensor, fake_0_1: torch.Tensor) -> float: |
| real = to_neg1_1(real_0_1) |
| fake = to_neg1_1(fake_0_1) |
| kx = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32, device=real.device).view(1, 1, 3, 3) |
| ky = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32, device=real.device).view(1, 1, 3, 3) |
| C = real.shape[1] |
| kx = kx.to(real.dtype).repeat(C, 1, 1, 1) |
| ky = ky.to(real.dtype).repeat(C, 1, 1, 1) |
|
|
| def grad_mag(x): |
| gx = F.conv2d(x, kx, padding=1, groups=C) |
| gy = F.conv2d(x, ky, padding=1, groups=C) |
| return torch.sqrt(gx * gx + gy * gy + 1e-12) |
|
|
| return F.l1_loss(grad_mag(fake), grad_mag(real)).item() |
|
|
|
|
| def flatten_channels(x: torch.Tensor) -> torch.Tensor: |
| |
| if x.ndim == 4: |
| return x.permute(1, 0, 2, 3).reshape(x.shape[1], -1) |
| elif x.ndim == 5: |
| return x.permute(1, 0, 2, 3, 4).reshape(x.shape[1], -1) |
| else: |
| raise ValueError(f"Unexpected tensor ndim={x.ndim}") |
|
|
|
|
| def _to_numpy_1d(x: Any) -> Optional[np.ndarray]: |
| if x is None: |
| return None |
| if isinstance(x, (int, float)): |
| return None |
| if isinstance(x, torch.Tensor): |
| x = x.detach().cpu().float().numpy() |
| elif isinstance(x, (list, tuple)): |
| x = np.array(x, dtype=np.float32) |
| elif isinstance(x, np.ndarray): |
| x = x.astype(np.float32, copy=False) |
| else: |
| return None |
| x = x.reshape(-1) |
| return x |
|
|
|
|
| def _to_float(x: Any) -> Optional[float]: |
| if x is None: |
| return None |
| if isinstance(x, (int, float)): |
| return float(x) |
| if isinstance(x, np.ndarray) and x.size == 1: |
| return float(x.item()) |
| if isinstance(x, torch.Tensor) and x.numel() == 1: |
| return float(x.item()) |
| return None |
|
|
|
|
| def get_norm_tensors_and_summary(vae, latent_like: torch.Tensor): |
| """ |
| Нормализация латентов: глобальная и поканальная. |
| Применение: сначала глобальная (scalar), затем поканальная (vector). |
| Если в конфиге есть несколько ключей — аккумулируем. |
| """ |
| cfg = getattr(vae, "config", vae) |
|
|
| scale_keys = [ |
| "latents_std" |
| ] |
| shift_keys = [ |
| "latents_mean" |
| ] |
|
|
| C = latent_like.shape[1] |
| nd = latent_like.ndim |
| dev = latent_like.device |
| dt = latent_like.dtype |
|
|
| scale_global = getattr(vae.config, "scaling_factor", 1.0) |
| shift_global = getattr(vae.config, "shift_factor", 0.0) |
| if scale_global is None: |
| scale_global = 1.0 |
| if shift_global is None: |
| shift_global = 0.0 |
| |
| scale_channel = np.ones(C, dtype=np.float32) |
| shift_channel = np.zeros(C, dtype=np.float32) |
|
|
| for k in scale_keys: |
| v = getattr(cfg, k, None) |
| if v is None: |
| continue |
| vec = _to_numpy_1d(v) |
| if vec is not None and vec.size == C: |
| scale_channel *= vec |
| else: |
| s = _to_float(v) |
| if s is not None: |
| scale_global *= s |
|
|
| for k in shift_keys: |
| v = getattr(cfg, k, None) |
| if v is None: |
| continue |
| vec = _to_numpy_1d(v) |
| if vec is not None and vec.size == C: |
| shift_channel += vec |
| else: |
| s = _to_float(v) |
| if s is not None: |
| shift_global += s |
|
|
| g_shape = [1] * nd |
| c_shape = [1] * nd |
| c_shape[1] = C |
|
|
| t_scale_g = torch.tensor(scale_global, dtype=dt, device=dev).view(*g_shape) |
| t_shift_g = torch.tensor(shift_global, dtype=dt, device=dev).view(*g_shape) |
| t_scale_c = torch.from_numpy(scale_channel).to(device=dev, dtype=dt).view(*c_shape) |
| t_shift_c = torch.from_numpy(shift_channel).to(device=dev, dtype=dt).view(*c_shape) |
|
|
| summary = { |
| "scale_global": float(scale_global), |
| "shift_global": float(shift_global), |
| "scale_channel_min": float(scale_channel.min()), |
| "scale_channel_mean": float(scale_channel.mean()), |
| "scale_channel_max": float(scale_channel.max()), |
| "shift_channel_min": float(shift_channel.min()), |
| "shift_channel_mean": float(shift_channel.mean()), |
| "shift_channel_max": float(shift_channel.max()), |
| } |
| return t_shift_g, t_scale_g, t_shift_c, t_scale_c, summary |
|
|
|
|
| @torch.no_grad() |
| def kl_divergence_per_image(mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: |
| kl_map = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()) |
| return kl_map.float().view(kl_map.shape[0], -1).mean(dim=1) |
|
|
|
|
| def sanitize_filename(name: str) -> str: |
| name = name.replace("/", "_").replace("\\", "_").replace(" ", "_") |
| return "".join(ch if (ch.isalnum() or ch in "._-") else "_" for ch in name) |
|
|
|
|
| |
| class ImageFolderDataset(Dataset): |
| def __init__(self, root_dir: str, extensions=(".png", ".jpg", ".jpeg", ".webp"), min_size=1024, crop_size=512, limit=None): |
| paths = [] |
| for root, _, files in os.walk(root_dir): |
| for fname in files: |
| if fname.lower().endswith(extensions): |
| paths.append(os.path.join(root, fname)) |
| if limit: |
| paths = paths[:limit] |
|
|
| valid = [] |
| for p in tqdm(paths, desc="Проверяем файлы"): |
| try: |
| with Image.open(p) as im: |
| im.verify() |
| valid.append(p) |
| except Exception: |
| pass |
| if not valid: |
| raise RuntimeError(f"Нет валидных изображений в {root_dir}") |
| random.shuffle(valid) |
| self.paths = valid |
| print(f"Найдено {len(self.paths)} изображений") |
|
|
| self.transform = Compose([ |
| Resize(min_size), |
| CenterCrop(crop_size), |
| ToTensor(), |
| ]) |
|
|
| def __len__(self): |
| return len(self.paths) |
|
|
| def __getitem__(self, idx): |
| with Image.open(self.paths[idx]) as img: |
| img = img.convert("RGB") |
| return self.transform(img) |
|
|
|
|
| |
| def main(): |
| torch.set_grad_enabled(False) |
| os.makedirs(SAMPLES_DIR, exist_ok=True) |
|
|
| dataset = ImageFolderDataset(IMAGE_FOLDER, min_size=MIN_SIZE, crop_size=CROP_SIZE, limit=MAX_IMAGES) |
| loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True) |
|
|
| lpips_net = lpips.LPIPS(net="vgg").to(DEVICE).eval() |
|
|
| |
| vaes: List[Tuple[str, object]] = [] |
| print("\nЗагрузка VAE...") |
| for human_name, vae_class, model_path, subfolder in VAE_LIST: |
| try: |
| vae = vae_class.from_pretrained(model_path, subfolder=subfolder, torch_dtype=DTYPE) |
| vae = vae.to(DEVICE).eval() |
| vaes.append((human_name, vae)) |
| print(f" ✅ {human_name}") |
| except Exception as e: |
| print(f" ❌ {human_name}: {e}") |
|
|
| if not vaes: |
| print("Нет успешно загруженных VAE. Выходим.") |
| return |
|
|
| |
| per_model_metrics: Dict[str, Dict[str, float]] = { |
| name: {"mse": 0.0, "psnr": 0.0, "lpips": 0.0, "edge": 0.0, "kl": 0.0, "count": 0.0} |
| for name, _ in vaes |
| } |
|
|
| buffers_zmodel: Dict[str, List[torch.Tensor]] = {name: [] for name, _ in vaes} |
| norm_summaries: Dict[str, Dict[str, float]] = {} |
|
|
| |
| saved_first_for: Dict[str, bool] = {name: False for name, _ in vaes} |
|
|
| for batch_0_1 in tqdm(loader, desc="Батчи"): |
| batch_0_1 = batch_0_1.to(DEVICE, torch.float32) |
| batch_neg1_1 = to_neg1_1(batch_0_1).to(DTYPE) |
|
|
| for model_name, vae in vaes: |
| x_in = add_time_dim_if_needed(batch_neg1_1, vae) |
|
|
| posterior = vae.encode(x_in).latent_dist |
| mu, logvar = posterior.mean, posterior.logvar |
|
|
| |
| z_raw_mode = posterior.mode() |
| x_dec = vae.decode(z_raw_mode).sample |
| x_dec = strip_time_dim_if_possible(x_dec, vae) |
| x_rec_0_1 = to_0_1(x_dec.float()).clamp(0, 1) |
|
|
| |
| z_raw_sample = posterior.sample() |
| t_shift_g, t_scale_g, t_shift_c, t_scale_c, summary = get_norm_tensors_and_summary(vae, z_raw_sample) |
| |
| if model_name not in norm_summaries: |
| norm_summaries[model_name] = summary |
|
|
| z_tmp = (z_raw_sample - t_shift_g) * t_scale_g |
| z_model = (z_tmp - t_shift_c) * t_scale_c |
| z_model = strip_time_dim_if_possible(z_model, vae) |
|
|
| buffers_zmodel[model_name].append(z_model.detach().to("cpu", torch.float32)) |
|
|
| |
| if not saved_first_for[model_name]: |
| safe = sanitize_filename(model_name) |
| orig_path = os.path.join(SAMPLES_DIR, f"{safe}_original.png") |
| dec_path = os.path.join(SAMPLES_DIR, f"{safe}_decoded.png") |
| save_image(batch_0_1[0:1].cpu(), orig_path) |
| save_image(x_rec_0_1[0:1].cpu(), dec_path) |
| saved_first_for[model_name] = True |
|
|
| |
| B = batch_0_1.shape[0] |
| for i in range(B): |
| gt = batch_0_1[i:i+1] |
| rec = x_rec_0_1[i:i+1] |
|
|
| mse = F.mse_loss(gt, rec).item() |
| psnr = safe_psnr(mse) |
| lp = float(lpips_net(gt, rec, normalize=True).mean().item()) |
| edge = sobel_edge_l1(gt, rec) |
|
|
| per_model_metrics[model_name]["mse"] += mse |
| per_model_metrics[model_name]["psnr"] += psnr |
| per_model_metrics[model_name]["lpips"] += lp |
| per_model_metrics[model_name]["edge"] += edge |
|
|
| |
| kl_pi = kl_divergence_per_image(mu, logvar) |
| per_model_metrics[model_name]["kl"] += float(kl_pi.sum().item()) |
| per_model_metrics[model_name]["count"] += B |
|
|
| |
| for name in per_model_metrics: |
| c = max(1.0, per_model_metrics[name]["count"]) |
| for k in ["mse", "psnr", "lpips", "edge", "kl"]: |
| per_model_metrics[name][k] /= c |
|
|
| |
| per_model_latent_stats = {} |
| for name, _ in vaes: |
| if not buffers_zmodel[name]: |
| continue |
| Z = torch.cat(buffers_zmodel[name], dim=0) |
|
|
| |
| z_min = float(Z.min().item()) |
| z_mean = float(Z.mean().item()) |
| z_max = float(Z.max().item()) |
| z_std = float(Z.std(unbiased=True).item()) |
|
|
| |
| Z_ch = flatten_channels(Z).numpy() |
| C = Z_ch.shape[0] |
| sk = np.zeros(C, dtype=np.float64) |
| ku = np.zeros(C, dtype=np.float64) |
| for c in range(C): |
| v = Z_ch[c] |
| sk[c] = float(skew(v, bias=False)) |
| ku[c] = float(kurtosis(v, fisher=True, bias=False)) |
|
|
| skew_min, skew_mean, skew_max = float(sk.min()), float(sk.mean()), float(sk.max()) |
| kurt_min, kurt_mean, kurt_max = float(ku.min()), float(ku.mean()), float(ku.max()) |
| mean_abs_skew = float(np.mean(np.abs(sk))) |
| mean_abs_kurt = float(np.mean(np.abs(ku))) |
|
|
| per_model_latent_stats[name] = { |
| "Z_min": z_min, "Z_mean": z_mean, "Z_max": z_max, "Z_std": z_std, |
| "skew_min": skew_min, "skew_mean": skew_mean, "skew_max": skew_max, |
| "kurt_min": kurt_min, "kurt_mean": kurt_mean, "kurt_max": kurt_max, |
| "mean_abs_skew": mean_abs_skew, "mean_abs_kurt": mean_abs_kurt, |
| } |
|
|
| |
| print("\n=== Параметры нормализации латентов (как применялись) ===") |
| for name, _ in vaes: |
| if name not in norm_summaries: |
| continue |
| s = norm_summaries[name] |
| print( |
| f"{name:26s} | " |
| f"shift_g={s['shift_global']:.6g} scale_g={s['scale_global']:.6g} | " |
| f"shift_c[min/mean/max]=[{s['shift_channel_min']:.6g}, {s['shift_channel_mean']:.6g}, {s['shift_channel_max']:.6g}] | " |
| f"scale_c[min/mean/max]=[{s['scale_channel_min']:.6g}, {s['scale_channel_mean']:.6g}, {s['scale_channel_max']:.6g}]" |
| ) |
|
|
| |
| print("\n=== Абсолютные метрики реконструкции и латентов ===") |
| for name, _ in vaes: |
| if name not in per_model_latent_stats: |
| continue |
| m = per_model_metrics[name] |
| s = per_model_latent_stats[name] |
| print( |
| f"{name:26s} | " |
| f"MSE={m['mse']:.3e} PSNR={m['psnr']:.2f} LPIPS={m['lpips']:.3f} Edge={m['edge']:.3f} KL={m['kl']:.3f} | " |
| f"Z[min/mean/max/std]=[{s['Z_min']:.3f}, {s['Z_mean']:.3f}, {s['Z_max']:.3f}, {s['Z_std']:.3f}] | " |
| f"Skew[min/mean/max]=[{s['skew_min']:.3f}, {s['skew_mean']:.3f}, {s['skew_max']:.3f}] | " |
| f"Kurt[min/mean/max]=[{s['kurt_min']:.3f}, {s['kurt_mean']:.3f}, {s['kurt_max']:.3f}]" |
| ) |
|
|
| |
| baseline = vaes[0][0] |
| print("\n=== Сравнение с первой моделью (проценты) ===") |
| print(f"| {'Модель':26s} | {'MSE':>9s} | {'PSNR':>9s} | {'LPIPS':>9s} | {'Edge':>9s} | {'Skew|0':>9s} | {'Kurt|0':>9s} |") |
| print(f"|{'-'*28}|{'-'*11}|{'-'*11}|{'-'*11}|{'-'*11}|{'-'*11}|{'-'*11}|") |
|
|
| b_m = per_model_metrics[baseline] |
| b_s = per_model_latent_stats[baseline] |
|
|
| for name, _ in vaes: |
| m = per_model_metrics[name] |
| s = per_model_latent_stats[name] |
|
|
| mse_pct = (b_m["mse"] / max(1e-12, m["mse"])) * 100.0 |
| psnr_pct = (m["psnr"] / max(1e-12, b_m["psnr"])) * 100.0 |
| lpips_pct= (b_m["lpips"] / max(1e-12, m["lpips"])) * 100.0 |
| edge_pct = (b_m["edge"] / max(1e-12, m["edge"])) * 100.0 |
|
|
| skew0_pct = (b_s["mean_abs_skew"] / max(1e-12, s["mean_abs_skew"])) * 100.0 |
| kurt0_pct = (b_s["mean_abs_kurt"] / max(1e-12, s["mean_abs_kurt"])) * 100.0 |
|
|
| if name == baseline: |
| print(f"| {name:26s} | {'100%':>9s} | {'100%':>9s} | {'100%':>9s} | {'100%':>9s} | {'100%':>9s} | {'100%':>9s} |") |
| else: |
| print(f"| {name:26s} | {mse_pct:8.1f}% | {psnr_pct:8.1f}% | {lpips_pct:8.1f}% | {edge_pct:8.1f}% | {skew0_pct:8.1f}% | {kurt0_pct:8.1f}% |") |
|
|
| |
| last_name = vaes[-1][0] |
| if buffers_zmodel[last_name]: |
| Z = torch.cat(buffers_zmodel[last_name], dim=0) |
|
|
| |
| z_mean = float(Z.mean().item()) |
| z_std = float(Z.std(unbiased=True).item()) |
| correction_global = { |
| "shift": -z_mean, |
| "scale": (1.0 / z_std) if z_std > 1e-12 else 1.0 |
| } |
|
|
| |
| Z_ch = flatten_channels(Z) |
| ch_means_t = Z_ch.mean(dim=1) |
| ch_stds_t = Z_ch.std(dim=1, unbiased=True) + 1e-12 |
| ch_means = [float(x) for x in ch_means_t.tolist()] |
| ch_stds = [float(x) for x in ch_stds_t.tolist()] |
|
|
| correction_per_channel = [ |
| {"shift": float(-m), "scale": float(1.0 / s)} |
| for m, s in zip(ch_means, ch_stds) |
| ] |
|
|
| print(f"\n=== Доп. коррекция для {last_name} (поверх VAE-нормализации) ===") |
| print(f"global_correction = {correction_global}") |
| print(f"channelwise_means = {ch_means}") |
| print(f"channelwise_stds = {ch_stds}") |
| print(f"channelwise_correction = {correction_per_channel}") |
|
|
| |
| json_path = os.path.join(SAMPLES_DIR, f"{sanitize_filename(last_name)}_correction.json") |
| to_save = { |
| "model_name": last_name, |
| "vae_normalization_summary": norm_summaries.get(last_name, {}), |
| "global_correction": correction_global, |
| "per_channel_means": ch_means, |
| "per_channel_stds": ch_stds, |
| "per_channel_correction": correction_per_channel, |
| "apply_order": { |
| "forward": "z_model -> (z - global_shift)*global_scale -> (per-channel: (z - mean_c)/std_c)", |
| "inverse": "z_corr -> (per-channel: z*std_c + mean_c) -> (z/global_scale + global_shift)" |
| }, |
| "note": "Эти коэффициенты рассчитаны по z_model (после встроенных VAE shift/scale), чтобы привести распределение к N(0,1)." |
| } |
| with open(json_path, "w", encoding="utf-8") as f: |
| json.dump(to_save, f, ensure_ascii=False, indent=2) |
| print("Corrections JSON saved to:", os.path.abspath(json_path)) |
|
|
| print("\n✅ Готово. Сэмплы сохранены в:", os.path.abspath(SAMPLES_DIR)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|