Add optional torchembed RoPE backend to apply_rotary_pos_emb#8052
Add optional torchembed RoPE backend to apply_rotary_pos_emb#8052py-ai-dev wants to merge 5 commits into
Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 3855dbb0c2
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
|
Hi @py-ai-dev,
Please also consider a comment from Codex bot. |
Hello @tohtana, thank you very much for your review — I've worked on them including the one from Codex and addressed them as follows:
Pushed as 5cc5e00 and 889efaf. Also, to give some concrete motivation for this integration: I benchmarked the fused kernel on an NVIDIA GB10 across typical LLM shapes:
|
- Add try/except ImportError guard for torchembed in sequence/layer.py - Dispatch to fused triton kernel from apply_rotary_pos_emb() when torchembed is installed and tensor is on CUDA - Add torchembed extras entry in setup.py - Add tests: numerical correctness vs reference, gradient flow Signed-off-by: py-ai-dev <py.oss.ml@gmail.com>
The fused path previously assumed orig_shape[-2] was the sequence length, but apply_rotary_pos_emb's contract is [seq_length, ..., dim] (seq at dim 0), so callers like fpdt_layer.py that pass [b, l, nh, hd] tensors would have the fused kernel rotate against the wrong axis while freqs_cos/sin still describe the true sequence length. Only take the fused path when freqs' sequence dim unambiguously matches t's dim 0 (and freqs carries no other non-broadcast dims), then movedim the sequence axis to the position torchembed expects before invoking the kernel. All other shapes fall back to the reference implementation. Also fixes a latent broadcasting bug in the new unit test: freqs_cos lacked the singleton heads dim needed to legally broadcast against a [seq_len, n_heads, dim] tensor, which made 24 of 35 parametrizations fail before this fix. Co-Authored-By: Claude Sonnet 5 <noreply@anthropic.com> Signed-off-by: py-ai-dev <py.oss.ml@gmail.com>
0.2.4 (the previously published version) predates the _triton module, so the fused RoPE path silently never activates. torchembed 0.3.0 (published today) is the first release that includes it. Co-Authored-By: Claude Sonnet 5 <noreply@anthropic.com> Signed-off-by: py-ai-dev <py.oss.ml@gmail.com>
fa7fde9 to
a0db9c1
Compare
Matches the pattern used by every other optional extra (triton, sd, deepcompile, etc.) instead of inlining the dependency in setup.py. Co-Authored-By: Claude Sonnet 5 <noreply@anthropic.com> Signed-off-by: py-ai-dev <py.oss.ml@gmail.com>
|
Hi @py-ai-dev, The forward of RoPE should be like: So the backward should be: However, torchembed reuses the forward function for backward. It will be: Can you confirm this and add the correctness test of gradients? If it is a real bug, can you file an issue on torchembed repository? |
Adds
torchembedas an optional fused RoPE backend fordeepspeed.sequence.layer.apply_rotary_pos_emb(), following the same pattern used in transformers and vLLM.Changes
deepspeed/sequence/layer.py: Addtry/except ImportErrorguard fortorchembed._triton.fused_rope_forward. Whentorchembedis installed, the tensor is on CUDA, androtary_dimis even, the function dispatches to the fused triton kernel instead of the PyTorch reference path.setup.py: Addtorchembedextras key (pip install deepspeed[torchembed]).tests/unit/sequence/test_apply_rotary_pos_emb.py: Numerical correctness vs PyTorch reference across seq_len (1/17/128), dim (32/64/128), and various rotary_dim. Gradient flow test.Implementation details
The torchembed kernel processes
(*leading, seq_len, dim)tensors withRotaryEmbedding(use_fused=True), applying Neox-style RoPE via triton. The helper reshapes arbitrary leading dims, calls the kernel, and restores the original shape — transparent to callers.Testing