LJTSG commited on
Commit
1399c22
·
verified ·
1 Parent(s): 7df1089

Upload shaders/rmsnorm_noweight.wgsl with huggingface_hub

Browse files
Files changed (1) hide show
  1. shaders/rmsnorm_noweight.wgsl +53 -0
shaders/rmsnorm_noweight.wgsl ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // rmsnorm_noweight.wgsl — RMSNorm without learned weights.
2
+ // Used for Falcon-Mamba's B, C, dt_pre normalization.
3
+ // y[d] = x[d] / sqrt(mean(x[:]^2) + eps)
4
+ // Single workgroup, one row. D can be small (16 or 256).
5
+
6
+ const WG_SIZE: u32 = 64u;
7
+
8
+ struct Params {
9
+ D: u32,
10
+ eps: f32,
11
+ }
12
+
13
+ @group(0) @binding(0) var<storage, read_write> buf: array<f32>;
14
+
15
+ @group(1) @binding(0) var<uniform> params: Params;
16
+
17
+ var<workgroup> partial_sumsq: array<f32, 64>;
18
+
19
+ @compute @workgroup_size(WG_SIZE)
20
+ fn main(@builtin(local_invocation_id) lid: vec3<u32>) {
21
+ let tid = lid.x;
22
+
23
+ var s: f32 = 0.0;
24
+ var d: u32 = tid;
25
+ loop {
26
+ if (d >= params.D) { break; }
27
+ let v = buf[d];
28
+ s = s + v * v;
29
+ d = d + WG_SIZE;
30
+ }
31
+ partial_sumsq[tid] = s;
32
+ workgroupBarrier();
33
+
34
+ var off: u32 = WG_SIZE / 2u;
35
+ loop {
36
+ if (off == 0u) { break; }
37
+ if (tid < off) {
38
+ partial_sumsq[tid] = partial_sumsq[tid] + partial_sumsq[tid + off];
39
+ }
40
+ workgroupBarrier();
41
+ off = off >> 1u;
42
+ }
43
+
44
+ let mean_sq = partial_sumsq[0u] / f32(params.D);
45
+ let scale = 1.0 / sqrt(mean_sq + params.eps);
46
+
47
+ d = tid;
48
+ loop {
49
+ if (d >= params.D) { break; }
50
+ buf[d] = buf[d] * scale;
51
+ d = d + WG_SIZE;
52
+ }
53
+ }