Optimize RAM usage
Browse files- applyweights.py +29 -12
applyweights.py
CHANGED
|
@@ -21,7 +21,7 @@ from pathlib import Path
|
|
| 21 |
|
| 22 |
import numpy as np
|
| 23 |
import torch
|
| 24 |
-
from safetensors.torch import load, save_file
|
| 25 |
|
| 26 |
PROJ_KEYS = ("gate_proj.weight", "up_proj.weight", "down_proj.weight")
|
| 27 |
|
|
@@ -133,8 +133,8 @@ def apply_single_file(model_path: Path, output_dir: Path, layer_files: dict, arg
|
|
| 133 |
dry_run = args.dry_run
|
| 134 |
print(f"\n[model] Processing file: {model_path.name}")
|
| 135 |
|
| 136 |
-
|
| 137 |
-
|
| 138 |
|
| 139 |
fused = 0
|
| 140 |
skipped = 0
|
|
@@ -142,8 +142,7 @@ def apply_single_file(model_path: Path, output_dir: Path, layer_files: dict, arg
|
|
| 142 |
for layer_idx, layer_path in sorted(layer_files.items()):
|
| 143 |
layer_type = "global" if is_global_attention_layer(layer_idx) else "swa"
|
| 144 |
|
| 145 |
-
|
| 146 |
-
new_weights = load(f.read())
|
| 147 |
|
| 148 |
if not any(k in new_weights for k in PROJ_KEYS):
|
| 149 |
print(f" [skip] Layer {layer_idx}: none of {PROJ_KEYS} found. "
|
|
@@ -227,8 +226,7 @@ def apply_sharded(model_dir: Path, output_dir: Path, layer_files: dict, args) ->
|
|
| 227 |
for layer_idx, layer_path in sorted(layer_files.items()):
|
| 228 |
layer_type = "global" if is_global_attention_layer(layer_idx) else "swa"
|
| 229 |
|
| 230 |
-
|
| 231 |
-
new_weights = load(f.read())
|
| 232 |
|
| 233 |
if not any(k in new_weights for k in PROJ_KEYS):
|
| 234 |
print(f" [skip] Layer {layer_idx}: none of {PROJ_KEYS} found. "
|
|
@@ -260,10 +258,28 @@ def apply_sharded(model_dir: Path, output_dir: Path, layer_files: dict, args) ->
|
|
| 260 |
f"No layers matched in weight_map. Sample keys: {sample}"
|
| 261 |
)
|
| 262 |
|
|
|
|
|
|
|
|
|
|
| 263 |
if not dry_run:
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
|
| 268 |
fused_layer_idxs: set = set()
|
| 269 |
|
|
@@ -271,8 +287,8 @@ def apply_sharded(model_dir: Path, output_dir: Path, layer_files: dict, args) ->
|
|
| 271 |
shard_src = model_dir / shard_name
|
| 272 |
shard_dst = output_dir / shard_name
|
| 273 |
|
| 274 |
-
|
| 275 |
-
|
| 276 |
|
| 277 |
# Re-group by layer so fuse_layer_deltas is called once per layer per shard.
|
| 278 |
by_layer: dict = {}
|
|
@@ -320,6 +336,7 @@ def apply_sharded(model_dir: Path, output_dir: Path, layer_files: dict, args) ->
|
|
| 320 |
if not dry_run:
|
| 321 |
save_file(tensors, str(shard_dst))
|
| 322 |
print(f" [ok] Saved shard {shard_name} ({len(by_layer)} layer(s))")
|
|
|
|
| 323 |
|
| 324 |
if skipped > 0:
|
| 325 |
print(f" [warn] {skipped} layer(s) fully skipped, "
|
|
|
|
| 21 |
|
| 22 |
import numpy as np
|
| 23 |
import torch
|
| 24 |
+
from safetensors.torch import load, load_file, save_file
|
| 25 |
|
| 26 |
PROJ_KEYS = ("gate_proj.weight", "up_proj.weight", "down_proj.weight")
|
| 27 |
|
|
|
|
| 133 |
dry_run = args.dry_run
|
| 134 |
print(f"\n[model] Processing file: {model_path.name}")
|
| 135 |
|
| 136 |
+
# load_file uses memory-mapping — avoids reading the whole file into RAM twice
|
| 137 |
+
tensors = load_file(str(model_path))
|
| 138 |
|
| 139 |
fused = 0
|
| 140 |
skipped = 0
|
|
|
|
| 142 |
for layer_idx, layer_path in sorted(layer_files.items()):
|
| 143 |
layer_type = "global" if is_global_attention_layer(layer_idx) else "swa"
|
| 144 |
|
| 145 |
+
new_weights = load_file(str(layer_path))
|
|
|
|
| 146 |
|
| 147 |
if not any(k in new_weights for k in PROJ_KEYS):
|
| 148 |
print(f" [skip] Layer {layer_idx}: none of {PROJ_KEYS} found. "
|
|
|
|
| 226 |
for layer_idx, layer_path in sorted(layer_files.items()):
|
| 227 |
layer_type = "global" if is_global_attention_layer(layer_idx) else "swa"
|
| 228 |
|
| 229 |
+
new_weights = load_file(str(layer_path))
|
|
|
|
| 230 |
|
| 231 |
if not any(k in new_weights for k in PROJ_KEYS):
|
| 232 |
print(f" [skip] Layer {layer_idx}: none of {PROJ_KEYS} found. "
|
|
|
|
| 258 |
f"No layers matched in weight_map. Sample keys: {sample}"
|
| 259 |
)
|
| 260 |
|
| 261 |
+
# Identify which shards will be modified so we can copy non-modified files lazily.
|
| 262 |
+
modified_shards = set(fusion_plan.keys())
|
| 263 |
+
|
| 264 |
if not dry_run:
|
| 265 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 266 |
+
# Copy all non-shard files (config, tokenizer, index, etc.) eagerly.
|
| 267 |
+
# Shard files are copied individually just before they are modified,
|
| 268 |
+
# avoiding a full model copy upfront that can exhaust RAM and disk I/O.
|
| 269 |
+
for src_file in model_dir.iterdir():
|
| 270 |
+
dst_file = output_dir / src_file.name
|
| 271 |
+
if src_file.name not in modified_shards:
|
| 272 |
+
if src_file.is_dir():
|
| 273 |
+
shutil.copytree(src_file, dst_file, dirs_exist_ok=True)
|
| 274 |
+
else:
|
| 275 |
+
shutil.copy2(src_file, dst_file)
|
| 276 |
+
# Copy unmodified shards (they just need to be present in the output).
|
| 277 |
+
all_shards = {v for v in weight_map.values()}
|
| 278 |
+
for shard_name in all_shards - modified_shards:
|
| 279 |
+
src = model_dir / shard_name
|
| 280 |
+
dst = output_dir / shard_name
|
| 281 |
+
if src.exists() and not dst.exists():
|
| 282 |
+
shutil.copy2(src, dst)
|
| 283 |
|
| 284 |
fused_layer_idxs: set = set()
|
| 285 |
|
|
|
|
| 287 |
shard_src = model_dir / shard_name
|
| 288 |
shard_dst = output_dir / shard_name
|
| 289 |
|
| 290 |
+
# load_file uses memory-mapped I/O — no full f.read() into RAM
|
| 291 |
+
tensors = load_file(str(shard_src))
|
| 292 |
|
| 293 |
# Re-group by layer so fuse_layer_deltas is called once per layer per shard.
|
| 294 |
by_layer: dict = {}
|
|
|
|
| 336 |
if not dry_run:
|
| 337 |
save_file(tensors, str(shard_dst))
|
| 338 |
print(f" [ok] Saved shard {shard_name} ({len(by_layer)} layer(s))")
|
| 339 |
+
del tensors # free RAM before loading next shard
|
| 340 |
|
| 341 |
if skipped > 0:
|
| 342 |
print(f" [warn] {skipped} layer(s) fully skipped, "
|