Skip to content

float8 moe training conversion API prototype #2275

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 10, 2025

Conversation

danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented May 30, 2025

Stacked PRs:


float8 moe training conversion API prototype

  • quantize_ will recursively swap nn.Parameter data tensors to a tensor subclass, which has an override for grouped_mm => dynamic quant + scaled grouped mm prototype. Context: see implementation of GroupedExperts here.
  • Prototype name changed from "scaled_grouped_mm" to "moe_training" since the scaled grouped mm is now just a building block in the larger MoE training prototype.

Testing

  • Added unit tests using the torchtitan Llama4 MoE layer to perform a forward + backward pass, then validating reasonable SQNR for model outputs, input gradients and parameter gradients when comparing the quantized vs baseline bf16 model.
  • Tested E2E training with torchtitan (see PR with MoE conversion API) and confirmed single GPU training works as expected with Llama4.

Limitations

  • Only supports single GPU training. I tried with FSDP=2 and hit this issue which seems to be related to a known issue that is being addressed.
  • Only performs grouped_mm override for routed experts (see condition here). For shared experts, I'll need to update the torchao prototype to support 3d A tensor.
danielvegamyhre added a commit that referenced this pull request May 30, 2025
stack-info: PR: #2275, branch: danielvegamyhre/stack/1
Copy link

pytorch-bot bot commented May 30, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2275

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit c80dfa6 with merge base 83663b8 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/1 branch from 363236e to e84430e Compare May 30, 2025 02:38
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 30, 2025
danielvegamyhre added a commit that referenced this pull request May 30, 2025
stack-info: PR: #2275, branch: danielvegamyhre/stack/1
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/1 branch from e84430e to 6d76d3d Compare May 30, 2025 02:39
danielvegamyhre added a commit that referenced this pull request May 30, 2025
stack-info: PR: #2275, branch: danielvegamyhre/stack/1
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/1 branch from 6d76d3d to a71744c Compare May 30, 2025 03:17
@danielvegamyhre danielvegamyhre changed the title float8 moe training conversion API prototype May 30, 2025
@danielvegamyhre danielvegamyhre added the topic: not user facing Use this tag if you don't want this PR to show up in release notes label May 30, 2025
danielvegamyhre added a commit that referenced this pull request May 30, 2025
stack-info: PR: #2275, branch: danielvegamyhre/stack/1
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/1 branch from a71744c to b72eabc Compare May 30, 2025 04:02
@danielvegamyhre danielvegamyhre changed the title float8 moe training conversion API prototype (single GPU training) May 30, 2025
@danielvegamyhre
Copy link
Contributor Author

@drisspg @vkuzo for review when you have a chance

from torchao.prototype.scaled_grouped_mm import _scaled_grouped_mm


