[ExecuTorch][WebGPU] Dynamic resize hooks for rms_norm, embedding, rope#20575
[ExecuTorch][WebGPU] Dynamic resize hooks for rms_norm, embedding, rope#20575JulianCloudNTH wants to merge 3 commits into
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20575
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 1 Pending, 1 Unrelated FailureAs of commit 6c20da7 with merge base db5f1ba ( NEW FAILURE - The following job has failed:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
|
@claude review and check for any areas or opportunities for modularization |
|
Claude finished @JulianCloudNTH's task in 1m 35s —— View job Review: Dynamic resize hooks for rms_norm, embedding, rope
The change is well-scoped and the design is sound: hooks are inert on static graphs (they only re-derive from A few things worth addressing before merge: Correctness1. 2. 3. Modularization4. The "rewrite UBO + bump workgroup count" pair is repeated 4× (rms ×1, embedding ×1, rope ×2) with the identical shape: wgpuQueueWriteBuffer(g.queue(), buf, 0, &p, sizeof(p));
g.dispatch_at(idx).workgroup_count_x =
utils::compute_1d_workgroup_count(g.device(), threads, wg_size, name);A small templated helper in template <typename Params>
inline void rewrite_dispatch(
WebGPUGraph& g, WGPUBuffer buf, const Params& p,
size_t dispatch_idx, uint32_t threads, uint32_t wg_size,
const char* name) {
wgpuQueueWriteBuffer(g.queue(), buf, 0, &p, sizeof(p));
g.dispatch_at(dispatch_idx).workgroup_count_x =
utils::compute_1d_workgroup_count(g.device(), threads, wg_size, name);
}
5. Pre-existing, optional: 6. Minor: the Nothing here is blocking — items 1–3 are the ones I'd want resolved (1 is a trivial message fix; 2 is the one with real runtime-correctness implications if the graph ever decouples xq/xk). |
Stack from ghstack (oldest at bottom):
These ops baked their dispatch count, param UBO, and output dims at
build()for the max seq-len. On a dynamic-shape graph at a smaller live S they would over-dispatch and leave the output sized at the max, so the resize engine could not actually shrink them.This adds tensor resize hooks to rms_norm, embedding_q4gsw, and apply_rotary_emb. When an input is resized, each hook recomputes the live row/token count, rewrites the param UBO, updates the dispatch
workgroup_count_x, and sets the output'scur_dims. The hook is inert until a resize happens, so static graphs are byte-identical.Implementation:
rms_norm: recomputenum_rowsfrom livecur_dims; out dims follow the input.embedding_q4gsw: recomputenum_indices/total_blocks; out dims = indices dims +[embed_dim].apply_rotary_emb:add_rope_dispatchnow returns its uniform handle; one hook rewrites both the xq and xk dispatches/UBOs for the live S and sets both outputs.own_uniform_buffer(the hook rewrites it) instead of releasing it at build.Mirrors Vulkan per-op
resize_*_node(recompute sizes + dispatch each execute). No kernel/WGSL/numerics change. Behavior-neutral on static graphs (hook only fires when live dims differ from max).quantized_linearand SDPA resize hooks land in following diffs;prepackneeds none (constants are fixed-size).@exported-using-ghexport
Differential Revision: D109906096
Differential Revision: D109906096