-
Notifications
You must be signed in to change notification settings - Fork 284
float8 moe training conversion API prototype #2275
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
stack-info: PR: #2275, branch: danielvegamyhre/stack/1
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2275
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit c80dfa6 with merge base 83663b8 ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
363236e
to
e84430e
Compare
stack-info: PR: #2275, branch: danielvegamyhre/stack/1
e84430e
to
6d76d3d
Compare
stack-info: PR: #2275, branch: danielvegamyhre/stack/1
6d76d3d
to
a71744c
Compare
stack-info: PR: #2275, branch: danielvegamyhre/stack/1
a71744c
to
b72eabc
Compare
from torchao.prototype.scaled_grouped_mm import _scaled_grouped_mm | ||
|
||
|
||
class ScaledGroupedMMTensor(torch.Tensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@bdhirsh I suggested this Torch function approach to stay above autograd in order call into our autograd func, AFAIK compile's fucntion subclass support has gotten much better but just wanted to double check w/ you
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1. How's compile looking on this now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- llama4 (what I tested with in torchtitan) cannot compile yet due to backward hooks implementing auxiliary loss not playing nicely with compile. Tianyu is aware of this and working on a new approach I believe.
- In this minimal GroupedExperts example, it cannot compile with
fullgraph=True
due to _grouped_mm not being traceable yet:torch._dynamo.exc.Unsupported: .... Explanation: Dynamo does not know how to trace the builtin `torch._VariableFunctionsClass._grouped_mm.
(I believe Brian is aware of this and working on it). - The minimal example can compile with graph breaks.
this was tested with the latest nightly pytorch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(1) yep, torch_function subclass compile support has improved in the last few months (thanks to Ryan / Lazos)
(2) on the grouped_mm issue, pytorch/pytorch#153384 should help, im going to land early next week but its a small change so feel free to rebase on it if you want to test sooner
return root_module | ||
|
||
|
||
def convert_moe_to_float8_training( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can also use the quantize_
API to match the rest of torchao. We want to eventually migrate the float8 training API to that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added todo for this. Hoping to get the functionality working w/ key parallelisms first and then will look into migrating to quantize_.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated to use quantize_
and added tests. The prototype works for both forward, however, the reference bf16 moe hits an issue during backward, so I can't use it to validate the fp8 results until it is resolved.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the fp8 version training ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the fp8 version training ?
In the e2e llama4 training with this PR, yes.
However, in the unit test in this PR test_moe_training_conversion.py
where I import and use the torchtitan MoE layer for testing, I hit this issue during backward in the bf16 model (it also happens in the fp8 version, but I mention bf16 since it rules out fp8 conversion specific issues).
I hit the same issue when I just use GroupedExperts
or the whole MoE
layer.
I might try just importing the transformer itself and seeing if that works.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I managed to get it working by using a MSE loss instead of the out.sum().backward()
, but the root cause is still a mystery
stack-info: PR: #2275, branch: danielvegamyhre/stack/1
b72eabc
to
a10b3a0
Compare
stack-info: PR: #2275, branch: danielvegamyhre/stack/1
a10b3a0
to
2b4361d
Compare
stack-info: PR: #2275, branch: danielvegamyhre/stack/1 migrate to quantize and add test work on moe training test
2b4361d
to
70bc212
Compare
stack-info: PR: #2275, branch: danielvegamyhre/stack/1 migrate to quantize and add test work on moe training test
70bc212
to
0041063
Compare
stack-info: PR: #2275, branch: danielvegamyhre/stack/1 migrate to quantize and add test work on moe training test
0041063
to
cb6f8c0
Compare
stack-info: PR: #2275, branch: danielvegamyhre/stack/1 migrate to quantize and add test work on moe training test
stack-info: PR: #2275, branch: danielvegamyhre/stack/1 migrate to quantize and add test work on moe training test
stack-info: PR: #2275, branch: danielvegamyhre/stack/1 migrate to quantize and add test work on moe training test
54d1bba
to
1072d85
Compare
stack-info: PR: #2275, branch: danielvegamyhre/stack/1 migrate to quantize and add test work on moe training test
1072d85
to
afde31e
Compare
stack-info: PR: #2275, branch: danielvegamyhre/stack/1 migrate to quantize and add test work on moe training test
afde31e
to
35a23ec
Compare
stack-info: PR: #2275, branch: danielvegamyhre/stack/1 migrate to quantize and add test work on moe training test
35a23ec
to
123a4bc
Compare
stack-info: PR: #2275, branch: danielvegamyhre/stack/1 migrate to quantize and add test work on moe training test
123a4bc
to
0c06aba
Compare
stack-info: PR: #2275, branch: danielvegamyhre/stack/1 migrate to quantize and add test work on moe training test
0c06aba
to
f9a42d6
Compare
stack-info: PR: #2275, branch: danielvegamyhre/stack/1 migrate to quantize and add test work on moe training test
f9a42d6
to
0692c1f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some small comments, looks good, re-req me when ready
stack-info: PR: #2275, branch: danielvegamyhre/stack/1 migrate to quantize and add test work on moe training test
0692c1f
to
0298cf9
Compare
@drisspg Thanks, I addressed the comments |
stack-info: PR: #2275, branch: danielvegamyhre/stack/1 migrate to quantize and add test work on moe training test
0298cf9
to
c80dfa6
Compare
Confirmed failing tests are unrelated to this change (they are autoquant, sparsity, etc), see below: FAILED test/integration/test_integration.py::TestSubclass::test_int4_weight_only_quant_subclass_grouped_5_cuda - torch._inductor.exc.InductorError: AssertionError: -201885665299331/250000000000000
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
FAILED test/integration/test_integration.py::TestAutoQuant::test_autoquant_compile_09_cuda - RuntimeError: self.size(0) needs to be greater than 16, but got 1
FAILED test/integration/test_integration.py::TestAutoQuant::test_autoquant_compile_15_cuda - RuntimeError: self.size(0) needs to be greater than 16, but got 1
FAILED test/sparsity/test_sparse_api.py::TestQuantSemiSparse::test_quant_semi_sparse_compile_False - AssertionError: Tensor-likes are not close! |
Stacked PRs:
float8 moe training conversion API prototype
quantize_
will recursively swap nn.Parameter data tensors to a tensor subclass, which has an override for grouped_mm => dynamic quant + scaled grouped mm prototype. Context: see implementation of GroupedExperts here.Testing
Limitations