Skip to content

Rewrite AuraFlowPatchEmbed.pe_selection_index_based_on_dim to be torch.compile compatible #11297

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We���ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

AstraliteHeart
Copy link
Contributor

What does this PR do?

Updates AuraFlowPatchEmbed.pe_selection_index_based_on_dim so that the AuraFlowTransformer2DModel can be fully torch.compile(d)

Old and new code generate same images but I am not an expert enough to know if this has any bad impact on performance or hidden caveats.

I've noticed some weirdness while fixing this issue:

AuraFlowTransformer2DModel in the docs has

pos_embed_max_size (int, defaults to 4096): Maximum positions to embed from the image latents.

and in the code

pos_embed_max_size: int = 1024,

but AFAIK for AuraFlow 0.3 it actually should be something like?

pos_embed_max_size=9216,
sample_size=96

Fixes # Originally filled in torch - (issue)

Before submitting

Who can review?

@cloneofsimo @sayakpaul @yiyixuxu

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@sayakpaul
Copy link
Member

pos_embed_max_size: int = 1024,

Feel free to update docs including the 0.3 note :)

Old and new code generate same images but I am not an expert enough to know if this has any bad impact on performance or hidden caveats.

I think if with and without the changes we can get same numerical outputs, that should be more than enough.

@StrongerXi, wanna investigate this?

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for your efforts here and also for testing torch.compile() with this.

Just as an FYI, we're working on #11085 to have better testing for torch.compile.

For my understanding, does this PR solve the recompilation issues for higher resolutions?

@AstraliteHeart
Copy link
Contributor Author

AstraliteHeart commented Apr 12, 2025

@sayakpaul

Feel free to update docs including the 0.3 note :)

should I update both pos_embed_max_size=9216 and sample_size=96? I think these are the right values based on VAE/patch size. I may be confused here but it looks like the model is configured for AF 0.2, so I am not sure why it works for 0.3, and if I update them I may break 0.2?

So maybe we should have some detection or at least update docs? I know some people prefer 0.2 right now.

I think if with and without the changes we can get same numerical outputs, that should be more than enough.

Are the current test sufficient to confirm or should I add something extra first? Any comparisons I should run on my end?

For my understanding, does this PR solve the recompilation issues for higher resolutions?

Correct, with this change so far I am not getting any more recompilations with AF loaded via GGUF.

@anijain2305
Copy link
Contributor

Thanks for taking so much effort to enable torch.compile here. This workstream is truly amazing!

Cc @bobrenjc93 @laithsakka for dynamic shape guards related rewrite review. Might be a good rewrite to document in the dynamic shape manual.

@AstraliteHeart
Copy link
Contributor Author

import torch
import rich

class TestPosEmbed:
    def __init__(self, pos_embed_max_size: int = 9216, embed_dim: int = 768, device='cpu'):
        """
        Initialize with a dummy positional embedding parameter.
        pos_embed_max_size must be a perfect square (here, 9216 yields a 96x96 grid).
        """
        self.pos_embed = torch.randn(1, pos_embed_max_size, embed_dim, device=device)

    def pe_selection_index_based_on_dim_orig(self, h_p: int, w_p: int) -> torch.Tensor:
        """
        Original implementation using torch.narrow.
        h_p: number of patches in height.
        w_p: number of patches in width.
        """
        total_pe = self.pos_embed.shape[1]
        grid_size = int(total_pe ** 0.5)
        assert grid_size * grid_size == total_pe, "pos_embed_max_size must be a perfect square"

        # Create a grid of indices.
        original_pe_indexes = torch.arange(total_pe, device=self.pos_embed.device).view(grid_size, grid_size)
        
        # Compute starting indices using Python arithmetic.
        starth = (grid_size - h_p) // 2
        startw = (grid_size - w_p) // 2
        
        # Use narrow to select the center region.
        narrowed = original_pe_indexes.narrow(0, starth, h_p)
        narrowed = narrowed.narrow(1, startw, w_p)
        return narrowed.flatten()

    def pe_selection_index_based_on_dim_new(self, h_p: int, w_p: int) -> torch.Tensor:
        """
        New implementation using inlined slicing arithmetic.
        h_p: number of patches in height.
        w_p: number of patches in width.
        """
        total_pe = int(self.pos_embed.shape[1])
        grid_size = int(total_pe ** 0.5)
        assert grid_size * grid_size == total_pe, "pos_embed_max_size must be a perfect square"

        # Compute starting indices (using Python ints).
        start_h = (grid_size - h_p) // 2
        start_w = (grid_size - w_p) // 2

        # Create a grid of indices.
        pe_grid = torch.arange(total_pe, device=self.pos_embed.device).view(grid_size, grid_size)
        # Select the central region and flatten.
        selected_pe = pe_grid[start_h: start_h + h_p, start_w: start_w + w_p].flatten()
        return selected_pe

def run_tests():
    torch.manual_seed(42)
    
    patch_size = 16
    # Use pos_embed_max_size = 9216 -> a 96x96 grid.
    pos_embed_max_size = 9216
    embed_dim = 768

    tester = TestPosEmbed(pos_embed_max_size=pos_embed_max_size, embed_dim=embed_dim, device='cpu')

    resolutions = [
        (224, 224),
        (224, 256),
        (256, 224),
        (256, 256),
        (256, 320),
        (320, 256),
        (384, 384),
        (384, 512),
        (512, 384),
        (512, 512),
        (640, 640),
        (768, 768),
        (1024, 1024),
        (1280, 1280),
        (1536, 1536),
        (1536, 1024),
        (1024, 1536),
        (1280, 768),
        (1536, 768),
        (768, 1536),
        (1536, 896),
        (896, 1536),
    ]

    for (height, width) in resolutions:
        h_p = height // patch_size
        w_p = width // patch_size

        orig_indices = tester.pe_selection_index_based_on_dim_orig(h_p, w_p)
        new_indices = tester.pe_selection_index_based_on_dim_new(h_p, w_p)
        
        match = torch.equal(orig_indices, new_indices)
        
        match_color = "green" if match else "red"
        rich.print(f"[cyan]Resolution: {height} x {width}[/cyan] | [yellow]Patch grid: {h_p} x {w_p}[/yellow] | Match: [{match_color}]{match}[/{match_color}]")

if __name__ == "__main__":
    run_tests()
Screenshot 2025-04-11 at 10 04 10 PM
@sayakpaul
Copy link
Member

AFK currently. Please allow me some time to get back to you

@sayakpaul
Copy link
Member

The tests in #11297 (comment) are sufficient. Thanks!

should I update both pos_embed_max_size=9216 and sample_size=96? I think these are the right values based on VAE/patch size. I may be confused here but it looks like the model is configured for AF 0.2, so I am not sure why it works for 0.3, and if I update them I may break 0.2?

Well, when from_pretrained() is called the configs are passed from the config.json. Similarly, for from_single_file(), these configs are automatically constructed from the state dict and also the equivalent diffusers repository config.json. So, just the updates to the docs are fine IMO.

Does this answer your question?

@sayakpaul
Copy link
Member

sayakpaul commented Apr 14, 2025

What I would also do is the following (perhaps in a separate PR):

Add a new test class / method in https://github.com/huggingface/diffusers/blob/main/tests/pipelines/aura_flow/test_pipeline_aura_flow.py that checks no recompilation is triggered when we go for higher resolutions. I believe we won't need a pre-trained checkpoint for this. We could use the dummy model from

transformer = AuraFlowTransformer2DModel(
and write our test case accordingly.

I can work on this and when ready ask for a review you and @anijain2305. WDYT? LM also know if this test case makes sense.

Also, @AstraliteHeart if possible, it would be great to update the docs of AuraFlow with a section on no recompilations when using torch.compile() on higher resolutions. Alternatively, if you update the PR description with a code snippet, I can do it.

@sayakpaul sayakpaul requested a review from yiyixuxu April 14, 2025 04:49
@sayakpaul
Copy link
Member

@yiyixuxu could also review this PR? This helps to make AuraFlow better compatible with torch.compile(), especially on higher resolutions such that it doesn't trigger recompilations.

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

Copy link

@bobrenjc93 bobrenjc93 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks!

cc @laithsakka you probably want to include a section on reducing recompiles by rewriting tensor indexing operations like this PR in your recompilations guide. Also you should probably either write a separate OSS version or publish the internal version of https://docs.google.com/document/d/1QgQLVBNKSYMeNbG5sEz_pwffL9PlKRHKXMI4ft3H9gA/edit?tab=t.0#heading=h.a37bpg8ay2f4

@sayakpaul
Copy link
Member

@AstraliteHeart just waiting for you to provide some confirmations to my comments above when you have time. We will then merge :)

@AstraliteHeart
Copy link
Contributor Author

@sayakpaul

Feel free to update docs including the 0.3 note :)

Updated the docs to reflect correct default values, I don't think we need 0.3 note, I assumed the values are not read from the model which was incorrect (see below).

Well, when from_pretrained() is called the configs are passed from the config.json

rechecked the values populated from the config and you are correct

I can work on this and when ready ask for a review you and @anijain2305. WDYT? LM also know if this test case makes sense.

I would never say "no" to someone volunteering to write test but lmk if you want me to work on that.

Also, @AstraliteHeart if possible, it would be great to update the docs of AuraFlow with a section on no recompilations when using torch.compile() on higher resolutions.

For the compilation example, I believe the only special thing right now is torch.fx.experimental._config.use_duck_shape and all resolutions just work (with my fix).

torch.fx.experimental._config.use_duck_shape = False

transformer = AuraFlowTransformer2DModel.from_single_file(
    "https://huggingface.co/city96/AuraFlow-v0.3-gguf/blob/main/aura_flow_0.3-Q2_K.gguf",
    quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
    torch_dtype=torch.bfloat16,
)
pipeline = AuraFlowPipeline.from_pretrained(
    "fal/AuraFlow-v0.3",
    torch_dtype=torch.bfloat16,
    transformer=transformer,
).to("cuda")

pipeline.transformer = torch.compile(pipeline.transformer, fullgraph=True, dynamic=True)

@AstraliteHeart just waiting for you to provide some confirmations to my comments above when you have time. We will then merge :)

Happy to get this merged (but please check the doc update just in case).

@yiyixuxu @bobrenjc93 thank you for having a look.

@sayakpaul
Copy link
Member

Thank you!

Can I push directly to your branch to include the snippet in #11297 (comment) in the AuraFlow pipeline docs?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment