[ExecuTorch][WebGPU] q4gsw prefill: shared-memory tiled GEMM, shape-routed (+150–303% large-K/N)#20605
[ExecuTorch][WebGPU] q4gsw prefill: shared-memory tiled GEMM, shape-routed (+150–303% large-K/N)#20605JulianCloudNTH wants to merge 1 commit into
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20605
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 742548a with merge base 0cef6de ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
Stack from ghstack (oldest at bottom):
Up to +303% faster q4gsw prefill on the wide FFN matmuls (M4 Pro): a shared-memory-staged tiled GEMM, shape-routed so it only replaces the register-tiled GEMM where it actually wins.
Problem: the register-tiled q4gsw prefill GEMM re-dequantizes each 4-bit weight per output row, so on the wide projections (gate/up
2048->8192, down8192->2048) the dequant + global-load traffic dominates.Solution: add a shared-memory-staged tiled GEMM that dequantizes each weight block ONCE per K-tile into shared memory and reuses it across the whole workgroup.
Before: every M>1 prefill used the register-tiled GEMM.
After: M>1 prefill routes to the shmem GEMM when
K >= 4096 || N >= 4096, and keeps the register-tiled GEMM otherwise. The route is shape-gated because an on-device A/B showed the shmem kernel wins big on largeK/N(+303% atK2048xN8192, +150% atK8192xN2048) but regresses the square2048x2048shape (-28%).Implementation:
q4gsw_linear_gemm_shmem.wgsl: a 64-thread (8x8) workgroup computes a 32x32 output tile (4x4 per thread); per K-tile (TK=16) the workgroup cooperatively stages the input block and dequantizes the weight block into shared memory, then MACs from shmem.QuantizedLinear.cpp: a third sub-route in the M>1 branch (use_shmem_gemm), dispatchingdiv_up(M,32)*div_up(N,32)workgroups directly (no grid-stride) and throwing if the count exceedsmaxComputeWorkgroupsPerDimension. The register-tiled path is unchanged.K/Nmatmuls on this target. Mirrors llama.cppmul_mat_reg_tile.Constraints: the dequant math, weight layout, and tolerances are unchanged; the register-tiled square-shape path is byte-identical; q4gsw's
N % 8 == 0/K % group_size == 0constraints are unchanged.Co-authored-with: Claude Code.
Differential Revision: D110095128