| |
| """ |
| Apply generated SwiGLU MLP weights to a Gemma 4 31B safetensors model. |
| Layer files contain gate_proj.weight / up_proj.weight / down_proj.weight |
| as pre-computed delta tensors — fused via Shape-Contoured Fusion (SCF). |
| |
| SCF replaces the old naive additive delta approach: |
| - down_proj : contoured multiplicative delta (dynamic_alpha * delta * W_existing) |
| - gate_proj : multiplicative gamma scaling (W * (1 + clamp(delta, +/-gamma_cap))) |
| - up_proj : intentionally unchanged (linear path, as in fuzer.py) |
| |
| Gemma 4 31B interleaved attention: 5 SWA + 1 global per period (60 layers total). |
| Global layers (5, 11, 17, 23, 29, 35, 41, 47, 53, 59) may carry double-wide MLP tensors; |
| partial coverage is handled transparently via row/col clamping. |
| """ |
|
|
| import argparse |
| import json |
| import shutil |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| from safetensors.torch import load, load_file, save_file |
|
|
| PROJ_KEYS = ("gate_proj.weight", "up_proj.weight", "down_proj.weight") |
|
|
| INTERLEAVE_PERIOD = 6 |
| GLOBAL_LAYER_OFFSET = 5 |
|
|
|
|
| def is_global_attention_layer(layer_idx: int) -> bool: |
| return ( |
| layer_idx >= GLOBAL_LAYER_OFFSET |
| and (layer_idx - GLOBAL_LAYER_OFFSET) % INTERLEAVE_PERIOD == 0 |
| ) |
|
|
|
|
| def detect_key_prefix(tensor_keys, layer_idx: int, proj: str) -> str: |
| """Dynamically locate the exact key prefix in the target file. |
| |
| Gemma 4 is a VLM: always prefer language_model matches over vision tower. |
| """ |
| suffix = f"layers.{layer_idx}.mlp.{proj}" |
| matches = [k for k in tensor_keys if k.endswith(suffix)] |
| for k in matches: |
| if "language_model" in k: |
| return k[: -len(suffix)] |
| if matches: |
| return matches[0][: -len(suffix)] |
| return "model.language_model.model." |
|
|
|
|
| def discover_generated_layers(weights_dir: Path) -> dict: |
| layers = {} |
| for f in sorted(weights_dir.glob("layer_*.safetensors")): |
| try: |
| idx = int(f.stem.split("_")[1]) |
| layers[idx] = f |
| except (IndexError, ValueError): |
| continue |
| return layers |
|
|
|
|
| |
| |
| |
|
|
| def fuse_layer_deltas( |
| layer_idx: int, |
| gate_w: torch.Tensor, |
| up_w: torch.Tensor, |
| down_w: torch.Tensor, |
| new_weights: dict, |
| args: argparse.Namespace, |
| ) -> None: |
| """ |
| Apply SCF to one layer using pre-computed delta tensors. |
| |
| down_proj -- contoured additive: |
| delta is scaled by the existing weight profile so the update respects |
| the model's learned contour. dynamic_alpha is variance-normalised so |
| scale stays consistent across layers regardless of initialisation. |
| |
| gate_proj -- multiplicative gamma: |
| gamma = 1 + clamp(delta, +-gamma_cap) |
| Matches fuzer's W*gamma pattern without needing raw adapter weights. |
| |
| up_proj -- unchanged: |
| Linear value path in SwiGLU must not receive non-linear scaling. |
| Intentional, mirrors fuzer's explicit decision. |
| """ |
|
|
| |
| if "down_proj.weight" in new_weights: |
| delta_down = new_weights["down_proj.weight"].float() |
| nr = min(delta_down.shape[0], down_w.shape[0]) |
| nc = min(delta_down.shape[1], down_w.shape[1]) |
|
|
| fan_in = down_w.shape[1] |
| expected_var = 1.0 / fan_in |
| down_var = down_w[:nr, :nc].var().item() |
| dynamic_alpha = float(np.clip( |
| args.alpha * (down_var / (expected_var + 1e-8)), |
| args.alpha * 0.1, |
| args.alpha * 10.0, |
| )) |
|
|
| contoured = dynamic_alpha * delta_down[:nr, :nc] * down_w[:nr, :nc] |
| down_w[:nr, :nc] = down_w[:nr, :nc] + contoured |
|
|
| if nr < down_w.shape[0] or nc < down_w.shape[1]: |
| print(f" [warn] Layer {layer_idx}: down_proj delta covers " |
| f"{nr}x{nc} of {down_w.shape[0]}x{down_w.shape[1]} -- partial fusion") |
|
|
| |
| if "gate_proj.weight" in new_weights: |
| delta_gate = new_weights["gate_proj.weight"].float() |
| nr = min(delta_gate.shape[0], gate_w.shape[0]) |
| nc = min(delta_gate.shape[1], gate_w.shape[1]) |
|
|
| gamma = 1.0 + delta_gate[:nr, :nc].clamp(-args.gamma_cap, args.gamma_cap) |
| gate_w[:nr, :nc] = gate_w[:nr, :nc] * gamma |
|
|
| |
|
|
|
|
| |
| |
| |
|
|
| def apply_single_file(model_path: Path, output_dir: Path, layer_files: dict, args) -> int: |
| dry_run = args.dry_run |
| print(f"\n[model] Processing file: {model_path.name}") |
|
|
| |
| tensors = load_file(str(model_path)) |
|
|
| fused = 0 |
| skipped = 0 |
|
|
| for layer_idx, layer_path in sorted(layer_files.items()): |
| layer_type = "global" if is_global_attention_layer(layer_idx) else "swa" |
|
|
| new_weights = load_file(str(layer_path)) |
|
|
| if not any(k in new_weights for k in PROJ_KEYS): |
| print(f" [skip] Layer {layer_idx}: none of {PROJ_KEYS} found. " |
| f"Got: {list(new_weights.keys())}") |
| skipped += 1 |
| continue |
|
|
| proj_model_keys = {} |
| all_found = True |
| for proj in PROJ_KEYS: |
| prefix = detect_key_prefix(tensors.keys(), layer_idx, proj) |
| model_key = f"{prefix}layers.{layer_idx}.mlp.{proj}" |
| if model_key not in tensors: |
| print(f" [skip] Key not found in model: {model_key!r}") |
| all_found = False |
| break |
| proj_model_keys[proj] = model_key |
|
|
| if not all_found: |
| skipped += 1 |
| continue |
|
|
| gate_key = proj_model_keys["gate_proj.weight"] |
| up_key = proj_model_keys["up_proj.weight"] |
| down_key = proj_model_keys["down_proj.weight"] |
|
|
| orig_gate_dtype = tensors[gate_key].dtype |
| orig_down_dtype = tensors[down_key].dtype |
|
|
| gate_w = tensors[gate_key].clone().float() |
| up_w = tensors[up_key].clone().float() |
| down_w = tensors[down_key].clone().float() |
|
|
| if not dry_run: |
| fuse_layer_deltas(layer_idx, gate_w, up_w, down_w, new_weights, args) |
| tensors[gate_key] = gate_w.to(orig_gate_dtype) |
| |
| tensors[down_key] = down_w.to(orig_down_dtype) |
|
|
| fused += 1 |
| print(f" {'[dry]' if dry_run else '[ok]'} Fused layer {layer_idx:02d} [{layer_type}]" |
| f" gate*gamma + down contoured (up unchanged)") |
|
|
| if skipped > 0 and fused == 0: |
| raise RuntimeError( |
| f"No layers were fused -- all {skipped} layer(s) were skipped.\n" |
| f"Sample model keys: {list(tensors.keys())[:4]}" |
| ) |
| if skipped > 0: |
| print(f" [warn] {skipped} layer(s) skipped, {fused} fused.") |
|
|
| if not dry_run: |
| out_path = output_dir / model_path.name |
| save_file(tensors, str(out_path)) |
| print(f" Saved -> {out_path.resolve()}") |
|
|
| return fused |
|
|
|
|
| |
| |
| |
|
|
| def apply_sharded(model_dir: Path, output_dir: Path, layer_files: dict, args) -> int: |
| dry_run = args.dry_run |
| index_path = model_dir / "model.safetensors.index.json" |
| if not index_path.exists(): |
| raise FileNotFoundError(f"Sharded index missing: {index_path}") |
|
|
| with open(index_path) as f: |
| index = json.load(f) |
| weight_map = index["weight_map"] |
|
|
| |
| |
| |
| |
| fusion_plan: dict = {} |
| skipped = 0 |
|
|
| for layer_idx, layer_path in sorted(layer_files.items()): |
| layer_type = "global" if is_global_attention_layer(layer_idx) else "swa" |
|
|
| new_weights = load_file(str(layer_path)) |
|
|
| if not any(k in new_weights for k in PROJ_KEYS): |
| print(f" [skip] Layer {layer_idx}: none of {PROJ_KEYS} found. " |
| f"Got: {list(new_weights.keys())}") |
| skipped += 1 |
| continue |
|
|
| proj_registered = 0 |
| for proj in PROJ_KEYS: |
| if proj not in new_weights: |
| continue |
| prefix = detect_key_prefix(weight_map.keys(), layer_idx, proj) |
| model_key = f"{prefix}layers.{layer_idx}.mlp.{proj}" |
| if model_key not in weight_map: |
| print(f" [skip] Layer {layer_idx}: {model_key!r} not in weight_map") |
| continue |
| shard_name = weight_map[model_key] |
| fusion_plan.setdefault(shard_name, []).append( |
| (layer_idx, proj, model_key, new_weights[proj], layer_type) |
| ) |
| proj_registered += 1 |
|
|
| if proj_registered == 0: |
| skipped += 1 |
|
|
| if not fusion_plan: |
| sample = list(weight_map.keys())[:6] |
| raise RuntimeError( |
| f"No layers matched in weight_map. Sample keys: {sample}" |
| ) |
|
|
| |
| modified_shards = set(fusion_plan.keys()) |
|
|
| if not dry_run: |
| output_dir.mkdir(parents=True, exist_ok=True) |
| |
| |
| |
| for src_file in model_dir.iterdir(): |
| dst_file = output_dir / src_file.name |
| if src_file.name not in modified_shards: |
| if src_file.is_dir(): |
| shutil.copytree(src_file, dst_file, dirs_exist_ok=True) |
| else: |
| shutil.copy2(src_file, dst_file) |
| |
| all_shards = {v for v in weight_map.values()} |
| for shard_name in all_shards - modified_shards: |
| src = model_dir / shard_name |
| dst = output_dir / shard_name |
| if src.exists() and not dst.exists(): |
| shutil.copy2(src, dst) |
|
|
| fused_layer_idxs: set = set() |
|
|
| for shard_name, ops in sorted(fusion_plan.items()): |
| shard_src = model_dir / shard_name |
| shard_dst = output_dir / shard_name |
|
|
| |
| tensors = load_file(str(shard_src)) |
|
|
| |
| by_layer: dict = {} |
| for layer_idx, proj, model_key, delta, layer_type in ops: |
| by_layer.setdefault(layer_idx, []).append((proj, model_key, delta, layer_type)) |
|
|
| for layer_idx, proj_ops in sorted(by_layer.items()): |
| layer_type = proj_ops[0][3] |
|
|
| |
| |
| |
| partial_new_weights = {proj: delta for proj, _, delta, _ in proj_ops} |
|
|
| |
| |
| |
| proj_tensors = { |
| proj: (model_key, tensors[model_key].clone().float()) |
| for proj, model_key, _, _ in proj_ops |
| } |
| gate_w = proj_tensors.get("gate_proj.weight", (None, torch.empty(0)))[1] |
| up_w = proj_tensors.get("up_proj.weight", (None, torch.empty(0)))[1] |
| down_w = proj_tensors.get("down_proj.weight", (None, torch.empty(0)))[1] |
|
|
| orig_dtypes = { |
| proj: tensors[model_key].dtype |
| for proj, model_key, _, _ in proj_ops |
| } |
|
|
| if not dry_run: |
| fuse_layer_deltas(layer_idx, gate_w, up_w, down_w, partial_new_weights, args) |
| for proj, model_key, _, _ in proj_ops: |
| if proj == "gate_proj.weight": |
| tensors[model_key] = gate_w.to(orig_dtypes[proj]) |
| elif proj == "down_proj.weight": |
| tensors[model_key] = down_w.to(orig_dtypes[proj]) |
| |
|
|
| fused_layer_idxs.add(layer_idx) |
| proj_names = [p.split(".")[0] for p, *_ in proj_ops] |
| print(f" {'[dry]' if dry_run else '[ok]'} Fused layer {layer_idx:02d} [{layer_type}]" |
| f" ({', '.join(proj_names)} in this shard)") |
|
|
| if not dry_run: |
| save_file(tensors, str(shard_dst)) |
| print(f" [ok] Saved shard {shard_name} ({len(by_layer)} layer(s))") |
| del tensors |
|
|
| if skipped > 0: |
| print(f" [warn] {skipped} layer(s) fully skipped, " |
| f"{len(fused_layer_idxs)} unique layer(s) fused.") |
|
|
| return len(fused_layer_idxs) |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description="Apply delta weights to a model via Shape-Contoured Fusion." |
| ) |
| parser.add_argument("--model", required=True) |
| parser.add_argument("--weights", required=True) |
| parser.add_argument("--output", required=True) |
| parser.add_argument("--layers", type=int, nargs="+", default=None) |
| parser.add_argument("--dry-run", action="store_true") |
| parser.add_argument("--alpha", type=float, default=0.02, |
| help="down-proj variance scale multiplier (default: 0.02)") |
| parser.add_argument("--gamma-cap", type=float, default=0.05, |
| help="max fractional gate_proj adjustment (default: 0.05)") |
| args = parser.parse_args() |
|
|
| model_path = Path(args.model) |
| weights_dir = Path(args.weights) |
| output_dir = Path(args.output) |
|
|
| layer_files = discover_generated_layers(weights_dir) |
| if not layer_files: |
| raise FileNotFoundError( |
| f"No layer_*.safetensors files found in: {weights_dir.resolve()}" |
| ) |
| if args.layers is not None: |
| layer_files = {i: layer_files[i] for i in args.layers if i in layer_files} |
| if not layer_files: |
| available = sorted(discover_generated_layers(weights_dir).keys()) |
| raise ValueError(f"--layers filter empty. Available: {available}") |
|
|
| print(f"[info] Found {len(layer_files)} layer file(s): indices {sorted(layer_files.keys())}") |
| print(f"[info] SCF params: alpha={args.alpha}, gamma_cap={args.gamma_cap}") |
|
|
| if not args.dry_run: |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| if model_path.is_file() and model_path.suffix == ".safetensors": |
| apply_single_file(model_path, output_dir, layer_files, args) |
|
|
| elif model_path.is_dir(): |
| single = model_path / "model.safetensors" |
| index = model_path / "model.safetensors.index.json" |
|
|
| if single.exists() and not index.exists(): |
| if not args.dry_run: |
| for f in model_path.iterdir(): |
| if f.name != "model.safetensors": |
| dst = output_dir / f.name |
| if f.is_dir(): |
| shutil.copytree(f, dst, dirs_exist_ok=True) |
| else: |
| shutil.copy2(f, dst) |
| apply_single_file(single, output_dir, layer_files, args) |
|
|
| elif index.exists(): |
| apply_sharded(model_path, output_dir, layer_files, args) |
|
|
| else: |
| raise FileNotFoundError( |
| f"No model.safetensors or model.safetensors.index.json in {model_path}" |
| ) |
| else: |
| raise FileNotFoundError(f"--model not found: {model_path}") |
|
|
| config_path = ( |
| model_path / "config.json" |
| if model_path.is_dir() |
| else model_path.parent / "config.json" |
| ) |
| if config_path.exists() and not args.dry_run: |
| shutil.copy2(config_path, output_dir / "config.json") |
| print(" [ok] Copied config.json (activation unchanged).") |
|
|
|
|
| if __name__ == "__main__": |
| main() |