Skip to content

Add optional torchembed RoPE backend to apply_rotary_pos_emb#8052

Open
py-ai-dev wants to merge 5 commits into
deepspeedai:masterfrom
py-ai-dev:add-torchembed-rope-backend
Open

Add optional torchembed RoPE backend to apply_rotary_pos_emb#8052
py-ai-dev wants to merge 5 commits into
deepspeedai:masterfrom
py-ai-dev:add-torchembed-rope-backend

Conversation

@py-ai-dev

Copy link
Copy Markdown

Adds torchembed as an optional fused RoPE backend for deepspeed.sequence.layer.apply_rotary_pos_emb(), following the same pattern used in transformers and vLLM.

Changes

  • deepspeed/sequence/layer.py: Add try/except ImportError guard for torchembed._triton.fused_rope_forward. When torchembed is installed, the tensor is on CUDA, and rotary_dim is even, the function dispatches to the fused triton kernel instead of the PyTorch reference path.

  • setup.py: Add torchembed extras key (pip install deepspeed[torchembed]).

  • tests/unit/sequence/test_apply_rotary_pos_emb.py: Numerical correctness vs PyTorch reference across seq_len (1/17/128), dim (32/64/128), and various rotary_dim. Gradient flow test.

Implementation details

The torchembed kernel processes (*leading, seq_len, dim) tensors with RotaryEmbedding(use_fused=True), applying Neox-style RoPE via triton. The helper reshapes arbitrary leading dims, calls the kernel, and restores the original shape — transparent to callers.

Testing

pytest tests/unit/sequence/test_apply_rotary_pos_emb.py -v

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 3855dbb0c2

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread deepspeed/sequence/layer.py Outdated
@tohtana

tohtana commented Jul 1, 2026

Copy link
Copy Markdown
Collaborator

Hi @py-ai-dev,
Thank you for submitting a PR! I think this needs some changes.

  • setup.py adds the torchembed extra, but the current PyPI torchembed package does not provide torchembed._triton.
  • The added test file fails for most cases. The test creates t with shape [seq_len, 4, dim] and frequency tensors with shape [seq_len, rot_dim]. Those do not broadcast for seq_len=17, seq_len=128, or the gradient-flow case.

Please also consider a comment from Codex bot.

@py-ai-dev

py-ai-dev commented Jul 1, 2026

Copy link
Copy Markdown
Author

Hi @py-ai-dev, Thank you for submitting a PR! I think this needs some changes.

  • setup.py adds the torchembed extra, but the current PyPI torchembed package does not provide torchembed._triton.
  • The added test file fails for most cases. The test creates t with shape [seq_len, 4, dim] and frequency tensors with shape [seq_len, rot_dim]. Those do not broadcast for seq_len=17, seq_len=128, or the gradient-flow case.

Please also consider a comment from Codex bot.

Hello @tohtana, thank you very much for your review — I've worked on them including the one from Codex and addressed them as follows:

  1. torchembed._triton not on PyPI: correct, the published 0.2.4 predated the _triton module. I just cut and published torchembed 0.3.0, which includes it. Updated setup.py to require torchembed>=0.3.0 so the extra can't resolve to a version missing the fused kernel.

  2. Test broadcasting failures: confirmed — freqs_cos/freqs_sin were missing the singleton heads dim needed to broadcast against t of shape [seq_len, n_heads, dim], so most parametrizations (seq_len != 4) raised a RuntimeError instead of running. Fixed by unsqueezing that dim in the test; all 35 cases now pass.

  3. Codex bot (fused-path sequence dim): confirmed as a real bug — the old code assumed orig_shape[-2] was the sequence length, which breaks for callers like fpdt_layer.py that pass [b, l, nh, hd] tensors (where -2 is the heads dim, not seq). The fused path now only activates when freqs_cos's sequence dim unambiguously matches t's dim 0 (with no other non-broadcast dims in freqs), and correctly moves the sequence axis into the position torchembed's kernel expects. Any shape that doesn't match that contract falls back to the reference implementation.

