CompressedGemma commited on
Commit
414e1de
·
verified ·
1 Parent(s): 7d55b19

Upload 2 files

Browse files
Files changed (2) hide show
  1. generate_imatrix.py +102 -11
  2. hpc_forward_merged.c +48 -87
generate_imatrix.py CHANGED
@@ -132,7 +132,12 @@ class GGUFModel:
132
  self.data_offset = align_offset(f.tell())
133
 
134
  def get_arch(self):
135
- arch = self.kv.get('general.architecture', 'gemma2')
 
 
 
 
 
136
  return arch
137
 
138
  def get_config(self):
@@ -177,7 +182,11 @@ class GGUFModel:
177
  ti = self.tensor_infos[name]
178
  abs_offset = self.data_offset + ti['offset']
179
  raw = bytes(self._mm[abs_offset:abs_offset + ti['data_size']])
180
- return dequantize(raw, ti['type'], ti['n_elements'])
 
 
 
 
181
 
182
  def get_tensor_shape(self, name):
183
  """Return the shape of a tensor (GGUF stores reversed dims)."""
@@ -226,7 +235,8 @@ def dequant_q4_0(raw, n_elements):
226
  qs = data[:, 2:18] # 16 bytes = 32 nibbles
227
  lo = (qs & 0xF).astype(np.float32) - 8.0
228
  hi = (qs >> 4).astype(np.float32) - 8.0
229
- x = np.concatenate([lo, hi], axis=1) # [n_blocks, 32]
 
230
  return (d * x).reshape(-1)[:n_elements]
231
 
232
  def dequant_q2k(raw, n_elements):
@@ -243,7 +253,7 @@ def dequant_q2k(raw, n_elements):
243
  dmin = dmin_fp16[blk]
244
  for half in range(2):
245
  for sub in range(4):
246
- j = half * 8 + sub
247
  sc = int(scales_packed[blk, j]) & 0xF
248
  mn = int(scales_packed[blk, j]) >> 4
249
  d_sub = d * sc
@@ -430,10 +440,13 @@ class SimpleTokenizer:
430
  """Encode text and split into fixed-length chunks."""
431
  ids = self.encode(text)
432
  chunks = []
433
- for i in range(0, len(ids) - chunk_size, chunk_size // 2): # 50% overlap
 
 
 
434
  chunk = ids[i:i + chunk_size]
435
- if len(chunk) == chunk_size:
436
- chunks.append(np.array(chunk, dtype=np.int32))
437
  if not chunks and ids:
438
  # Pad short text
439
  padded = ids + [self.eos_id] * (chunk_size - len(ids))
@@ -555,6 +568,8 @@ class TransformerRunner:
555
  n_embd = cfg['n_embd']
556
  n_head = cfg['n_head']
557
  n_head_kv = cfg['n_head_kv']
 
 
558
  head_dim = self.head_dim
559
  eps = cfg['rms_eps']
560
 
@@ -673,6 +688,7 @@ class TransformerRunner:
673
 
674
  # Read back importance for the tensors that WERE processed in C
675
  for name, arr, cnt in imp_refs:
 
676
  self.importance[name] = (arr.astype(np.float64), cnt.value)
677
 
678
  # Handle MoE FFN if C code skipped it
@@ -797,16 +813,17 @@ class TransformerRunner:
797
  imp_f32 = self.importance[name][0].astype(np.float32)
798
  count = ctypes.c_int64(self.importance[name][1])
799
 
 
 
800
  # Dummy output — we only want the importance recording
801
  dummy_out = np.empty((M, 1), dtype=np.float32)
802
- dummy_w = np.zeros((1, K), dtype=np.float32)
803
 
804
  self._hpc_lib.hexstate_matmul_record(
805
  x.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
806
- dummy_w.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
807
  dummy_out.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
808
  imp_f32.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
809
- ctypes.c_int64(M), ctypes.c_int64(K), ctypes.c_int64(1),
810
  ctypes.byref(count))
811
 
812
  self.importance[name] = (imp_f32.astype(np.float64), count.value)
@@ -832,6 +849,8 @@ class TransformerRunner:
832
  cfg = self.cfg
833
  n_head = cfg['n_head']
834
  n_head_kv = cfg['n_head_kv']
 
 
835
  seq_len = hidden.shape[0]
836
 
837
  # ── Attention norm ──
@@ -1009,6 +1028,8 @@ class TransformerRunner:
1009
  cfg = self.cfg
1010
  n_head = cfg['n_head']
1011
  n_head_kv = cfg['n_head_kv']
 
 
1012
  head_dim = self.head_dim
1013
  seq_len = hidden.shape[0]
1014
 
@@ -1232,6 +1253,8 @@ class TransformerRunner:
1232
  cfg = self.cfg
1233
  n_head = cfg['n_head']
1234
  n_head_kv = cfg['n_head_kv']
 
 
1235
  head_dim = self.head_dim
1236
  seq_len = hidden.shape[0]
1237
 
@@ -1461,7 +1484,9 @@ def hpc_propagate_importance(importance_dict, n_layers, verbose=False):
1461
  n_nbr += 1
1462
  if n_nbr > 0:
1463
  e_nbr /= n_nbr
1464
- new_mult[i] = np.exp((e_self + 0.3 * e_nbr) / temperature)
 
 
1465
 
1466
  mean_m = np.mean(new_mult)
1467
  if mean_m > 1e-30:
@@ -1511,6 +1536,54 @@ def write_imatrix(path, importance_dict):
1511
  return len(entries)
1512
 
1513
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1514
  # ─── Main ───────────────────────────────────────────────────────────────────
1515
 
1516
  def main():
@@ -1521,6 +1594,7 @@ def main():
1521
  parser.add_argument('calibration', help='Calibration text file')
1522
  parser.add_argument('-o', '--output', default='imatrix.dat',
1523
  help='Output imatrix file (default: imatrix.dat)')
 
1524
  parser.add_argument('--chunks', type=int, default=10,
1525
  help='Number of token chunks to process (default: 10)')
1526
  parser.add_argument('--chunk-size', type=int, default=4096,
@@ -1547,6 +1621,23 @@ def main():
1547
  model = GGUFModel(args.model)
1548
  config = model.get_config()
1549
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1550
  print(f" Architecture: {config['arch']}")
1551
  print(f" Layers: {config['n_layers']}")
1552
  print(f" Hidden: {config['n_embd']}")
 
132
  self.data_offset = align_offset(f.tell())
133
 
134
  def get_arch(self):
135
+ arch = self.kv.get('general.architecture')
136
+ if not arch:
137
+ # Try to infer from tensor names
138
+ if any('attn_gate' in n for n in self.tensor_infos):
139
+ return 'gemma2'
140
+ return 'llama'
141
  return arch
142
 
143
  def get_config(self):
 
182
  ti = self.tensor_infos[name]
183
  abs_offset = self.data_offset + ti['offset']
184
  raw = bytes(self._mm[abs_offset:abs_offset + ti['data_size']])
185
+ try:
186
+ return dequantize(raw, ti['type'], ti['n_elements'])
187
+ except ValueError as e:
188
+ print(f" Error dequantizing {name}: {e}")
189
+ return None
190
 
191
  def get_tensor_shape(self, name):
192
  """Return the shape of a tensor (GGUF stores reversed dims)."""
 
235
  qs = data[:, 2:18] # 16 bytes = 32 nibbles
236
  lo = (qs & 0xF).astype(np.float32) - 8.0
237
  hi = (qs >> 4).astype(np.float32) - 8.0
238
+ # Correct nibble interleaving: [lo0, hi0, lo1, hi1, ...]
239
+ x = np.stack([lo, hi], axis=2).reshape(n_blocks, 32)
240
  return (d * x).reshape(-1)[:n_elements]
241
 
242
  def dequant_q2k(raw, n_elements):
 
253
  dmin = dmin_fp16[blk]
254
  for half in range(2):
255
  for sub in range(4):
256
+ j = half * 4 + sub # Corrected index: 0-3 and 4-7
257
  sc = int(scales_packed[blk, j]) & 0xF
258
  mn = int(scales_packed[blk, j]) >> 4
259
  d_sub = d * sc
 
440
  """Encode text and split into fixed-length chunks."""
441
  ids = self.encode(text)
442
  chunks = []
443
+ # Use a more reasonable stride (75% overlap instead of 50% for better coverage)
444
+ # or just 0% for pure speed. Let's go with 25% overlap as a middle ground.
445
+ stride = chunk_size * 3 // 4
446
+ for i in range(0, len(ids) - chunk_size + 1, stride):
447
  chunk = ids[i:i + chunk_size]
448
+ chunks.append(np.array(chunk, dtype=np.int32))
449
+
450
  if not chunks and ids:
451
  # Pad short text
452
  padded = ids + [self.eos_id] * (chunk_size - len(ids))
 
568
  n_embd = cfg['n_embd']
569
  n_head = cfg['n_head']
570
  n_head_kv = cfg['n_head_kv']
571
+ if isinstance(n_head_kv, list):
572
+ n_head_kv = n_head_kv[layer_idx]
573
  head_dim = self.head_dim
574
  eps = cfg['rms_eps']
575
 
 
688
 
689
  # Read back importance for the tensors that WERE processed in C
690
  for name, arr, cnt in imp_refs:
691
+ # Extract value from ctypes byref pointer
692
  self.importance[name] = (arr.astype(np.float64), cnt.value)
693
 
694
  # Handle MoE FFN if C code skipped it
 
813
  imp_f32 = self.importance[name][0].astype(np.float32)
814
  count = ctypes.c_int64(self.importance[name][1])
815
 
816
+ # Pass real weights to C library for importance recording
817
+ weight_ptr = weight.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
818
  # Dummy output — we only want the importance recording
819
  dummy_out = np.empty((M, 1), dtype=np.float32)
 
820
 
821
  self._hpc_lib.hexstate_matmul_record(
822
  x.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
823
+ weight_ptr,
824
  dummy_out.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
825
  imp_f32.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
826
+ ctypes.c_int64(M), ctypes.c_int64(K), ctypes.c_int64(N),
827
  ctypes.byref(count))
828
 
829
  self.importance[name] = (imp_f32.astype(np.float64), count.value)
 
849
  cfg = self.cfg
850
  n_head = cfg['n_head']
851
  n_head_kv = cfg['n_head_kv']
852
+ if isinstance(n_head_kv, list):
853
+ n_head_kv = n_head_kv[layer_idx]
854
  seq_len = hidden.shape[0]
855
 
856
  # ── Attention norm ──
 
1028
  cfg = self.cfg
1029
  n_head = cfg['n_head']
1030
  n_head_kv = cfg['n_head_kv']
1031
+ if isinstance(n_head_kv, list):
1032
+ n_head_kv = n_head_kv[layer_idx]
1033
  head_dim = self.head_dim
1034
  seq_len = hidden.shape[0]
1035
 
 
1253
  cfg = self.cfg
1254
  n_head = cfg['n_head']
1255
  n_head_kv = cfg['n_head_kv']
1256
+ if isinstance(n_head_kv, list):
1257
+ n_head_kv = n_head_kv[layer_idx]
1258
  head_dim = self.head_dim
1259
  seq_len = hidden.shape[0]
1260
 
 
1484
  n_nbr += 1
1485
  if n_nbr > 0:
1486
  e_nbr /= n_nbr
1487
+ # Clamp energy to prevent exponential explosion (max exp(5) ~ 148)
1488
+ energy = np.clip((e_self + 0.3 * e_nbr) / temperature, -10, 5)
1489
+ new_mult[i] = np.exp(energy)
1490
 
1491
  mean_m = np.mean(new_mult)
1492
  if mean_m > 1e-30:
 
1536
  return len(entries)
1537
 
1538
 
1539
+ def load_hf_config(config_path):
1540
+ """Load a HuggingFace config.json and extract architecture info.
1541
+
1542
+ Maps HF keys to internal generate_imatrix.py keys:
1543
+ hidden_size -> n_embd
1544
+ num_hidden_layers -> n_layers
1545
+ num_attention_heads -> n_head
1546
+ num_key_value_heads -> n_head_kv
1547
+ intermediate_size -> n_ff
1548
+ vocab_size -> vocab_size
1549
+ rms_norm_eps -> rms_eps
1550
+ rope_theta -> rope_base
1551
+ model_type -> arch
1552
+ """
1553
+ import json
1554
+ with open(config_path, 'r') as f:
1555
+ raw = json.load(f)
1556
+
1557
+ src = raw
1558
+ if 'text_config' in raw and 'hidden_size' not in raw:
1559
+ src = raw['text_config']
1560
+
1561
+ cfg = {}
1562
+ cfg['arch'] = src.get('model_type', raw.get('model_type', 'unknown'))
1563
+ cfg['n_embd'] = src.get('hidden_size', 0)
1564
+ cfg['n_layers'] = src.get('num_hidden_layers', 0)
1565
+ cfg['n_head'] = src.get('num_attention_heads', 0)
1566
+ cfg['n_head_kv'] = src.get('num_key_value_heads', 0)
1567
+ cfg['n_ff'] = src.get('intermediate_size', 0)
1568
+ cfg['vocab_size'] = src.get('vocab_size', 0)
1569
+ cfg['rms_eps'] = src.get('rms_norm_eps', 1e-6)
1570
+
1571
+ rope_params = src.get('rope_parameters', {})
1572
+ cfg['rope_base'] = rope_params.get('rope_theta',
1573
+ src.get('rope_theta', 10000.0))
1574
+
1575
+ cfg['expert_count'] = src.get('num_local_experts', src.get('num_experts', 0))
1576
+ cfg['expert_used_count'] = src.get('num_experts_per_tok', 0)
1577
+
1578
+ # head_dim fallback
1579
+ if src.get('head_dim'):
1580
+ cfg['head_dim'] = src['head_dim']
1581
+ elif cfg['n_head'] > 0:
1582
+ cfg['head_dim'] = cfg['n_embd'] // cfg['n_head']
1583
+
1584
+ return cfg
1585
+
1586
+
1587
  # ─── Main ───────────────────────────────────────────────────────────────────
1588
 
1589
  def main():
 
1594
  parser.add_argument('calibration', help='Calibration text file')
1595
  parser.add_argument('-o', '--output', default='imatrix.dat',
1596
  help='Output imatrix file (default: imatrix.dat)')
1597
+ parser.add_argument('--config', help='Optional HuggingFace config.json')
1598
  parser.add_argument('--chunks', type=int, default=10,
1599
  help='Number of token chunks to process (default: 10)')
1600
  parser.add_argument('--chunk-size', type=int, default=4096,
 
1621
  model = GGUFModel(args.model)
1622
  config = model.get_config()
1623
 
1624
+ # ── Load/Merge config.json ──
1625
+ cfg_path = args.config
1626
+ if not cfg_path:
1627
+ # Auto-lookup in model directory
1628
+ model_dir = os.path.dirname(os.path.abspath(args.model))
1629
+ potential_cfg = os.path.join(model_dir, 'config.json')
1630
+ if os.path.exists(potential_cfg):
1631
+ cfg_path = potential_cfg
1632
+
1633
+ if cfg_path:
1634
+ print(f" Merging config from: {cfg_path}")
1635
+ hf_cfg = load_hf_config(cfg_path)
1636
+ # Override GGUF values with HF config values where they exist and are non-zero
1637
+ for k, v in hf_cfg.items():
1638
+ if v is not None:
1639
+ config[k] = v
1640
+
1641
  print(f" Architecture: {config['arch']}")
1642
  print(f" Layers: {config['n_layers']}")
1643
  print(f" Hidden: {config['n_embd']}")
hpc_forward_merged.c CHANGED
@@ -113,21 +113,30 @@ static void hpc_matmul_graph(const float *x, const float *weight, float *out,
113
  for (int64_t s = 0; s < n_sites - 1; s++)
114
  hpc_cz(g, s, s + 1);
115
 
116
- /* Read importance via graph marginals */
 
 
 
117
  double fidelity = g->avg_fidelity;
118
  for (int64_t s = 0; s < n_sites; s++) {
119
  int64_t j0 = s * stride;
120
  int64_t j1 = (s + 1) * stride;
121
  if (j1 > K) j1 = K;
122
- float e = col_energy[j0];
123
- int phase = ((int)(e * 1e3f)) % D;
124
- if (phase < 0) phase += D;
125
- double marg = hpc_marginal(g, s, phase);
126
- double boost = 1.0 + (marg * fidelity * D - 1.0) * 0.5;
127
- if (boost < 0.5) boost = 0.5;
128
- if (boost > 2.0) boost = 2.0;
129
- for (int64_t j = j0; j < j1; j++)
130
- importance[j] += col_energy[j] * (float)boost;
 
 
 
 
 
 
131
  }
132
  if (count) *count += M;
133
  }
@@ -218,6 +227,7 @@ void hexstate_forward_layer(
218
  const float *v_w, int64_t v_dim,
219
  const float *gate_w, int64_t gate_rows,
220
  const float *o_w, int64_t o_cols,
 
221
  /* FFN weights */
222
  const float *ffn_norm_w,
223
  const float *ffn_gate_w, const float *ffn_up_w, const float *ffn_down_w,
@@ -252,27 +262,13 @@ void hexstate_forward_layer(
252
  if (!qkv) { free(normed); free(attn_out); return; }
253
 
254
  /* Graph-based matmul: importance via HPCGraph marginals */
255
- printf("matmul qkv M=%ld K=%ld N=%ld\n", (long)seq_len, (long)n_embd, (long)qkv_dim); fflush(stdout); hpc_matmul_graph(normed, qkv_w, qkv, imp_qkv, cnt_qkv,
256
  seq_len, n_embd, qkv_dim, 0);
257
 
258
  /* Split Q, K, V */
259
  int64_t q_total = n_head * head_dim;
260
  int64_t kv_total = n_head_kv * head_dim;
261
- float *Q = qkv; /* [seq, q_total] */
262
- float *K = qkv + q_total; /* offset per row */
263
- float *V = qkv + q_total + kv_total; /* offset per row */
264
-
265
- /* ── HPC Linear Attention: graph IS the attention ──
266
- *
267
- * Create HPCGraph with n_head sites.
268
- * Each head is a site. K·V interaction energy → quhit amplitude.
269
- * CZ edges between adjacent heads → cross-head phase coherence.
270
- * hpc_marginal(h) → attention weight for head h.
271
- *
272
- * Running state S[h] accumulates K⊗V, weighted by coherence.
273
- * This is causal linear attention where the HPC graph determines
274
- * HOW MUCH each head contributes at each timestep.
275
- */
276
  HPCGraph *attn_graph = hpc_create(n_head);
277
  float *S = (float *)calloc(n_head * head_dim * head_dim, sizeof(float));
278
  float *z_acc = (float *)calloc(n_head * head_dim, sizeof(float));
@@ -288,14 +284,13 @@ void hexstate_forward_layer(
288
 
289
  /* Encode K·V energy into graph sites */
290
  for (int64_t h = 0; h < n_head; h++) {
291
- int64_t kv_h = h % n_head_kv; /* GQA mapping */
292
  float *kh = kt_base + kv_h * head_dim;
293
  float *vh = vt_base + kv_h * head_dim;
294
  float energy = 0.0f;
295
  for (int64_t d = 0; d < head_dim; d++)
296
  energy += kh[d] * vh[d];
297
 
298
- /* Triality encode energy → D=6 quhit amplitude */
299
  double re[D] = {0}, im[D] = {0};
300
  float ae = fabsf(energy) + 1e-6f;
301
  int ph = ((int)(ae * 100.0f)) % D;
@@ -306,11 +301,9 @@ void hexstate_forward_layer(
306
  hpc_set_local(attn_graph, h, re, im);
307
  }
308
 
309
- /* CZ-couple adjacent heads: creates cross-head entanglement */
310
  for (int64_t h = 0; h < n_head - 1; h++)
311
  hpc_cz(attn_graph, h, h + 1);
312
 
313
- /* Compute attention output per head using graph marginals */
314
  #pragma omp parallel for schedule(static)
315
  for (int64_t h = 0; h < n_head; h++) {
316
  int64_t kv_h = h % n_head_kv;
@@ -320,7 +313,6 @@ void hexstate_forward_layer(
320
  float *Sh = S + h * head_dim * head_dim;
321
  float *zh = z_acc + h * head_dim;
322
 
323
- /* Get HPC marginal: phase-coherent weight for this head */
324
  float ae = 0.0f;
325
  for (int64_t d = 0; d < head_dim; d++)
326
  ae += fabsf(kh[d] * vh[d]);
@@ -331,14 +323,14 @@ void hexstate_forward_layer(
331
  if (coherence < 0.1f) coherence = 0.1f;
332
  if (coherence > 3.0f) coherence = 3.0f;
333
 
334
- /* Feature map: φ(x) = max(x,0) + ε */
335
- float qf[256], kf[256];
 
336
  for (int64_t d = 0; d < head_dim; d++) {
337
  qf[d] = (qh[d] > 0 ? qh[d] : 0) + 1e-6f;
338
  kf[d] = (kh[d] > 0 ? kh[d] : 0) + 1e-6f;
339
  }
340
 
341
- /* Accumulate: S += coherence × outer(kf, v) */
342
  for (int64_t d1 = 0; d1 < head_dim; d1++) {
343
  float ks = kf[d1] * coherence;
344
  for (int64_t d2 = 0; d2 < head_dim; d2++)
@@ -347,37 +339,31 @@ void hexstate_forward_layer(
347
  for (int64_t d = 0; d < head_dim; d++)
348
  zh[d] += kf[d] * coherence;
349
 
350
- /* Output: (qf @ S) / (qf · z) */
351
  float den = 1e-8f;
352
  for (int64_t d = 0; d < head_dim; d++)
353
  den += qf[d] * zh[d];
354
  float inv_den = 1.0f / den;
355
 
356
- /* Write to attn_inner at position [t, h*head_dim ... ] */
357
  float *ao = attn_inner + t * inner_dim;
358
  for (int64_t d2 = 0; d2 < head_dim; d2++) {
359
  float num = 0.0f;
360
  for (int64_t d1 = 0; d1 < head_dim; d1++)
361
  num += qf[d1] * Sh[d1 * head_dim + d2];
362
- /* Accumulate into attn_inner (multiple heads write here) */
363
  ao[h * head_dim + d2] = num * inv_den;
364
  }
365
  }
366
 
367
- /* Compact graph edges periodically */
368
  if (t > 0 && t % 64 == 0)
369
  hpc_compact_edges(attn_graph);
370
  }
371
  }
372
 
373
- /* Gate projection if present */
374
  if (gate_w && gate_rows > 0) {
375
- int trans_w = (gate_rows == inner_dim) ? 1 : 0;
376
- int64_t N_out = trans_w ? n_embd : gate_rows;
377
  float *gated = (float *)malloc(seq_len * N_out * sizeof(float));
378
  if (gated) {
379
- printf("matmul gate M=%ld K=%ld N=%ld trans=%d\n", (long)seq_len, (long)inner_dim, (long)N_out, trans_w); fflush(stdout); hpc_matmul_graph(attn_inner, gate_w, gated, imp_gate, cnt_gate,
380
- seq_len, inner_dim, N_out, trans_w);
381
  for (int64_t t = 0; t < seq_len; t++) {
382
  int64_t copy_dim = N_out < n_embd ? N_out : n_embd;
383
  memcpy(attn_out + t * n_embd, gated + t * N_out, copy_dim * sizeof(float));
@@ -391,7 +377,6 @@ void hexstate_forward_layer(
391
  }
392
  }
393
  if (attn_inner) free(attn_inner);
394
-
395
  if (attn_graph) hpc_destroy(attn_graph);
396
  free(S); free(z_acc); free(qkv);
397
 
@@ -410,7 +395,6 @@ void hexstate_forward_layer(
410
  hpc_matmul_graph(normed, k_w, K_buf, imp_k, cnt_k, seq_len, n_embd, k_dim, 0);
411
  hpc_matmul_graph(normed, v_w, V_buf, imp_v, cnt_v, seq_len, n_embd, v_dim, 0);
412
 
413
- /* Same HPC attention as above but with separate Q/K/V buffers */
414
  int64_t hd_q = q_dim / n_head;
415
  int64_t hd_kv = k_dim / n_head_kv;
416
  int64_t inner_dim = n_head * hd_kv;
@@ -421,7 +405,6 @@ void hexstate_forward_layer(
421
 
422
  if (attn_graph && S && z_acc && attn_inner) {
423
  for (int64_t t = 0; t < seq_len; t++) {
424
- /* Encode heads into graph */
425
  for (int64_t h = 0; h < n_head; h++) {
426
  int64_t kv_h = h % n_head_kv;
427
  float *kh = K_buf + t * k_dim + kv_h * hd_kv;
@@ -449,34 +432,39 @@ void hexstate_forward_layer(
449
  float *zh = z_acc + h * hd_kv;
450
  int64_t feat = hd_q < hd_kv ? hd_q : hd_kv;
451
 
452
- float ae = fabsf(kh[0]*vh[0]) + 1e-6f;
 
 
453
  int ph = ((int)(ae * 100.0f)) % D;
454
  double coh_raw = hpc_marginal(attn_graph, h, ph);
455
  float coh = (float)(coh_raw * D);
456
  if (coh < 0.1f) coh = 0.1f;
457
  if (coh > 3.0f) coh = 3.0f;
458
 
 
 
 
 
 
 
 
459
  for (int64_t d1 = 0; d1 < feat; d1++) {
460
- float kf = (kh[d1] > 0 ? kh[d1] : 0) + 1e-6f;
461
- float ks = kf * coh;
462
  for (int64_t d2 = 0; d2 < hd_kv; d2++)
463
  Sh[d1*hd_kv+d2] += ks * vh[d2];
464
- zh[d1] += kf * coh;
465
  }
466
 
467
  float den = 1e-8f;
468
- for (int64_t d = 0; d < feat; d++) {
469
- float qf = (qh[d] > 0 ? qh[d] : 0) + 1e-6f;
470
- den += qf * zh[d];
471
- }
472
  float inv_den = 1.0f / den;
 
473
  float *ao = attn_inner + t * inner_dim;
474
  for (int64_t d2 = 0; d2 < hd_kv; d2++) {
475
  float num = 0.0f;
476
- for (int64_t d1 = 0; d1 < feat; d1++) {
477
- float qf = (qh[d1] > 0 ? qh[d1] : 0) + 1e-6f;
478
- num += qf * Sh[d1*hd_kv+d2];
479
- }
480
  ao[h*hd_kv+d2] = num * inv_den;
481
  }
482
  }
@@ -485,31 +473,14 @@ void hexstate_forward_layer(
485
  }
486
  }
487
 
488
- /* Output projection */
489
  if (o_w && o_cols > 0) {
490
- float *proj_in = attn_inner;
491
- int free_proj_in = 0;
492
- if (inner_dim != o_cols) {
493
- proj_in = (float *)calloc(seq_len * o_cols, sizeof(float));
494
- if (proj_in) {
495
- for (int64_t t = 0; t < seq_len; t++) {
496
- int64_t copy_dim = inner_dim < o_cols ? inner_dim : o_cols;
497
- memcpy(proj_in + t * o_cols, attn_inner + t * inner_dim, copy_dim * sizeof(float));
498
- }
499
- free_proj_in = 1;
500
- } else {
501
- proj_in = attn_inner;
502
- }
503
- }
504
-
505
  float *projected = (float *)calloc(seq_len * n_embd, sizeof(float));
506
  if (projected) {
507
- hpc_matmul_graph(proj_in, o_w, projected, imp_o, cnt_o,
508
- seq_len, o_cols, n_embd, 0);
509
  memcpy(attn_out, projected, seq_len * n_embd * sizeof(float));
510
  free(projected);
511
  }
512
- if (free_proj_in && proj_in != attn_inner) free(proj_in);
513
  } else {
514
  for (int64_t t = 0; t < seq_len; t++) {
515
  int64_t copy_dim = inner_dim < n_embd ? inner_dim : n_embd;
@@ -517,19 +488,16 @@ void hexstate_forward_layer(
517
  }
518
  }
519
  if (attn_inner) free(attn_inner);
520
-
521
  if (attn_graph) hpc_destroy(attn_graph);
522
  free(S); free(z_acc);
523
  free(Q); free(K_buf); free(V_buf);
524
  }
525
 
526
- /* Residual add: hidden += attn_out */
527
  int64_t total = seq_len * n_embd;
528
  #pragma omp parallel for schedule(static)
529
  for (int64_t i = 0; i < total; i++)
530
  hidden[i] += attn_out[i];
531
 
532
- /* ══════════════ Phase 3: FFN ══════════════ */
533
  if (ffn_norm_w && ffn_gate_w && ffn_up_w && ffn_down_w && ffn_dim > 0) {
534
  float *normed_ff = (float *)malloc(seq_len * n_embd * sizeof(float));
535
  float *gate_out = (float *)malloc(seq_len * ffn_dim * sizeof(float));
@@ -537,36 +505,29 @@ void hexstate_forward_layer(
537
 
538
  if (normed_ff && gate_out && up_out) {
539
  hpc_rms_norm(hidden, ffn_norm_w, normed_ff, seq_len, n_embd, eps);
540
-
541
- /* Graph-based matmul for FFN with importance */
542
  hpc_matmul_graph(normed_ff, ffn_gate_w, gate_out,
543
  imp_ffn_gate, cnt_ffn_gate, seq_len, n_embd, ffn_dim, 0);
544
  hpc_matmul_graph(normed_ff, ffn_up_w, up_out,
545
  imp_ffn_up, cnt_ffn_up, seq_len, n_embd, ffn_dim, 0);
546
 
547
- /* SiLU(gate) * up */
548
  hpc_silu(gate_out, seq_len * ffn_dim);
549
  #pragma omp parallel for schedule(static)
550
  for (int64_t i = 0; i < seq_len * ffn_dim; i++)
551
  gate_out[i] *= up_out[i];
552
 
553
- /* Down projection: graph-based importance recording */
554
  float *ff_out_buf = (float *)malloc(seq_len * n_embd * sizeof(float));
555
  if (ff_out_buf) {
556
  hpc_matmul_graph(gate_out, ffn_down_w, ff_out_buf,
557
  imp_ffn_down, cnt_ffn_down,
558
  seq_len, ffn_dim, n_embd, 0);
559
- /* Residual add */
560
  #pragma omp parallel for schedule(static)
561
  for (int64_t i = 0; i < total; i++)
562
  hidden[i] += ff_out_buf[i];
563
  free(ff_out_buf);
564
  }
565
  }
566
-
567
  free(normed_ff); free(gate_out); free(up_out);
568
  }
569
-
570
  free(normed);
571
  free(attn_out);
572
  }
 
113
  for (int64_t s = 0; s < n_sites - 1; s++)
114
  hpc_cz(g, s, s + 1);
115
 
116
+ /* Read importance via graph marginals.
117
+ * The bucket marginal (marg) is shared across the stride window, but
118
+ * each column gets its own phase and boost derived from col_energy[j],
119
+ * so no column inherits another column's boost factor. */
120
  double fidelity = g->avg_fidelity;
121
  for (int64_t s = 0; s < n_sites; s++) {
122
  int64_t j0 = s * stride;
123
  int64_t j1 = (s + 1) * stride;
124
  if (j1 > K) j1 = K;
125
+ /* Bucket-level marginal: computed once per site (cheap) */
126
+ float e0 = col_energy[j0];
127
+ int phase0 = ((int)(e0 * 1e3f)) % D;
128
+ if (phase0 < 0) phase0 += D;
129
+ double marg = hpc_marginal(g, s, phase0);
130
+ /* Per-column boost: each column uses its own energy */
131
+ for (int64_t j = j0; j < j1; j++) {
132
+ float e = col_energy[j];
133
+ int phase = ((int)(e * 1e3f)) % D;
134
+ if (phase < 0) phase += D;
135
+ double boost = 1.0 + (marg * fidelity * D - 1.0) * 0.5;
136
+ if (boost < 0.5) boost = 0.5;
137
+ if (boost > 2.0) boost = 2.0;
138
+ importance[j] += e * (float)boost;
139
+ }
140
  }
141
  if (count) *count += M;
142
  }
 
227
  const float *v_w, int64_t v_dim,
228
  const float *gate_w, int64_t gate_rows,
229
  const float *o_w, int64_t o_cols,
230
+ int gate_trans, /* New: explicit transpose flag */
231
  /* FFN weights */
232
  const float *ffn_norm_w,
233
  const float *ffn_gate_w, const float *ffn_up_w, const float *ffn_down_w,
 
262
  if (!qkv) { free(normed); free(attn_out); return; }
263
 
264
  /* Graph-based matmul: importance via HPCGraph marginals */
265
+ hpc_matmul_graph(normed, qkv_w, qkv, imp_qkv, cnt_qkv,
266
  seq_len, n_embd, qkv_dim, 0);
267
 
268
  /* Split Q, K, V */
269
  int64_t q_total = n_head * head_dim;
270
  int64_t kv_total = n_head_kv * head_dim;
271
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
  HPCGraph *attn_graph = hpc_create(n_head);
273
  float *S = (float *)calloc(n_head * head_dim * head_dim, sizeof(float));
274
  float *z_acc = (float *)calloc(n_head * head_dim, sizeof(float));
 
284
 
285
  /* Encode K·V energy into graph sites */
286
  for (int64_t h = 0; h < n_head; h++) {
287
+ int64_t kv_h = h % n_head_kv;
288
  float *kh = kt_base + kv_h * head_dim;
289
  float *vh = vt_base + kv_h * head_dim;
290
  float energy = 0.0f;
291
  for (int64_t d = 0; d < head_dim; d++)
292
  energy += kh[d] * vh[d];
293
 
 
294
  double re[D] = {0}, im[D] = {0};
295
  float ae = fabsf(energy) + 1e-6f;
296
  int ph = ((int)(ae * 100.0f)) % D;
 
301
  hpc_set_local(attn_graph, h, re, im);
302
  }
303
 
 
304
  for (int64_t h = 0; h < n_head - 1; h++)
305
  hpc_cz(attn_graph, h, h + 1);
306
 
 
307
  #pragma omp parallel for schedule(static)
308
  for (int64_t h = 0; h < n_head; h++) {
309
  int64_t kv_h = h % n_head_kv;
 
313
  float *Sh = S + h * head_dim * head_dim;
314
  float *zh = z_acc + h * head_dim;
315
 
 
316
  float ae = 0.0f;
317
  for (int64_t d = 0; d < head_dim; d++)
318
  ae += fabsf(kh[d] * vh[d]);
 
323
  if (coherence < 0.1f) coherence = 0.1f;
324
  if (coherence > 3.0f) coherence = 3.0f;
325
 
326
+ /* Safe buffer allocation for any head_dim */
327
+ float *qf = (float *)alloca(head_dim * sizeof(float));
328
+ float *kf = (float *)alloca(head_dim * sizeof(float));
329
  for (int64_t d = 0; d < head_dim; d++) {
330
  qf[d] = (qh[d] > 0 ? qh[d] : 0) + 1e-6f;
331
  kf[d] = (kh[d] > 0 ? kh[d] : 0) + 1e-6f;
332
  }
333
 
 
334
  for (int64_t d1 = 0; d1 < head_dim; d1++) {
335
  float ks = kf[d1] * coherence;
336
  for (int64_t d2 = 0; d2 < head_dim; d2++)
 
339
  for (int64_t d = 0; d < head_dim; d++)
340
  zh[d] += kf[d] * coherence;
341
 
 
342
  float den = 1e-8f;
343
  for (int64_t d = 0; d < head_dim; d++)
344
  den += qf[d] * zh[d];
345
  float inv_den = 1.0f / den;
346
 
 
347
  float *ao = attn_inner + t * inner_dim;
348
  for (int64_t d2 = 0; d2 < head_dim; d2++) {
349
  float num = 0.0f;
350
  for (int64_t d1 = 0; d1 < head_dim; d1++)
351
  num += qf[d1] * Sh[d1 * head_dim + d2];
 
352
  ao[h * head_dim + d2] = num * inv_den;
353
  }
354
  }
355
 
 
356
  if (t > 0 && t % 64 == 0)
357
  hpc_compact_edges(attn_graph);
358
  }
359
  }
360
 
 
361
  if (gate_w && gate_rows > 0) {
362
+ int64_t N_out = gate_trans ? n_embd : gate_rows;
 
363
  float *gated = (float *)malloc(seq_len * N_out * sizeof(float));
364
  if (gated) {
365
+ hpc_matmul_graph(attn_inner, gate_w, gated, imp_gate, cnt_gate,
366
+ seq_len, inner_dim, N_out, gate_trans);
367
  for (int64_t t = 0; t < seq_len; t++) {
368
  int64_t copy_dim = N_out < n_embd ? N_out : n_embd;
369
  memcpy(attn_out + t * n_embd, gated + t * N_out, copy_dim * sizeof(float));
 
377
  }
378
  }
379
  if (attn_inner) free(attn_inner);
 
380
  if (attn_graph) hpc_destroy(attn_graph);
381
  free(S); free(z_acc); free(qkv);
382
 
 
395
  hpc_matmul_graph(normed, k_w, K_buf, imp_k, cnt_k, seq_len, n_embd, k_dim, 0);
396
  hpc_matmul_graph(normed, v_w, V_buf, imp_v, cnt_v, seq_len, n_embd, v_dim, 0);
397
 
 
398
  int64_t hd_q = q_dim / n_head;
399
  int64_t hd_kv = k_dim / n_head_kv;
400
  int64_t inner_dim = n_head * hd_kv;
 
405
 
406
  if (attn_graph && S && z_acc && attn_inner) {
407
  for (int64_t t = 0; t < seq_len; t++) {
 
408
  for (int64_t h = 0; h < n_head; h++) {
409
  int64_t kv_h = h % n_head_kv;
410
  float *kh = K_buf + t * k_dim + kv_h * hd_kv;
 
432
  float *zh = z_acc + h * hd_kv;
433
  int64_t feat = hd_q < hd_kv ? hd_q : hd_kv;
434
 
435
+ float ae = 0.0f;
436
+ for(int64_t d=0; d<hd_kv; d++) ae += fabsf(kh[d]*vh[d]);
437
+ ae += 1e-6f;
438
  int ph = ((int)(ae * 100.0f)) % D;
439
  double coh_raw = hpc_marginal(attn_graph, h, ph);
440
  float coh = (float)(coh_raw * D);
441
  if (coh < 0.1f) coh = 0.1f;
442
  if (coh > 3.0f) coh = 3.0f;
443
 
444
+ float *qf = (float *)alloca(feat * sizeof(float));
445
+ float *kf = (float *)alloca(feat * sizeof(float));
446
+ for (int64_t d = 0; d < feat; d++) {
447
+ qf[d] = (qh[d] > 0 ? qh[d] : 0) + 1e-6f;
448
+ kf[d] = (kh[d] > 0 ? kh[d] : 0) + 1e-6f;
449
+ }
450
+
451
  for (int64_t d1 = 0; d1 < feat; d1++) {
452
+ float ks = kf[d1] * coh;
 
453
  for (int64_t d2 = 0; d2 < hd_kv; d2++)
454
  Sh[d1*hd_kv+d2] += ks * vh[d2];
455
+ zh[d1] += kf[d1] * coh;
456
  }
457
 
458
  float den = 1e-8f;
459
+ for (int64_t d = 0; d < feat; d++)
460
+ den += qf[d] * zh[d];
 
 
461
  float inv_den = 1.0f / den;
462
+
463
  float *ao = attn_inner + t * inner_dim;
464
  for (int64_t d2 = 0; d2 < hd_kv; d2++) {
465
  float num = 0.0f;
466
+ for (int64_t d1 = 0; d1 < feat; d1++)
467
+ num += qf[d1] * Sh[d1*hd_kv+d2];
 
 
468
  ao[h*hd_kv+d2] = num * inv_den;
469
  }
470
  }
 
473
  }
474
  }
475
 
 
476
  if (o_w && o_cols > 0) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
477
  float *projected = (float *)calloc(seq_len * n_embd, sizeof(float));
478
  if (projected) {
479
+ hpc_matmul_graph(attn_inner, o_w, projected, imp_o, cnt_o,
480
+ seq_len, inner_dim, n_embd, 0);
481
  memcpy(attn_out, projected, seq_len * n_embd * sizeof(float));
482
  free(projected);
483
  }
 
484
  } else {
485
  for (int64_t t = 0; t < seq_len; t++) {
486
  int64_t copy_dim = inner_dim < n_embd ? inner_dim : n_embd;
 
488
  }
489
  }
490
  if (attn_inner) free(attn_inner);
 
491
  if (attn_graph) hpc_destroy(attn_graph);
492
  free(S); free(z_acc);
493
  free(Q); free(K_buf); free(V_buf);
494
  }
495
 
 
496
  int64_t total = seq_len * n_embd;
497
  #pragma omp parallel for schedule(static)
498
  for (int64_t i = 0; i < total; i++)
499
  hidden[i] += attn_out[i];
500
 
 
501
  if (ffn_norm_w && ffn_gate_w && ffn_up_w && ffn_down_w && ffn_dim > 0) {
502
  float *normed_ff = (float *)malloc(seq_len * n_embd * sizeof(float));
503
  float *gate_out = (float *)malloc(seq_len * ffn_dim * sizeof(float));
 
505
 
506
  if (normed_ff && gate_out && up_out) {
507
  hpc_rms_norm(hidden, ffn_norm_w, normed_ff, seq_len, n_embd, eps);
 
 
508
  hpc_matmul_graph(normed_ff, ffn_gate_w, gate_out,
509
  imp_ffn_gate, cnt_ffn_gate, seq_len, n_embd, ffn_dim, 0);
510
  hpc_matmul_graph(normed_ff, ffn_up_w, up_out,
511
  imp_ffn_up, cnt_ffn_up, seq_len, n_embd, ffn_dim, 0);
512
 
 
513
  hpc_silu(gate_out, seq_len * ffn_dim);
514
  #pragma omp parallel for schedule(static)
515
  for (int64_t i = 0; i < seq_len * ffn_dim; i++)
516
  gate_out[i] *= up_out[i];
517
 
 
518
  float *ff_out_buf = (float *)malloc(seq_len * n_embd * sizeof(float));
519
  if (ff_out_buf) {
520
  hpc_matmul_graph(gate_out, ffn_down_w, ff_out_buf,
521
  imp_ffn_down, cnt_ffn_down,
522
  seq_len, ffn_dim, n_embd, 0);
 
523
  #pragma omp parallel for schedule(static)
524
  for (int64_t i = 0; i < total; i++)
525
  hidden[i] += ff_out_buf[i];
526
  free(ff_out_buf);
527
  }
528
  }
 
529
  free(normed_ff); free(gate_out); free(up_out);
530
  }
 
531
  free(normed);
532
  free(attn_out);
533
  }