Skip to content

[ExecuTorch][WebGPU] q4gsw prefill: shared-memory tiled GEMM, shape-routed (+150–303% large-K/N)#20605

Open
JulianCloudNTH wants to merge 1 commit into
gh/JulianCloudNTH/79/basefrom
gh/JulianCloudNTH/79/head
Open

[ExecuTorch][WebGPU] q4gsw prefill: shared-memory tiled GEMM, shape-routed (+150–303% large-K/N)#20605
JulianCloudNTH wants to merge 1 commit into
gh/JulianCloudNTH/79/basefrom
gh/JulianCloudNTH/79/head

Conversation

@JulianCloudNTH

@JulianCloudNTH JulianCloudNTH commented Jun 29, 2026

Copy link
Copy Markdown
Contributor

Stack from ghstack (oldest at bottom):

Up to +303% faster q4gsw prefill on the wide FFN matmuls (M4 Pro): a shared-memory-staged tiled GEMM, shape-routed so it only replaces the register-tiled GEMM where it actually wins.

Problem: the register-tiled q4gsw prefill GEMM re-dequantizes each 4-bit weight per output row, so on the wide projections (gate/up 2048->8192, down 8192->2048) the dequant + global-load traffic dominates.

Solution: add a shared-memory-staged tiled GEMM that dequantizes each weight block ONCE per K-tile into shared memory and reuses it across the whole workgroup.

Before: every M>1 prefill used the register-tiled GEMM.

After: M>1 prefill routes to the shmem GEMM when K >= 4096 || N >= 4096, and keeps the register-tiled GEMM otherwise. The route is shape-gated because an on-device A/B showed the shmem kernel wins big on large K/N (+303% at K2048xN8192, +150% at K8192xN2048) but regresses the square 2048x2048 shape (-28%).

Implementation:

  • New q4gsw_linear_gemm_shmem.wgsl: a 64-thread (8x8) workgroup computes a 32x32 output tile (4x4 per thread); per K-tile (TK=16) the workgroup cooperatively stages the input block and dequantizes the weight block into shared memory, then MACs from shmem.
  • QuantizedLinear.cpp: a third sub-route in the M>1 branch (use_shmem_gemm), dispatching div_up(M,32)*div_up(N,32) workgroups directly (no grid-stride) and throwing if the count exceeds maxComputeWorkgroupsPerDimension. The register-tiled path is unchanged.
  • Diverges from the Vulkan q4gsw GEMM, which is register-tiled with no shared memory (it reads weights from a storage buffer and relies on register tiling + the hardware cache for reuse). The existing register-tiled WebGPU GEMM mirrors that design and is kept for the square shape; this kernel adds a shared-memory stage (dequantize the weight block once per K-tile, reuse across the workgroup) that wins on wide-K/N matmuls on this target. Mirrors llama.cpp mul_mat_reg_tile.

Constraints: the dequant math, weight layout, and tolerances are unchanged; the register-tiled square-shape path is byte-identical; q4gsw's N % 8 == 0 / K % group_size == 0 constraints are unchanged.

Co-authored-with: Claude Code.

Differential Revision: D110095128

[ghstack-poisoned]
@pytorch-bot

pytorch-bot Bot commented Jun 29, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20605

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 742548a with merge base 0cef6de (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 29, 2026
@github-actions

Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant