Skip to content

Add FluxPAGPipeline with support for PAG #11510

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from

Conversation

tongyu0924
Copy link
Contributor

What does this PR do?

This PR adds support for Perturbed Attention Guidance (PAG) to the FluxPipeline

Fixes #11488

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@DN6
Copy link
Collaborator

DN6 commented May 22, 2025

@tongyu0924 Is this ready for review? If so, could you move the pipeline under src/diffusers/pipelines/pag please and make the changes needed to import the pipeline?

@tongyu0924
Copy link
Contributor Author

@tongyu0924 Is this ready for review? If so, could you move the pipeline under src/diffusers/pipelines/pag please and make the changes needed to import the pipeline?

Done! The pipeline is now under src/diffusers/pipelines/pag. It's ready for review

@DN6
Copy link
Collaborator

DN6 commented May 26, 2025

Thank you @tongyu0924 👍🏽! Could we add the pipeline to the necessary init files

PAG Module:
https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pag/__init__.py

Pipelines Module:

_import_structure["pag"].extend(

Diffusers main init

"AnimateDiffPAGPipeline",

And then could you please add a fast test for the pipeline, similar to how it has been done here

class StableDiffusion3PAGPipelineFastTests(unittest.TestCase, PipelineTesterMixin):

@tongyu0924
Copy link
Contributor Author

Thank you @tongyu0924 👍🏽! Could we add the pipeline to the necessary init files

PAG Module: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pag/__init__.py

Pipelines Module:

_import_structure["pag"].extend(

Diffusers main init

"AnimateDiffPAGPipeline",

And then could you please add a fast test for the pipeline, similar to how it has been done here

class StableDiffusion3PAGPipelineFastTests(unittest.TestCase, PipelineTesterMixin):

I've added the pipeline to the necessary __init__.py files and added a fast test following the structure in test_pag_sd3.py.

Copy link
Collaborator

@DN6 DN6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay here @tongyu0924. I think we just need define the correct PAG Attn Processors for this pipeline and we should be good to go 👍🏽

Thanks for your patience.

Comment on lines 844 to 858
# do_true_pag = true_pag > 0
# (
# prompt_embeds,
# pooled_prompt_embeds,
# text_ids,
# ) = self.encode_prompt(
# prompt=prompt,
# prompt_2=prompt_2,
# prompt_embeds=prompt_embeds,
# pooled_prompt_embeds=pooled_prompt_embeds,
# device=device,
# num_images_per_prompt=num_images_per_prompt,
# max_sequence_length=max_sequence_length,
# lora_scale=lora_scale,
# )
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we remove this commented section please.

timesteps = scheduler.timesteps
return timesteps, num_inference_steps

class PAGIdentitySelfAttnProcessor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we would need to define a PAGFluxAttnProcessor_2_0 similar to how it is done for SD3's JointAttnProcessor

class PAGJointAttnProcessor2_0:

And since the pipeline supports true CFG we would also need to add a PAGCFGFluxAttnProcessor_2_0

class PAGCFGJointAttnProcessor2_0:

@nitinmukesh
Copy link

Please update the following if possible

diffusers/src/diffusers/pipelines/pag/pipeline_pag_flux.py

EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import FluxPipeline

    >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
    >>> pipe.to("cuda")
    >>> prompt = "A cat holding a sign that says hello world"
    >>> # Depending on the variant being used, the pipeline call will slightly vary.
    >>> # Refer to the pipeline documentation for more details.
    >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
    >>> image.save("flux.png")
    ```

"""

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
5 participants