-
Notifications
You must be signed in to change notification settings - Fork 6k
Chroma as a FLUX.1 variant #11566
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Chroma as a FLUX.1 variant #11566
Conversation
Hi, is there a specific model repo this works with. I'm getting key errors loading the models from lodestones/Chroma and silveroxides/Chroma-GGUF PR installed by using $ pip uninstall diffusers
$ pip install git+https://github.com/huggingface/diffusers.git@refs/pull/11566/merge
$ grep chroma lib/python3.11/site-packages/diffusers/models/transformers/transformer_flux.py
INVALID_VARIANT_ERRMSG = "`variant` must be `'flux' or `'chroma'`."
elif variant == "chroma":
elif variant == "chroma":
elif variant == "chroma":
norm_out_cls = AdaLayerNormContinuous if variant != "chroma" else AdaLayerNormContinuousPruned
is_chroma = isinstance(self.time_text_embed, CombinedTimestepTextProjChromaEmbeddings)
if not is_chroma:
if is_chroma:
if is_chroma:
if is_chroma:
I'm loading the models using #ckpt_path="/Volumes/SSD2TB/AI/caches/models/chroma-unlocked-v29-Q8_0.gguf"
ckpt_path="/Volumes/SSD2TB/AI/caches/models/chroma-unlocked-v29.5.safetensors"
transformer = FluxTransformer2DModel.from_single_file(
ckpt_path,
# quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
torch_dtype=torch.bfloat16,
variant="chroma"
) The full error is.... Traceback (most recent call last):
File "/Volumes/SSD2TB/AI/Diffusers/gguf_chroma.py", line 48, in <module>
transformer = FluxTransformer2DModel.from_single_file(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/huggingface_hub/utils/_validators.py", line 114, in _inner_fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/diffusers/loaders/single_file_model.py", line 370, in from_single_file
diffusers_format_checkpoint = checkpoint_mapping_fn(
^^^^^^^^^^^^^^^^^^^^^^
File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/diffusers/loaders/single_file_utils.py", line 2157, in convert_flux_transformer_checkpoint_to_diffusers
converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop(
^^^^^^^^^^^^^^^
KeyError: 'time_in.in_layer.weight' |
This is still Draft. |
Hi @hameerabbasi Thank you so much for your PR and for taking up this integration. 👍🏽 It's looking good so far but would it be possible to define a dedicated transformer model for Chroma. e.g Given that the model is still being actively trained, there are potentially a lot of changes that could be introduced. It would be better to isolate them rather than add additional conditional logic to Flux. |
Thank you for the awesome work on diffusers!
It certainly would, but I have my reasons for choosing to integrate it directly into
The weights are certainly in flux (pun intended), but the architecture has been fixed for quite some time (maybe @lodestone-rock can confirm). Since the code encodes the architecture rather than weights, I chose to do it this way. Additionally, back-compat is untouched -- so Flux itself will be unaffected by this change. |
@hameerabbasi @DN6 |
Here's a seemingly working from_single_file function :) Though I still cannot get it to run (errors in time_text_embed) I also agree with DN6 that introducing a separate class is cleaner than using variant (no other class does this and variant normally implies a dtype option like fp16 or fp32 when downloading from the hub) convert_chroma_transformer_checkpoint_to_diffusers
def convert_chroma_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {}
keys = list(checkpoint.keys())
for k in keys:
if "model.diffusion_model." in k:
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401
num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401
num_guidance_layers = list(set(int(k.split(".", 3)[2]) for k in checkpoint if "distilled_guidance_layer.layers." in k))[-1] + 1 # noqa: C401
mlp_ratio = 4.0
inner_dim = 3072
# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
def swap_scale_shift(weight):
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight
# guidance
converted_state_dict["time_text_embed.embedder.in_proj.bias"] = checkpoint.pop(
"distilled_guidance_layer.in_proj.bias"
)
converted_state_dict["time_text_embed.embedder.in_proj.weight"] = checkpoint.pop(
"distilled_guidance_layer.in_proj.weight"
)
converted_state_dict["time_text_embed.embedder.out_proj.bias"] = checkpoint.pop(
"distilled_guidance_layer.out_proj.bias"
)
converted_state_dict["time_text_embed.embedder.out_proj.weight"] = checkpoint.pop(
"distilled_guidance_layer.out_proj.weight"
)
for i in range(num_guidance_layers):
block_prefix = f"time_text_embed.embedder.layers.{i}."
converted_state_dict[f"{block_prefix}linear_1.bias"] = checkpoint.pop(
f"distilled_guidance_layer.layers.{i}.in_layer.bias"
)
converted_state_dict[f"{block_prefix}linear_1.weight"] = checkpoint.pop(
f"distilled_guidance_layer.layers.{i}.in_layer.weight"
)
converted_state_dict[f"{block_prefix}linear_2.bias"] = checkpoint.pop(
f"distilled_guidance_layer.layers.{i}.out_layer.bias"
)
converted_state_dict[f"{block_prefix}linear_2.weight"] = checkpoint.pop(
f"distilled_guidance_layer.layers.{i}.out_layer.weight"
)
converted_state_dict[f"time_text_embed.embedder.norms.{i}.weight"] = checkpoint.pop(
f"distilled_guidance_layer.norms.{i}.scale"
)
# context_embedder
converted_state_dict["context_embedder.weight"] = checkpoint.pop("txt_in.weight")
converted_state_dict["context_embedder.bias"] = checkpoint.pop("txt_in.bias")
# x_embedder
converted_state_dict["x_embedder.weight"] = checkpoint.pop("img_in.weight")
converted_state_dict["x_embedder.bias"] = checkpoint.pop("img_in.bias")
# double transformer blocks
for i in range(num_layers):
block_prefix = f"transformer_blocks.{i}."
# Q, K, V
sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0)
context_q, context_k, context_v = torch.chunk(
checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0
)
sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0
)
context_q_bias, context_k_bias, context_v_bias = torch.chunk(
checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0
)
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q])
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias])
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k])
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias])
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v])
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias])
converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q])
converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias])
converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k])
converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias])
converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v])
converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias])
# qk_norm
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop(
f"double_blocks.{i}.img_attn.norm.query_norm.scale"
)
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop(
f"double_blocks.{i}.img_attn.norm.key_norm.scale"
)
converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = checkpoint.pop(
f"double_blocks.{i}.txt_attn.norm.query_norm.scale"
)
converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = checkpoint.pop(
f"double_blocks.{i}.txt_attn.norm.key_norm.scale"
)
# ff img_mlp
converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = checkpoint.pop(
f"double_blocks.{i}.img_mlp.0.weight"
)
converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.0.bias")
converted_state_dict[f"{block_prefix}ff.net.2.weight"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.weight")
converted_state_dict[f"{block_prefix}ff.net.2.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.bias")
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = checkpoint.pop(
f"double_blocks.{i}.txt_mlp.0.weight"
)
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = checkpoint.pop(
f"double_blocks.{i}.txt_mlp.0.bias"
)
converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = checkpoint.pop(
f"double_blocks.{i}.txt_mlp.2.weight"
)
converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = checkpoint.pop(
f"double_blocks.{i}.txt_mlp.2.bias"
)
# output projections.
converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = checkpoint.pop(
f"double_blocks.{i}.img_attn.proj.weight"
)
converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = checkpoint.pop(
f"double_blocks.{i}.img_attn.proj.bias"
)
converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = checkpoint.pop(
f"double_blocks.{i}.txt_attn.proj.weight"
)
converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = checkpoint.pop(
f"double_blocks.{i}.txt_attn.proj.bias"
)
# single transformer blocks
for i in range(num_single_layers):
block_prefix = f"single_transformer_blocks.{i}."
# Q, K, V, mlp
mlp_hidden_dim = int(inner_dim * mlp_ratio)
split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
q, k, v, mlp = torch.split(checkpoint.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0)
q_bias, k_bias, v_bias, mlp_bias = torch.split(
checkpoint.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0
)
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q])
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias])
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k])
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias])
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v])
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias])
converted_state_dict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp])
converted_state_dict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias])
# qk norm
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop(
f"single_blocks.{i}.norm.query_norm.scale"
)
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop(
f"single_blocks.{i}.norm.key_norm.scale"
)
# output projections.
converted_state_dict[f"{block_prefix}proj_out.weight"] = checkpoint.pop(f"single_blocks.{i}.linear2.weight")
converted_state_dict[f"{block_prefix}proj_out.bias"] = checkpoint.pop(f"single_blocks.{i}.linear2.bias")
converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
return converted_state_dict |
For this the
We are trying to enforce a single file per model policy which tends to be easier to maintain in the long run. So all the necessary components of a model would be defined in a single file. e.g |
def _get_chroma_attn_mask(self, length: torch.Tensor, max_sequence_length: int) -> torch.Tensor: | ||
attention_mask = torch.zeros((length.shape[0], max_sequence_length), dtype=torch.bool, device=length.device) | ||
for i, n_tokens in enumerate(length): | ||
n_tokens = torch.max(n_tokens + 1, max_sequence_length) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does Chroma support tokens beyond the max length for T5? Wouldn't this operation result in a mask that is 512 tokens in length with all True/1 for n_tokens < max_sequence_length?
Also is it not possible to use the attention mask returned by the tokenizer? text_input_ids.attention_mask
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, Chroma needs an attention mask that's equivalent to torch.cat([torch.ones(n_tokens + 1, dtype=torch.bool), torch.zeros(max_tokens - n_tokens - 1, dtype=torch.bool)])
, because it needs one unmasked <pad>
token after the actual prompt. IIUC text_input_ids.attention_mask
will mask all the <pad>
tokens.
The torch.max
is to handle the corner case where n_tokens == max_sequence_length
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.max
doesnt work correctly here because it assumes max_sequence_length
is the dimension (there are not 512 dimensions) but max()
might work
though, the discussion here suggests the extra pad token is a mistake in the comfyui implementation so text_input_ids.attention_mask
should be fine
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, https://huggingface.co/lodestones/Chroma#tldr-masking-t5-padding-tokens-enhanced-fidelity-and-increased-stability-during-training seems to suggest otherwise.
self.time_text_embed = CombinedTimestepTextProjChromaEmbeddings( | ||
factor=approximator_in_factor, | ||
hidden_dim=approximator_hidden_dim, | ||
out_dim=3 * num_single_layers + 2 * 6 * num_layers + 2, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this not be 3072? am I missing something? (currently computes to 344, but this doesn't fit distilled_guidance_layer.out_proj
and comfy sets it to 3072) mod_index_length/mod_proj.shape[0] should be 344 though
Cool, so long as we can use |
This may help, it's a working (as in running) guidance embedder CombinedTimestepTextProjChromaEmbeddings
class CombinedTimestepTextProjChromaEmbeddings(nn.Module):
def __init__(self, factor: int, hidden_dim: int, out_dim: int, n_layers: int, embedding_dim: int):
super().__init__()
self.time_proj = Timesteps(num_channels=factor, flip_sin_to_cos=True, downscale_freq_shift=0)
self.guidance_proj = Timesteps(num_channels=factor, flip_sin_to_cos=True, downscale_freq_shift=0)
self.embedder = ChromaApproximator(
in_dim=factor * 4,
out_dim=out_dim,
hidden_dim=hidden_dim,
n_layers=n_layers,
)
self.embedding_dim = embedding_dim
self.register_buffer(
"mod_proj",
get_timestep_embedding(torch.arange(344), 2 * factor, flip_sin_to_cos=True, downscale_freq_shift=0),
persistent=False,
)
def forward(
self, timestep: torch.Tensor, guidance: Optional[torch.Tensor], pooled_projections: torch.Tensor
) -> torch.Tensor:
mod_index_length = self.mod_proj.shape[0]
timesteps_proj = self.time_proj(timestep)# + self.time_proj(pooled_projections)
print(timesteps_proj.dtype)
print(self.mod_proj.dtype)
if guidance is not None:
guidance_proj = self.guidance_proj(guidance.repeat(timesteps_proj.shape[0]))
else:
guidance_proj = torch.zeros(
(1, self.guidance_proj.num_channels),
dtype=timesteps_proj.dtype,
device=timesteps_proj.device,
)
mod_proj = self.mod_proj.unsqueeze(0).repeat(timesteps_proj.shape[0], 1, 1).to(dtype=timesteps_proj.dtype, device=timesteps_proj.device)
timestep_guidance = (
torch.cat([timesteps_proj, guidance_proj], dim=1).repeat(1, mod_index_length, 1)
)
input_vec = torch.cat([timestep_guidance, mod_proj], dim=-1)
conditioning = self.embedder(input_vec)
return conditioning I found that the version currently in the PR is not working correctly so I modified it. One large thing to note is that it doesn't use pooled_projections since chroma seemingly does not use CLIP, so the final pipeline likely should not include CLIP and should rename text_encoder_2 to text_encoder (along with tokenizer_2). pooled_projections was reshaping the array in a weird way that broke stuff. I got the pipeline to "run" by making this modification and fixing the norm output by using .squeeze(1) (they currently output in the wrong shape) though I still get just noise regardless of |
@hameerabbasi How's it going? LMK if you need any help moving this PR forward. |
Chroma is having a massive moment and is gaining momentum, @DN6 any chance the core team has bandwidth to help? Happy to provide support if needed to help get this across the finish line. |
Hello! I intend to continue work on this but my bandwidth is limited to some weekends. Please feel free to build on it if someone has more time. |
@hameerabbasi I can push up some changes to your branch this week if that's cool? |
No problem; if you have the bandwidth it's fine to take over the PR or branch off it. |
@hameerabbasi @DN6: I've got the branch to a point where it's successfully generating images. Commits at: hameerabbasi#1 if you want to merge them. I also left a script to show the entire end-to-end workflow. @DN6: This update doesn't incorporate the refactoring you asked for, but it will save you a significant amount of debugging time related to getting images to generate correctly. |
Thanks @iddl; I merged the changes. I'll do my best to move it to a new class in the coming week. Did you verify Flux checkpoints (of various kinds) still load? |
@hameerabbasi I didn't verify that the flux checkpoints still load because I assumed that with the refactor we might just leave all of the flux code eventually untouched without the various |
Ah -- there was a bunch of code deleted from the main path and put under |
Right, the layers under the |
I'm gonna start working on getting the pipeline separated :) Unsure if it would be better to make a pr for @hameerabbasi or make a new diffusers PR. I don't want to mess with flux commit history if I don't have to. @DN6 ? |
Either is fine for me! |
Thank you all for being so active here 😄 I'll start working on this today. @Ednaordinary Would it be possible for you to compare some outputs with @iddl's changes with ones generated using the default attention mask output from the T5 tokenizer? I'll get the new Chroma Transformer model done in the mean time. |
Sure! Though note iddl didn't change anything there so it's the same as back in your initial review. If its helpful, here's how far I've gotten in separating stuff out. The transformer was pretty easy, the pipeline is the harder part since we have to refactor |
okay so this is a bit confusing but I wrapped my head around what's happening:
decoupling the chroma pipeline should make this much easier to deal with |
Oh @Ednaordinary sorry didn't realise you worked on the transformer as well. I've pushed up the changes to the |
What does this PR do?
Fixes # (issue)
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.