Pushed as 5cc5e00 and 889efaf.


Also, to give some concrete motivation for this integration: I benchmarked the fused kernel on an NVIDIA GB10 across typical LLM shapes:

Shape (B,H,S,D) PyTorch torch.compile Triton Speedup
(1,32,2048,128) 1.40ms 0.61ms 0.34ms 4.1x
(1,32,4096,128) 2.95ms 1.21ms 0.63ms 4.7x
(1,32,8192,128) 5.94ms 2.47ms 1.29ms 4.6x

torch.compile reduces kernel launch overhead but can't eliminate the intermediate tensor allocations from chunk/cat in the rotate-half step. The fused Triton kernel reads each element exactly once and writes it once, with zero intermediates — a ~2x win over torch.compile and ~4x over pure PyTorch. The integration stays fully optional: torchembed is gated behind a try/except and falls back to the existing PyTorch path automatically for anyone who doesn't install it.

py-ai-dev and others added 4 commits July 1, 2026 08:36
- Add try/except ImportError guard for torchembed in sequence/layer.py
- Dispatch to fused triton kernel from apply_rotary_pos_emb() when
  torchembed is installed and tensor is on CUDA
- Add torchembed extras entry in setup.py
- Add tests: numerical correctness vs reference, gradient flow

Signed-off-by: py-ai-dev <py.oss.ml@gmail.com>
The fused path previously assumed orig_shape[-2] was the sequence
length, but apply_rotary_pos_emb's contract is [seq_length, ..., dim]
(seq at dim 0), so callers like fpdt_layer.py that pass [b, l, nh, hd]
tensors would have the fused kernel rotate against the wrong axis
while freqs_cos/sin still describe the true sequence length.

Only take the fused path when freqs' sequence dim unambiguously
matches t's dim 0 (and freqs carries no other non-broadcast dims),
then movedim the sequence axis to the position torchembed expects
before invoking the kernel. All other shapes fall back to the
reference implementation.

Also fixes a latent broadcasting bug in the new unit test: freqs_cos
lacked the singleton heads dim needed to legally broadcast against a
[seq_len, n_heads, dim] tensor, which made 24 of 35 parametrizations
fail before this fix.

Co-Authored-By: Claude Sonnet 5 <noreply@anthropic.com>
Signed-off-by: py-ai-dev <py.oss.ml@gmail.com>
0.2.4 (the previously published version) predates the _triton
module, so the fused RoPE path silently never activates. torchembed
0.3.0 (published today) is the first release that includes it.

Co-Authored-By: Claude Sonnet 5 <noreply@anthropic.com>
Signed-off-by: py-ai-dev <py.oss.ml@gmail.com>
@py-ai-dev py-ai-dev force-pushed the add-torchembed-rope-backend branch from fa7fde9 to a0db9c1 Compare July 1, 2026 15:36
Matches the pattern used by every other optional extra (triton, sd,
deepcompile, etc.) instead of inlining the dependency in setup.py.

Co-Authored-By: Claude Sonnet 5 <noreply@anthropic.com>
Signed-off-by: py-ai-dev <py.oss.ml@gmail.com>
@tohtana

tohtana commented Jul 1, 2026

Copy link
Copy Markdown
Collaborator

Hi @py-ai-dev,
Thank you for the quick fixes! It looks the issues I mentioned are all addressed. However, I also found an issue on torchembed side.

The forward of RoPE should be like:

y0 = x0 * c - x1 * s
y1 = x1 * c + x0 * s

So the backward should be:

dx0 = g0 * c + g1 * s
dx1 = -g0 * s + g1 * c

However, torchembed reuses the forward function for backward. It will be:

g0 * c - g1 * s
g1 * c + g0 * s

Can you confirm this and add the correctness test of gradients? If it is a real bug, can you file an issue on torchembed repository?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

2 participants