Skip to content

[tests] help us test torch.compile() for impactful models #11430

Open
@sayakpaul

Description

@sayakpaul

#11085 added a test for checking if there's any graph break or recompilation issue for torch.compiled model.

We should add this test to the most impactful models to ensure our code is torch.compile friendly and has the potential to benefit from it. So far, we test it for FluxTransformer2DModel. Below are some models I have in mind where this test should be added:

Steps to contribute

  1. Refer to [tests] add tests to check for graph breaks and recompilation in pipelines during torch.compile() #11085 to under the changes needed in the test classes of the respective models.
  2. To be specific, always attempt to just add TorchCompileTesterMixin to the respective model testing file. Like:
    class FluxTransformerTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase):
  3. Make the changes.
  4. Run the tests locally with RUN_SLOW=1 RUN_COMPILE=1 pytest tests/<PATH_TO_TEST_FILE> -k "test_torch_compile_recompilation_and_graph_break". Make sure the machine has a GPU. Please confirm that this test is passing when opening the PR. In case it doesn't pass let us know.
  5. Open the PR and mention this issue ([tests] help us test torch.compile() for impactful models #11430). Tag @sayakpaul and @DN6 for a review. Don't hesitate to ask for guidance/help if needed.

@DN6 any model classes I am missing?

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions