Skip to content

Chroma as a FLUX.1 variant #11566

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft

Conversation

hameerabbasi
Copy link
Contributor

What does this PR do?

Fixes # (issue)

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@Vargol
Copy link

Vargol commented May 17, 2025

Hi, is there a specific model repo this works with. I'm getting key errors loading the models from lodestones/Chroma and silveroxides/Chroma-GGUF

PR installed by using

$ pip uninstall diffusers
$ pip install git+https://github.com/huggingface/diffusers.git@refs/pull/11566/merge
$ grep chroma  lib/python3.11/site-packages/diffusers/models/transformers/transformer_flux.py 
INVALID_VARIANT_ERRMSG = "`variant` must be `'flux' or `'chroma'`."
        elif variant == "chroma":
        elif variant == "chroma":
        elif variant == "chroma":
        norm_out_cls = AdaLayerNormContinuous if variant != "chroma" else AdaLayerNormContinuousPruned
        is_chroma = isinstance(self.time_text_embed, CombinedTimestepTextProjChromaEmbeddings)
        if not is_chroma:
            if is_chroma:
            if is_chroma:
        if is_chroma:

I'm loading the models using

#ckpt_path="/Volumes/SSD2TB/AI/caches/models/chroma-unlocked-v29-Q8_0.gguf"
ckpt_path="/Volumes/SSD2TB/AI/caches/models/chroma-unlocked-v29.5.safetensors"

transformer = FluxTransformer2DModel.from_single_file(
    ckpt_path,
#   quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
    torch_dtype=torch.bfloat16,
    variant="chroma"
)

The full error is....

Traceback (most recent call last):
  File "/Volumes/SSD2TB/AI/Diffusers/gguf_chroma.py", line 48, in <module>
    transformer = FluxTransformer2DModel.from_single_file(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/huggingface_hub/utils/_validators.py", line 114, in _inner_fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/diffusers/loaders/single_file_model.py", line 370, in from_single_file
    diffusers_format_checkpoint = checkpoint_mapping_fn(
                                  ^^^^^^^^^^^^^^^^^^^^^^
  File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/diffusers/loaders/single_file_utils.py", line 2157, in convert_flux_transformer_checkpoint_to_diffusers
    converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop(
                                                                                ^^^^^^^^^^^^^^^
KeyError: 'time_in.in_layer.weight'
@nitinmukesh
Copy link

@Vargol

This is still Draft.

@DN6
Copy link
Collaborator

DN6 commented May 19, 2025

Hi @hameerabbasi Thank you so much for your PR and for taking up this integration. 👍🏽

It's looking good so far but would it be possible to define a dedicated transformer model for Chroma. e.g ChromaTransformer2DModel.

Given that the model is still being actively trained, there are potentially a lot of changes that could be introduced. It would be better to isolate them rather than add additional conditional logic to Flux.

@hameerabbasi
Copy link
Contributor Author

hameerabbasi commented May 19, 2025

Hi @hameerabbasi Thank you so much for your PR and for taking up this integration. 👍🏽

Thank you for the awesome work on diffusers!

It's looking good so far but would it be possible to define a dedicated transformer model for Chroma. e.g ChromaTransformer2DModel.

It certainly would, but I have my reasons for choosing to integrate it directly into FluxTransformer2DModel:

  1. We get automatic support for LoRA, ControlNet, IP-Adapter, ...
  2. Little to no code duplication, which is always nice. Minimal diffs are easier to review.

Given that the model is still being actively trained, there are potentially a lot of changes that could be introduced. It would be better to isolate them rather than add additional conditional logic to Flux.

The weights are certainly in flux (pun intended), but the architecture has been fixed for quite some time (maybe @lodestone-rock can confirm). Since the code encodes the architecture rather than weights, I chose to do it this way.

Additionally, back-compat is untouched -- so Flux itself will be unaffected by this change.

@lodestone-rock
Copy link

@hameerabbasi @DN6
I probably won't introduce another arch changes, it's already quite optimal in this current state.
so this pr should be sufficient

@Ednaordinary
Copy link
Contributor

Ednaordinary commented May 19, 2025

Here's a seemingly working from_single_file function :) Though I still cannot get it to run (errors in time_text_embed)

I also agree with DN6 that introducing a separate class is cleaner than using variant (no other class does this and variant normally implies a dtype option like fp16 or fp32 when downloading from the hub)

convert_chroma_transformer_checkpoint_to_diffusers

def convert_chroma_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
    converted_state_dict = {}
    keys = list(checkpoint.keys())

    for k in keys:
        if "model.diffusion_model." in k:
            checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)

    num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1  # noqa: C401
    num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1  # noqa: C401
    num_guidance_layers = list(set(int(k.split(".", 3)[2]) for k in checkpoint if "distilled_guidance_layer.layers." in k))[-1] + 1  # noqa: C401
    mlp_ratio = 4.0
    inner_dim = 3072

    # in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
    # while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
    def swap_scale_shift(weight):
        shift, scale = weight.chunk(2, dim=0)
        new_weight = torch.cat([scale, shift], dim=0)
        return new_weight

    # guidance
    converted_state_dict["time_text_embed.embedder.in_proj.bias"] = checkpoint.pop(
            "distilled_guidance_layer.in_proj.bias"
        )
    converted_state_dict["time_text_embed.embedder.in_proj.weight"] = checkpoint.pop(
            "distilled_guidance_layer.in_proj.weight"
        )
    converted_state_dict["time_text_embed.embedder.out_proj.bias"] = checkpoint.pop(
            "distilled_guidance_layer.out_proj.bias"
        )
    converted_state_dict["time_text_embed.embedder.out_proj.weight"] = checkpoint.pop(
            "distilled_guidance_layer.out_proj.weight"
        )
    for i in range(num_guidance_layers):
        block_prefix = f"time_text_embed.embedder.layers.{i}."
        converted_state_dict[f"{block_prefix}linear_1.bias"] = checkpoint.pop(
            f"distilled_guidance_layer.layers.{i}.in_layer.bias"
        )
        converted_state_dict[f"{block_prefix}linear_1.weight"] = checkpoint.pop(
            f"distilled_guidance_layer.layers.{i}.in_layer.weight"
        )
        converted_state_dict[f"{block_prefix}linear_2.bias"] = checkpoint.pop(
            f"distilled_guidance_layer.layers.{i}.out_layer.bias"
        )
        converted_state_dict[f"{block_prefix}linear_2.weight"] = checkpoint.pop(
            f"distilled_guidance_layer.layers.{i}.out_layer.weight"
        )
        converted_state_dict[f"time_text_embed.embedder.norms.{i}.weight"] = checkpoint.pop(
            f"distilled_guidance_layer.norms.{i}.scale"
        )

    # context_embedder
    converted_state_dict["context_embedder.weight"] = checkpoint.pop("txt_in.weight")
    converted_state_dict["context_embedder.bias"] = checkpoint.pop("txt_in.bias")

    # x_embedder
    converted_state_dict["x_embedder.weight"] = checkpoint.pop("img_in.weight")
    converted_state_dict["x_embedder.bias"] = checkpoint.pop("img_in.bias")

    # double transformer blocks
    for i in range(num_layers):
        block_prefix = f"transformer_blocks.{i}."
        # Q, K, V
        sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0)
        context_q, context_k, context_v = torch.chunk(
            checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0
        )
        sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
            checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0
        )
        context_q_bias, context_k_bias, context_v_bias = torch.chunk(
            checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0
        )
        converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q])
        converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias])
        converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k])
        converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias])
        converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v])
        converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias])
        converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q])
        converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias])
        converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k])
        converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias])
        converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v])
        converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias])
        # qk_norm
        converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop(
            f"double_blocks.{i}.img_attn.norm.query_norm.scale"
        )
        converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop(
            f"double_blocks.{i}.img_attn.norm.key_norm.scale"
        )
        converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = checkpoint.pop(
            f"double_blocks.{i}.txt_attn.norm.query_norm.scale"
        )
        converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = checkpoint.pop(
            f"double_blocks.{i}.txt_attn.norm.key_norm.scale"
        )
        # ff img_mlp
        converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = checkpoint.pop(
            f"double_blocks.{i}.img_mlp.0.weight"
        )
        converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.0.bias")
        converted_state_dict[f"{block_prefix}ff.net.2.weight"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.weight")
        converted_state_dict[f"{block_prefix}ff.net.2.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.bias")
        converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = checkpoint.pop(
            f"double_blocks.{i}.txt_mlp.0.weight"
        )
        converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = checkpoint.pop(
            f"double_blocks.{i}.txt_mlp.0.bias"
        )
        converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = checkpoint.pop(
            f"double_blocks.{i}.txt_mlp.2.weight"
        )
        converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = checkpoint.pop(
            f"double_blocks.{i}.txt_mlp.2.bias"
        )
        # output projections.
        converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = checkpoint.pop(
            f"double_blocks.{i}.img_attn.proj.weight"
        )
        converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = checkpoint.pop(
            f"double_blocks.{i}.img_attn.proj.bias"
        )
        converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = checkpoint.pop(
            f"double_blocks.{i}.txt_attn.proj.weight"
        )
        converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = checkpoint.pop(
            f"double_blocks.{i}.txt_attn.proj.bias"
        )

    # single transformer blocks
    for i in range(num_single_layers):
        block_prefix = f"single_transformer_blocks.{i}."
        # Q, K, V, mlp
        mlp_hidden_dim = int(inner_dim * mlp_ratio)
        split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
        q, k, v, mlp = torch.split(checkpoint.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0)
        q_bias, k_bias, v_bias, mlp_bias = torch.split(
            checkpoint.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0
        )
        converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q])
        converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias])
        converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k])
        converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias])
        converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v])
        converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias])
        converted_state_dict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp])
        converted_state_dict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias])
        # qk norm
        converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop(
            f"single_blocks.{i}.norm.query_norm.scale"
        )
        converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop(
            f"single_blocks.{i}.norm.key_norm.scale"
        )
        # output projections.
        converted_state_dict[f"{block_prefix}proj_out.weight"] = checkpoint.pop(f"single_blocks.{i}.linear2.weight")
        converted_state_dict[f"{block_prefix}proj_out.bias"] = checkpoint.pop(f"single_blocks.{i}.linear2.bias")

    converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
    converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")

    return converted_state_dict

@DN6
Copy link
Collaborator

DN6 commented May 20, 2025

@hameerabbasi

We get automatic support for LoRA, ControlNet, IP-Adapter, ...

For this the ChromaTransformer2DModel if you add the appropriate Mixins to the model all these features will work. e.g PeftAdapterMixin, FluxTransformer2DLoadersMixin

Little to no code duplication, which is always nice. Minimal diffs are easier to review.

We are trying to enforce a single file per model policy which tends to be easier to maintain in the long run. So all the necessary components of a model would be defined in a single file. e.g CombinedTimestepTextProjChromaEmbeddings and ChromaApproximator would be defined in the chroma transfomer file. Even if we have to define a new pipeline for Chroma that is totally file. Duplication isn't a concern. See: https://huggingface.co/docs/diffusers/v0.33.1/en/conceptual/philosophy#tweakable-contributor-friendly-over-abstraction

def _get_chroma_attn_mask(self, length: torch.Tensor, max_sequence_length: int) -> torch.Tensor:
attention_mask = torch.zeros((length.shape[0], max_sequence_length), dtype=torch.bool, device=length.device)
for i, n_tokens in enumerate(length):
n_tokens = torch.max(n_tokens + 1, max_sequence_length)
Copy link
Collaborator

@DN6 DN6 May 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does Chroma support tokens beyond the max length for T5? Wouldn't this operation result in a mask that is 512 tokens in length with all True/1 for n_tokens < max_sequence_length?

Also is it not possible to use the attention mask returned by the tokenizer? text_input_ids.attention_mask?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, Chroma needs an attention mask that's equivalent to torch.cat([torch.ones(n_tokens + 1, dtype=torch.bool), torch.zeros(max_tokens - n_tokens - 1, dtype=torch.bool)]), because it needs one unmasked <pad> token after the actual prompt. IIUC text_input_ids.attention_mask will mask all the <pad> tokens.

The torch.max is to handle the corner case where n_tokens == max_sequence_length.

Copy link
Contributor

@Ednaordinary Ednaordinary May 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.max doesnt work correctly here because it assumes max_sequence_length is the dimension (there are not 512 dimensions) but max() might work

though, the discussion here suggests the extra pad token is a mistake in the comfyui implementation so text_input_ids.attention_mask should be fine

Copy link
Contributor Author

@hameerabbasi hameerabbasi May 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.time_text_embed = CombinedTimestepTextProjChromaEmbeddings(
factor=approximator_in_factor,
hidden_dim=approximator_hidden_dim,
out_dim=3 * num_single_layers + 2 * 6 * num_layers + 2,
Copy link
Contributor

@Ednaordinary Ednaordinary May 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this not be 3072? am I missing something? (currently computes to 344, but this doesn't fit distilled_guidance_layer.out_proj and comfy sets it to 3072) mod_index_length/mod_proj.shape[0] should be 344 though

@hameerabbasi
Copy link
Contributor Author

@hameerabbasi

We get automatic support for LoRA, ControlNet, IP-Adapter, ...

For this the ChromaTransformer2DModel if you add the appropriate Mixins to the model all these features will work. e.g PeftAdapterMixin, FluxTransformer2DLoadersMixin

Little to no code duplication, which is always nice. Minimal diffs are easier to review.

We are trying to enforce a single file per model policy which tends to be easier to maintain in the long run. So all the necessary components of a model would be defined in a single file. e.g CombinedTimestepTextProjChromaEmbeddings and ChromaApproximator would be defined in the chroma transfomer file. Even if we have to define a new pipeline for Chroma that is totally file. Duplication isn't a concern. See: https://huggingface.co/docs/diffusers/v0.33.1/en/conceptual/philosophy#tweakable-contributor-friendly-over-abstraction

Cool, so long as we can use FluxTransformer2DLoadersMixin, it should be okay. I'll make a copy.

@Ednaordinary
Copy link
Contributor

Ednaordinary commented May 22, 2025

This may help, it's a working (as in running) guidance embedder

CombinedTimestepTextProjChromaEmbeddings

class CombinedTimestepTextProjChromaEmbeddings(nn.Module):
    def __init__(self, factor: int, hidden_dim: int, out_dim: int, n_layers: int, embedding_dim: int):
        super().__init__()

        self.time_proj = Timesteps(num_channels=factor, flip_sin_to_cos=True, downscale_freq_shift=0)
        self.guidance_proj = Timesteps(num_channels=factor, flip_sin_to_cos=True, downscale_freq_shift=0)
        self.embedder = ChromaApproximator(
            in_dim=factor * 4,
            out_dim=out_dim,
            hidden_dim=hidden_dim,
            n_layers=n_layers,
        )
        self.embedding_dim = embedding_dim

        self.register_buffer(
            "mod_proj",
            get_timestep_embedding(torch.arange(344), 2 * factor, flip_sin_to_cos=True, downscale_freq_shift=0),
            persistent=False,
        )

    def forward(
        self, timestep: torch.Tensor, guidance: Optional[torch.Tensor], pooled_projections: torch.Tensor
    ) -> torch.Tensor:
        mod_index_length = self.mod_proj.shape[0]
        timesteps_proj = self.time_proj(timestep)# + self.time_proj(pooled_projections)
        print(timesteps_proj.dtype)
        print(self.mod_proj.dtype)
        if guidance is not None:
            guidance_proj = self.guidance_proj(guidance.repeat(timesteps_proj.shape[0]))
        else:
            guidance_proj = torch.zeros(
                (1, self.guidance_proj.num_channels),
                dtype=timesteps_proj.dtype,
                device=timesteps_proj.device,
            )
        mod_proj = self.mod_proj.unsqueeze(0).repeat(timesteps_proj.shape[0], 1, 1).to(dtype=timesteps_proj.dtype, device=timesteps_proj.device)
        timestep_guidance = (
            torch.cat([timesteps_proj, guidance_proj], dim=1).repeat(1, mod_index_length, 1)
        )
        input_vec = torch.cat([timestep_guidance, mod_proj], dim=-1)
        conditioning = self.embedder(input_vec)
        return conditioning

I found that the version currently in the PR is not working correctly so I modified it. One large thing to note is that it doesn't use pooled_projections since chroma seemingly does not use CLIP, so the final pipeline likely should not include CLIP and should rename text_encoder_2 to text_encoder (along with tokenizer_2). pooled_projections was reshaping the array in a weird way that broke stuff. I got the pipeline to "run" by making this modification and fixing the norm output by using .squeeze(1) (they currently output in the wrong shape) though I still get just noise regardless of guidance_scale, true_cfg_scale and guidance_embeds (in the transformer config).

bad output (eyesore)

chroma

@DN6
Copy link
Collaborator

DN6 commented Jun 3, 2025

@hameerabbasi How's it going? LMK if you need any help moving this PR forward.

@ghunkins
Copy link
Contributor

ghunkins commented Jun 4, 2025

Chroma is having a massive moment and is gaining momentum, @DN6 any chance the core team has bandwidth to help? Happy to provide support if needed to help get this across the finish line.

@hameerabbasi
Copy link
Contributor Author

Hello! I intend to continue work on this but my bandwidth is limited to some weekends. Please feel free to build on it if someone has more time.

@DN6
Copy link
Collaborator

DN6 commented Jun 9, 2025

@hameerabbasi I can push up some changes to your branch this week if that's cool?

@hameerabbasi
Copy link
Contributor Author

No problem; if you have the bandwidth it's fine to take over the PR or branch off it.

@iddl
Copy link

iddl commented Jun 9, 2025

@hameerabbasi @DN6: I've got the branch to a point where it's successfully generating images. Commits at: hameerabbasi#1 if you want to merge them. I also left a script to show the entire end-to-end workflow.

@DN6: This update doesn't incorporate the refactoring you asked for, but it will save you a significant amount of debugging time related to getting images to generate correctly.

Screenshot 2025-06-09 at 10 54 12 AM

@hameerabbasi
Copy link
Contributor Author

Thanks @iddl; I merged the changes. I'll do my best to move it to a new class in the coming week. Did you verify Flux checkpoints (of various kinds) still load?

@iddl
Copy link

iddl commented Jun 10, 2025

@hameerabbasi I didn't verify that the flux checkpoints still load because I assumed that with the refactor we might just leave all of the flux code eventually untouched without the various if variant == ...

@hameerabbasi
Copy link
Contributor Author

Ah -- there was a bunch of code deleted from the main path and put under if statements in your commit; which is why I asked.

@iddl
Copy link

iddl commented Jun 10, 2025

Right, the layers under the if statements should be all the layers that chroma dropped. Flux should still work, let me check tomorrow to confirm.

@Ednaordinary
Copy link
Contributor

Ednaordinary commented Jun 10, 2025

I'm gonna start working on getting the pipeline separated :) Unsure if it would be better to make a pr for @hameerabbasi or make a new diffusers PR. I don't want to mess with flux commit history if I don't have to. @DN6 ?

@hameerabbasi
Copy link
Contributor Author

hameerabbasi commented Jun 10, 2025

Either is fine for me!

@DN6
Copy link
Collaborator

DN6 commented Jun 10, 2025

Thank you all for being so active here 😄

I'll start working on this today.

@Ednaordinary Would it be possible for you to compare some outputs with @iddl's changes with ones generated using the default attention mask output from the T5 tokenizer?

I'll get the new Chroma Transformer model done in the mean time.

@Ednaordinary
Copy link
Contributor

Ednaordinary commented Jun 10, 2025

Sure! Though note iddl didn't change anything there so it's the same as back in your initial review. If its helpful, here's how far I've gotten in separating stuff out. The transformer was pretty easy, the pipeline is the harder part since we have to refactor true_cfg_scale and text_encoder_2, tokenizer_2 -> text_encoder, tokenizer

@Ednaordinary
Copy link
Contributor

Ednaordinary commented Jun 10, 2025

Here's what I get with `text_inputs.attention_mask` vs `_get_chroma_attn_mask`

text_inputs.attention_mask
chroma

_get_chroma_attn_mask
chroma_iddl

Unsure why the difference is so big if the tokenizers attention_mask is different by just a pad token

@Ednaordinary
Copy link
Contributor

okay so this is a bit confusing but I wrapped my head around what's happening:

  • _get_chroma_attn_mask doesn't work right now from what I can tell. Above image is with no attention mask
  • above representation for text_inputs.attention_mask is (probably) correct
  • loading the pipeline with variant = "chroma" is impossible because pipeline_loading_utils thinks its a huggingface variant
  • I was loading the transformer first and then loading the pipeline with transformer=transformer, so it thought the variant was still "flux" inside the transformer

decoupling the chroma pipeline should make this much easier to deal with

@DN6
Copy link
Collaborator

DN6 commented Jun 10, 2025

Oh @Ednaordinary sorry didn't realise you worked on the transformer as well. I've pushed up the changes to the chroma branch in the diffusers repo if you want to include that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
8 participants