class ScaledGroupedMMTensor(torch.Tensor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bdhirsh I suggested this Torch function approach to stay above autograd in order call into our autograd func, AFAIK compile's fucntion subclass support has gotten much better but just wanted to double check w/ you

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1. How's compile looking on this now?

Copy link
Contributor Author

@danielvegamyhre danielvegamyhre May 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • llama4 (what I tested with in torchtitan) cannot compile yet due to backward hooks implementing auxiliary loss not playing nicely with compile. Tianyu is aware of this and working on a new approach I believe.
  • In this minimal GroupedExperts example, it cannot compile with fullgraph=True due to _grouped_mm not being traceable yet: torch._dynamo.exc.Unsupported: .... Explanation: Dynamo does not know how to trace the builtin `torch._VariableFunctionsClass._grouped_mm. (I believe Brian is aware of this and working on it).
  • The minimal example can compile with graph breaks.

this was tested with the latest nightly pytorch.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(1) yep, torch_function subclass compile support has improved in the last few months (thanks to Ryan / Lazos)

(2) on the grouped_mm issue, pytorch/pytorch#153384 should help, im going to land early next week but its a small change so feel free to rebase on it if you want to test sooner

return root_module


def convert_moe_to_float8_training(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can also use the quantize_ API to match the rest of torchao. We want to eventually migrate the float8 training API to that.

Copy link
Contributor Author

@danielvegamyhre danielvegamyhre May 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added todo for this. Hoping to get the functionality working w/ key parallelisms first and then will look into migrating to quantize_.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to use quantize_ and added tests. The prototype works for both forward, however, the reference bf16 moe hits an issue during backward, so I can't use it to validate the fp8 results until it is resolved.

Copy link
Contributor

@drisspg drisspg Jun 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the fp8 version training ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the fp8 version training ?

In the e2e llama4 training with this PR, yes.

However, in the unit test in this PR test_moe_training_conversion.py where I import and use the torchtitan MoE layer for testing, I hit this issue during backward in the bf16 model (it also happens in the fp8 version, but I mention bf16 since it rules out fp8 conversion specific issues).

I hit the same issue when I just use GroupedExperts or the whole MoE layer.

I might try just importing the transformer itself and seeing if that works.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I managed to get it working by using a MSE loss instead of the out.sum().backward(), but the root cause is still a mystery

danielvegamyhre added a commit that referenced this pull request May 30, 2025
stack-info: PR: #2275, branch: danielvegamyhre/stack/1
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/1 branch from b72eabc to a10b3a0 Compare May 30, 2025 18:03
danielvegamyhre added a commit that referenced this pull request May 30, 2025
stack-info: PR: #2275, branch: danielvegamyhre/stack/1
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/1 branch from a10b3a0 to 2b4361d Compare May 30, 2025 18:04
danielvegamyhre added a commit that referenced this pull request Jun 3, 2025
stack-info: PR: #2275, branch: danielvegamyhre/stack/1

migrate to quantize and add test

work on moe training test
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/1 branch from 2b4361d to 70bc212 Compare June 3, 2025 20:01
@danielvegamyhre danielvegamyhre marked this pull request as draft June 3, 2025 20:02
danielvegamyhre added a commit that referenced this pull request Jun 3, 2025
stack-info: PR: #2275, branch: danielvegamyhre/stack/1

migrate to quantize and add test

work on moe training test
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/1 branch from 70bc212 to 0041063 Compare June 3, 2025 20:03
danielvegamyhre added a commit that referenced this pull request Jun 3, 2025
stack-info: PR: #2275, branch: danielvegamyhre/stack/1

migrate to quantize and add test

work on moe training test
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/1 branch from 0041063 to cb6f8c0 Compare June 3, 2025 20:25
danielvegamyhre added a commit that referenced this pull request Jun 5, 2025
stack-info: PR: #2275, branch: danielvegamyhre/stack/1

migrate to quantize and add test

work on moe training test
danielvegamyhre added a commit that referenced this pull request Jun 5, 2025
stack-info: PR: #2275, branch: danielvegamyhre/stack/1

migrate to quantize and add test

work on moe training test
danielvegamyhre added a commit that referenced this pull request Jun 5, 2025
stack-info: PR: #2275, branch: danielvegamyhre/stack/1

migrate to quantize and add test

work on moe training test
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/1 branch from 54d1bba to 1072d85 Compare June 5, 2025 21:00
danielvegamyhre added a commit that referenced this pull request Jun 9, 2025
stack-info: PR: #2275, branch: danielvegamyhre/stack/1

migrate to quantize and add test

work on moe training test
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/1 branch from 1072d85 to afde31e Compare June 9, 2025 23:46
danielvegamyhre added a commit that referenced this pull request Jun 10, 2025
stack-info: PR: #2275, branch: danielvegamyhre/stack/1

migrate to quantize and add test

work on moe training test
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/1 branch from afde31e to 35a23ec Compare June 10, 2025 01:45
danielvegamyhre added a commit that referenced this pull request Jun 10, 2025
stack-info: PR: #2275, branch: danielvegamyhre/stack/1

migrate to quantize and add test

work on moe training test
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/1 branch from 35a23ec to 123a4bc Compare June 10, 2025 01:55
@danielvegamyhre danielvegamyhre marked this pull request as ready for review June 10, 2025 01:55
@danielvegamyhre
Copy link
Contributor Author

@vkuzo @drisspg this is ready for another look, I (1) migrated to the quantize_ API, (2) added unit tests using the torchtitan Llama4 MoE implementation, and (3) validated E2E training with torchtitan again as well.

danielvegamyhre added a commit that referenced this pull request Jun 10, 2025
stack-info: PR: #2275, branch: danielvegamyhre/stack/1

migrate to quantize and add test

work on moe training test
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/1 branch from 123a4bc to 0c06aba Compare June 10, 2025 04:14
danielvegamyhre added a commit that referenced this pull request Jun 10, 2025
stack-info: PR: #2275, branch: danielvegamyhre/stack/1

migrate to quantize and add test

work on moe training test
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/1 branch from 0c06aba to f9a42d6 Compare June 10, 2025 14:11
danielvegamyhre added a commit that referenced this pull request Jun 10, 2025
stack-info: PR: #2275, branch: danielvegamyhre/stack/1

migrate to quantize and add test

work on moe training test
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/1 branch from f9a42d6 to 0692c1f Compare June 10, 2025 14:24
Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some small comments, looks good, re-req me when ready

danielvegamyhre added a commit that referenced this pull request Jun 10, 2025
stack-info: PR: #2275, branch: danielvegamyhre/stack/1

migrate to quantize and add test

work on moe training test
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/1 branch from 0692c1f to 0298cf9 Compare June 10, 2025 16:41
@danielvegamyhre
Copy link
Contributor Author

Some small comments, looks good, re-req me when ready

@drisspg Thanks, I addressed the comments

stack-info: PR: #2275, branch: danielvegamyhre/stack/1

migrate to quantize and add test

work on moe training test
@danielvegamyhre
Copy link
Contributor Author

Confirmed failing tests are unrelated to this change (they are autoquant, sparsity, etc), see below:

FAILED test/integration/test_integration.py::TestSubclass::test_int4_weight_only_quant_subclass_grouped_5_cuda - torch._inductor.exc.InductorError: AssertionError: -201885665299331/250000000000000

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
FAILED test/integration/test_integration.py::TestAutoQuant::test_autoquant_compile_09_cuda - RuntimeError: self.size(0) needs to be greater than 16, but got 1
FAILED test/integration/test_integration.py::TestAutoQuant::test_autoquant_compile_15_cuda - RuntimeError: self.size(0) needs to be greater than 16, but got 1
FAILED test/sparsity/test_sparse_api.py::TestQuantSemiSparse::test_quant_semi_sparse_compile_False - AssertionError: Tensor-likes are not close!
@danielvegamyhre danielvegamyhre merged commit b6bb7dc into main Jun 10, 2025
18 of 19 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: not user facing Use this tag if you don't want this PR to show up in release notes
5 participants