CompressedGemma commited on
Commit
b53cacf
·
verified ·
1 Parent(s): fb98ac2

Optimize RAM usage

Browse files
Files changed (1) hide show
  1. applyweights.py +29 -12
applyweights.py CHANGED
@@ -21,7 +21,7 @@ from pathlib import Path
21
 
22
  import numpy as np
23
  import torch
24
- from safetensors.torch import load, save_file
25
 
26
  PROJ_KEYS = ("gate_proj.weight", "up_proj.weight", "down_proj.weight")
27
 
@@ -133,8 +133,8 @@ def apply_single_file(model_path: Path, output_dir: Path, layer_files: dict, arg
133
  dry_run = args.dry_run
134
  print(f"\n[model] Processing file: {model_path.name}")
135
 
136
- with open(model_path, "rb") as f:
137
- tensors = load(f.read())
138
 
139
  fused = 0
140
  skipped = 0
@@ -142,8 +142,7 @@ def apply_single_file(model_path: Path, output_dir: Path, layer_files: dict, arg
142
  for layer_idx, layer_path in sorted(layer_files.items()):
143
  layer_type = "global" if is_global_attention_layer(layer_idx) else "swa"
144
 
145
- with open(layer_path, "rb") as f:
146
- new_weights = load(f.read())
147
 
148
  if not any(k in new_weights for k in PROJ_KEYS):
149
  print(f" [skip] Layer {layer_idx}: none of {PROJ_KEYS} found. "
@@ -227,8 +226,7 @@ def apply_sharded(model_dir: Path, output_dir: Path, layer_files: dict, args) ->
227
  for layer_idx, layer_path in sorted(layer_files.items()):
228
  layer_type = "global" if is_global_attention_layer(layer_idx) else "swa"
229
 
230
- with open(layer_path, "rb") as f:
231
- new_weights = load(f.read())
232
 
233
  if not any(k in new_weights for k in PROJ_KEYS):
234
  print(f" [skip] Layer {layer_idx}: none of {PROJ_KEYS} found. "
@@ -260,10 +258,28 @@ def apply_sharded(model_dir: Path, output_dir: Path, layer_files: dict, args) ->
260
  f"No layers matched in weight_map. Sample keys: {sample}"
261
  )
262
 
 
 
 
263
  if not dry_run:
264
- if output_dir.exists():
265
- shutil.rmtree(output_dir)
266
- shutil.copytree(model_dir, output_dir)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
 
268
  fused_layer_idxs: set = set()
269
 
@@ -271,8 +287,8 @@ def apply_sharded(model_dir: Path, output_dir: Path, layer_files: dict, args) ->
271
  shard_src = model_dir / shard_name
272
  shard_dst = output_dir / shard_name
273
 
274
- with open(shard_src, "rb") as f:
275
- tensors = load(f.read())
276
 
277
  # Re-group by layer so fuse_layer_deltas is called once per layer per shard.
278
  by_layer: dict = {}
@@ -320,6 +336,7 @@ def apply_sharded(model_dir: Path, output_dir: Path, layer_files: dict, args) ->
320
  if not dry_run:
321
  save_file(tensors, str(shard_dst))
322
  print(f" [ok] Saved shard {shard_name} ({len(by_layer)} layer(s))")
 
323
 
324
  if skipped > 0:
325
  print(f" [warn] {skipped} layer(s) fully skipped, "
 
21
 
22
  import numpy as np
23
  import torch
24
+ from safetensors.torch import load, load_file, save_file
25
 
26
  PROJ_KEYS = ("gate_proj.weight", "up_proj.weight", "down_proj.weight")
27
 
 
133
  dry_run = args.dry_run
134
  print(f"\n[model] Processing file: {model_path.name}")
135
 
136
+ # load_file uses memory-mapping — avoids reading the whole file into RAM twice
137
+ tensors = load_file(str(model_path))
138
 
139
  fused = 0
140
  skipped = 0
 
142
  for layer_idx, layer_path in sorted(layer_files.items()):
143
  layer_type = "global" if is_global_attention_layer(layer_idx) else "swa"
144
 
145
+ new_weights = load_file(str(layer_path))
 
146
 
147
  if not any(k in new_weights for k in PROJ_KEYS):
148
  print(f" [skip] Layer {layer_idx}: none of {PROJ_KEYS} found. "
 
226
  for layer_idx, layer_path in sorted(layer_files.items()):
227
  layer_type = "global" if is_global_attention_layer(layer_idx) else "swa"
228
 
229
+ new_weights = load_file(str(layer_path))
 
230
 
231
  if not any(k in new_weights for k in PROJ_KEYS):
232
  print(f" [skip] Layer {layer_idx}: none of {PROJ_KEYS} found. "
 
258
  f"No layers matched in weight_map. Sample keys: {sample}"
259
  )
260
 
261
+ # Identify which shards will be modified so we can copy non-modified files lazily.
262
+ modified_shards = set(fusion_plan.keys())
263
+
264
  if not dry_run:
265
+ output_dir.mkdir(parents=True, exist_ok=True)
266
+ # Copy all non-shard files (config, tokenizer, index, etc.) eagerly.
267
+ # Shard files are copied individually just before they are modified,
268
+ # avoiding a full model copy upfront that can exhaust RAM and disk I/O.
269
+ for src_file in model_dir.iterdir():
270
+ dst_file = output_dir / src_file.name
271
+ if src_file.name not in modified_shards:
272
+ if src_file.is_dir():
273
+ shutil.copytree(src_file, dst_file, dirs_exist_ok=True)
274
+ else:
275
+ shutil.copy2(src_file, dst_file)
276
+ # Copy unmodified shards (they just need to be present in the output).
277
+ all_shards = {v for v in weight_map.values()}
278
+ for shard_name in all_shards - modified_shards:
279
+ src = model_dir / shard_name
280
+ dst = output_dir / shard_name
281
+ if src.exists() and not dst.exists():
282
+ shutil.copy2(src, dst)
283
 
284
  fused_layer_idxs: set = set()
285
 
 
287
  shard_src = model_dir / shard_name
288
  shard_dst = output_dir / shard_name
289
 
290
+ # load_file uses memory-mapped I/O — no full f.read() into RAM
291
+ tensors = load_file(str(shard_src))
292
 
293
  # Re-group by layer so fuse_layer_deltas is called once per layer per shard.
294
  by_layer: dict = {}
 
336
  if not dry_run:
337
  save_file(tensors, str(shard_dst))
338
  print(f" [ok] Saved shard {shard_name} ({len(by_layer)} layer(s))")
339
+ del tensors # free RAM before loading next shard
340
 
341
  if skipped > 0:
342
  print(f" [warn] {skipped} layer(s) fully skipped, "