Skip to content

Context Parallel w/ Ring & Ulysses & Unified Attention #11941

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 32 commits into
base: main
Choose a base branch
from

Conversation

a-r-r-o-w
Copy link
Member

@a-r-r-o-w a-r-r-o-w commented Jul 16, 2025

Adds support for ring, ulysses and unified attention natively. For a minimal PoC, I've limited changes to Flux.

Supported attention backends with CP: cuDNN, FA2, Sage.

Requires #11916 to be merged first.

Minimal example

import torch
from diffusers import FluxPipeline
from diffusers.models import ParallelConfig

try:
    torch.distributed.init_process_group("nccl")
    rank = torch.distributed.get_rank()
    device = torch.device("cuda", rank % torch.cuda.device_count())
    torch.cuda.set_device(device)

    pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
    pipe.to(device)
    # pipe.transformer.parallelize(ring_degree=2)
    # pipe.transformer.set_attention_backend("_native_cudnn")
    pipe.transformer.set_attention_backend("flash")
    # pipe.transformer.compile(fullgraph=True)
    # pipe.transformer.compile(fullgraph=True, mode="max-autotune")

    prompt = "A cat holding a sign that says 'hello world'"
    
    # Must specify generator so all ranks start with same latents (or pass your own)
    generator = torch.Generator().manual_seed(42)
    with pipe.transformer.parallelize(config=ParallelConfig(ulysses_degree=2)):
        image = pipe(prompt, num_inference_steps=2, guidance_scale=4.0, generator=generator).images[0]
        image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0, generator=generator).images[0]
    
    if rank == 0:
        image.save("output.png")

except Exception as e:
    print(f"An error occurred: {e}")
    torch.distributed.breakpoint()
    raise

finally:
    if torch.distributed.is_initialized():
        torch.distributed.destroy_process_group()

Benchmarks

TODO

Explanation

Each model should define a _cp_plan attribute that contains information on how to shard/gather tensors at different stages of the forward.

TODO

Note: There were some merge conflicts that I'm not sure I resolved correctly. Some things may be broken. For this reason, I've removed training support and only tested inference. I'll update some of the TODOs tomorrow

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@a-r-r-o-w a-r-r-o-w added the roadmap Add to current release roadmap label Jul 16, 2025
@sayakpaul
Copy link
Member

I am going to review it very soon. But before I do I would like to read a bit about unified attention. Simple searches returned results that didn't seem relevant. Hence the ask.

@a-r-r-o-w
Copy link
Member Author

Unified CP is a generalization of performing Ulysses and Ring together. Both those methods become special subcases of Unified attention. Paper: https://arxiv.org/abs/2405.07719v3

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking really promising! Thank you, Aryan.

# TODO(aryan): consolidate with ._helpers.TransformerBlockMetadata
@dataclass
class ModuleForwardMetadata:
cached_parameter_indices: Dict[str, int] = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit): No need for this PR, but might make sense to introduce help for args like this.

Comment on lines 108 to 116
# HACK: we cannot use context managers or setattr or similar solutions in an overwritten forward
# diffusers hook method because Dynamo fails to trace it. Instead, we make use of module hooks
# available in pytorch to set the parallel context before/after the forward/backward pass.
# It is dirty, but fullgraph=True tracing works because of this and I haven't found a better solution yet.
# The previous/older implementation simply did this:
# def new_forward(self, ...):
# with _parallel_context(parallel_config):
# return self.fn_ref.original_forward(*args, **kwargs)
# TODO: ask help from Pytorch team on how to improve this
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cc: @anijain2305 as an FYI.

@a-r-r-o-w we don't have to tackle the following in this PR but is using fullgraph=True is beneficial from the perspectives of number of recompilation, latency, etc.? Totally okay if we don't know that yet. But in my experience, sometimes, "fullgraph=True + more recompiles" performs worse than a somewhat "fragmented graph + less recompiles".

If this works out well we can potentially upstream this into accelerate, too, as it makes use of the new_forward pattern and propagate in our offloading hooks if necessary.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've gotten rid of this in favor of context manager based parallelize function (for the time being). The previous hack is not good

Comment on lines +239 to +244
if is_tensor:
output = [output]
elif not (isinstance(output, (list, tuple)) and all(isinstance(x, torch.Tensor) for x in output)):
raise ValueError(f"Expected output to be a tensor or a list/tuple of tensors, but got {type(output)}.")

output = list(output)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if is_tensor:
   output = [output]
...
output = list(output)

Is this intended?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, calling list() on a list does not create nested list, if that was the question. This is so we can convert tuple output into a in-place modifiable list. Maybe I should do this to avoid confusion:

if is_tensor:
    output = (output,)
Comment on lines +237 to +242
is_tensor = isinstance(output, torch.Tensor)

if is_tensor:
output = [output]
elif not (isinstance(output, (list, tuple)) and all(isinstance(x, torch.Tensor) for x in output)):
raise ValueError(f"Expected output to be a tensor or a list/tuple of tensors, but got {type(output)}.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit): maybe we could follow the unified format of checking this error?

I like the following approach:

is_tensor_list = isinstance(output, (list, tuple)) and all(isinstance(x, torch.Tensor) for x in output)

if backend not in cls._supports_context_parallel:
raise ValueError(f"Backend {backend} is not registered.")
supports_context_parallel = cls._supports_context_parallel[backend]
is_degree_greater_than_1 = _AttentionBackendRegistry._parallel_config is not None and (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit): a comment to denote what degree means in this context would be helpful.

return torch.empty_like(q), q.new_empty(lse_shape)


# ===== Autograd functions =====
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A comment on why we're having to keep the autograds for forward and backward here would be helpful.

@a-r-r-o-w
Copy link
Member Author

Pushed some changes to support ring attention backward. Running into a super weird error when comparing the gradients. The outputs match and the gradients match, except everything in the parallelized module has grad set to None after the last single transformer block (I don't even know how that's possible). I'm leaving the code and outputs here in case someone has insights.

code
import torch
from diffusers import FluxTransformer2DModel
from diffusers.models import ParallelConfig
from diffusers.hooks import ModelHook, HookRegistry

class OutputRequiresGradHook(ModelHook):
    def __init__(self, name):
        super().__init__()
        self.name = name
    
    def post_forward(self, module, output):
        if torch.distributed.get_rank() == 0:
            if isinstance(output, torch.Tensor):
                requires_grad = output.requires_grad
            elif isinstance(output, (tuple, list)):
                requires_grad = [o.requires_grad for o in output if isinstance(o, torch.Tensor)]
            else:
                requires_grad = None
            print(self.name, requires_grad)
        return output

def main():
    dtype = torch.bfloat16
    device = "cuda"
    world_size = torch.distributed.get_world_size()
    
    in_channels = 16
    attention_head_dim = 128
    num_attention_heads = 4
    joint_attention_dim = 1024
    pooled_projection_dim = 256
    init_config = dict(
        patch_size=1,
        in_channels=in_channels,
        num_layers=2,
        num_single_layers=2,
        attention_head_dim=attention_head_dim,
        num_attention_heads=num_attention_heads,
        joint_attention_dim=joint_attention_dim,
        pooled_projection_dim=pooled_projection_dim,
        axes_dims_rope=[16, 56, 56],
    )
    
    torch.manual_seed(0)
    transformer_1 = FluxTransformer2DModel(**init_config)
    config = ParallelConfig(ring_degree=world_size)
    
    torch.manual_seed(0)
    transformer_2 = FluxTransformer2DModel(**init_config)

    transformer_1.to(device=device, dtype=dtype)
    transformer_1.requires_grad_(True)
    transformer_1.set_attention_backend("flash")
    transformer_1.train()

    transformer_2.to(device=device, dtype=dtype)
    transformer_2.requires_grad_(True)
    transformer_2.set_attention_backend("flash")
    transformer_2.train()

    for name, submodule in transformer_1.named_modules():
        registry = HookRegistry.check_if_exists_or_initialize(submodule)
        registry.register_hook(OutputRequiresGradHook(name), "output_requires_grad_hook")
    for name, submodule in transformer_2.named_modules():
        registry = HookRegistry.check_if_exists_or_initialize(submodule)
        registry.register_hook(OutputRequiresGradHook(name), "output_requires_grad_hook")
    
    batch_size = 1
    text_sequence_length = 128
    height = width = 64

    x = torch.randn((batch_size, height * width, in_channels), device=device, dtype=dtype)
    noise = torch.randn((batch_size, height * width, in_channels), device=device, dtype=dtype)
    hidden_states = x + noise
    encoder_hidden_states = torch.randn((batch_size, text_sequence_length, joint_attention_dim), device=device, dtype=dtype)
    pooled_prompt_embeds = torch.randn((batch_size, pooled_projection_dim), device=device, dtype=dtype)
    text_ids = torch.zeros((text_sequence_length, 3), device=device, dtype=dtype)
    image_ids = torch.randn((height * width, 3), device=device, dtype=dtype)
    timestep = torch.tensor([1.0]).to(device).expand(batch_size)

    with transformer_1.parallelize(config=config):
        out1 = transformer_1(hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, pooled_projections=pooled_prompt_embeds, img_ids=image_ids, txt_ids=text_ids, timestep=timestep, return_dict=False)[0]
        loss1 = (out1.float() - noise.float()).pow(2).mean()
        loss1.backward()
        
    out2 = transformer_2(hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, pooled_projections=pooled_prompt_embeds, img_ids=image_ids, txt_ids=text_ids, timestep=timestep, return_dict=False)[0]
    loss2 = (out2.float() - noise.float()).pow(2).mean()
    loss2.backward()
    
    torch.cuda.synchronize()
    torch.testing.assert_close(out1, out2, rtol=1e-2, atol=1e-2)
    
    for (name, param1), (name2, param2) in zip(transformer_1.named_parameters(), transformer_2.named_parameters()):
        if param1.grad is not None and param2.grad is not None:
            assert name == name2
            torch.testing.assert_close(param1.grad, param2.grad, rtol=5e-2, atol=5e-2)
        else:
            if torch.distributed.get_rank() == 0:
                print(f"Skipping gradient check for {name} as one of the gradients is None (param1.grad: {param1.grad is not None}, param2.grad: {param2.grad is not None})")


try:
    torch.distributed.init_process_group(backend="nccl")
    torch.cuda.set_device(torch.distributed.get_rank() % torch.cuda.device_count())
    main()
except Exception as e:
    print(f"Error initializing distributed process group: {e}")
    raise
finally:
    if torch.distributed.is_initialized():
        torch.distributed.destroy_process_group()
output
(nightly-venv) aryan@hf-dgx-01:~/work/diffusers$ torchrun --nproc_per_node 2 dump7.py
W0730 07:24:10.460000 1504699 torch/distributed/run.py:766]
W0730 07:24:10.460000 1504699 torch/distributed/run.py:766] *****************************************
W0730 07:24:10.460000 1504699 torch/distributed/run.py:766] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0730 07:24:10.460000 1504699 torch/distributed/run.py:766] *****************************************
Attention backends are an experimental feature and the API may be subject to change.
Attention backends are an experimental feature and the API may be subject to change.
Attention backends are an experimental feature and the API may be subject to change.
Attention backends are an experimental feature and the API may be subject to change.
`parallelize` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning.
`parallelize` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning.
x_embedder True
time_text_embed.time_proj False
time_text_embed.timestep_embedder.linear_1 True
time_text_embed.timestep_embedder.act True
time_text_embed.timestep_embedder.linear_2 True
time_text_embed.timestep_embedder True
time_text_embed.text_embedder.linear_1 True
time_text_embed.text_embedder.act_1 True
time_text_embed.text_embedder.linear_2 True
time_text_embed.text_embedder True
time_text_embed True
context_embedder True
pos_embed [False, False]
transformer_blocks.0.norm1.silu True
transformer_blocks.0.norm1.linear True
transformer_blocks.0.norm1.norm True
transformer_blocks.0.norm1 [True, True, True, True, True]
transformer_blocks.0.norm1_context.silu True
transformer_blocks.0.norm1_context.linear True
transformer_blocks.0.norm1_context.norm True
transformer_blocks.0.norm1_context [True, True, True, True, True]
transformer_blocks.0.attn.to_q True
transformer_blocks.0.attn.to_k True
transformer_blocks.0.attn.to_v True
transformer_blocks.0.attn.add_q_proj True
transformer_blocks.0.attn.add_k_proj True
transformer_blocks.0.attn.add_v_proj True
transformer_blocks.0.attn.norm_q True
transformer_blocks.0.attn.norm_k True
transformer_blocks.0.attn.norm_added_q True
transformer_blocks.0.attn.norm_added_k True
transformer_blocks.0.attn.to_out.0 True
transformer_blocks.0.attn.to_out.1 True
transformer_blocks.0.attn.to_add_out True
transformer_blocks.0.attn [True, True]
transformer_blocks.0.norm2 True
transformer_blocks.0.ff.net.0.proj True
transformer_blocks.0.ff.net.0 True
transformer_blocks.0.ff.net.1 True
transformer_blocks.0.ff.net.2 True
transformer_blocks.0.ff True
transformer_blocks.0.norm2_context True
transformer_blocks.0.ff_context.net.0.proj True
transformer_blocks.0.ff_context.net.0 True
transformer_blocks.0.ff_context.net.1 True
transformer_blocks.0.ff_context.net.2 True
transformer_blocks.0.ff_context True
transformer_blocks.0 [True, True]
transformer_blocks.1.norm1.silu True
transformer_blocks.1.norm1.linear True
transformer_blocks.1.norm1.norm True
transformer_blocks.1.norm1 [True, True, True, True, True]
transformer_blocks.1.norm1_context.silu True
transformer_blocks.1.norm1_context.linear True
transformer_blocks.1.norm1_context.norm True
transformer_blocks.1.norm1_context [True, True, True, True, True]
transformer_blocks.1.attn.to_q True
transformer_blocks.1.attn.to_k True
transformer_blocks.1.attn.to_v True
transformer_blocks.1.attn.add_q_proj True
transformer_blocks.1.attn.add_k_proj True
transformer_blocks.1.attn.add_v_proj True
transformer_blocks.1.attn.norm_q True
transformer_blocks.1.attn.norm_k True
transformer_blocks.1.attn.norm_added_q True
transformer_blocks.1.attn.norm_added_k True
transformer_blocks.1.attn.to_out.0 True
transformer_blocks.1.attn.to_out.1 True
transformer_blocks.1.attn.to_add_out True
transformer_blocks.1.attn [True, True]
transformer_blocks.1.norm2 True
transformer_blocks.1.ff.net.0.proj True
transformer_blocks.1.ff.net.0 True
transformer_blocks.1.ff.net.1 True
transformer_blocks.1.ff.net.2 True
transformer_blocks.1.ff True
transformer_blocks.1.norm2_context True
transformer_blocks.1.ff_context.net.0.proj True
transformer_blocks.1.ff_context.net.0 True
transformer_blocks.1.ff_context.net.1 True
transformer_blocks.1.ff_context.net.2 True
transformer_blocks.1.ff_context True
transformer_blocks.1 [True, True]
single_transformer_blocks.0.norm.silu True
single_transformer_blocks.0.norm.linear True
single_transformer_blocks.0.norm.norm True
single_transformer_blocks.0.norm [True, True]
single_transformer_blocks.0.proj_mlp True
single_transformer_blocks.0.act_mlp True
single_transformer_blocks.0.attn.to_q True
single_transformer_blocks.0.attn.to_k True
single_transformer_blocks.0.attn.to_v True
single_transformer_blocks.0.attn.norm_q True
single_transformer_blocks.0.attn.norm_k True
single_transformer_blocks.0.attn True
single_transformer_blocks.0.proj_out True
single_transformer_blocks.0 [True, True]
single_transformer_blocks.1.norm.silu True
single_transformer_blocks.1.norm.linear True
single_transformer_blocks.1.norm.norm True
single_transformer_blocks.1.norm [True, True]
single_transformer_blocks.1.proj_mlp True
single_transformer_blocks.1.act_mlp True
single_transformer_blocks.1.attn.to_q True
single_transformer_blocks.1.attn.to_k True
single_transformer_blocks.1.attn.to_v True
single_transformer_blocks.1.attn.norm_q True
single_transformer_blocks.1.attn.norm_k True
single_transformer_blocks.1.attn True
single_transformer_blocks.1.proj_out True
single_transformer_blocks.1 [True, True]
norm_out.silu True
norm_out.linear True
norm_out.norm True
norm_out True
proj_out True
 [True]
/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/autograd/graph.py:824: UserWarning: _c10d_functional::wait_tensor: an autograd kernel was not registered to the Autograd key(s) but we are trying to backprop through it. This may lead to silently incorrect behavior. This behavior is deprecated and will be removed in a future version of PyTorch. If your operator is differentiable, please ensure you have registered an autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, DispatchKey::CompositeImplicitAutograd). If your operator is not differentiable, or to squash this warning and use the previous behavior, please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd. (Triggered internally at /pytorch/torch/csrc/autograd/autograd_not_implemented_fallback.cpp:62.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/autograd/graph.py:824: UserWarning: _c10d_functional::all_gather_into_tensor: an autograd kernel was not registered to the Autograd key(s) but we are trying to backprop through it. This may lead to silently incorrect behavior. This behavior is deprecated and will be removed in a future version of PyTorch. If your operator is differentiable, please ensure you have registered an autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, DispatchKey::CompositeImplicitAutograd). If your operator is not differentiable, or to squash this warning and use the previous behavior, please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd. (Triggered internally at /pytorch/torch/csrc/autograd/autograd_not_implemented_fallback.cpp:62.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/autograd/graph.py:824: UserWarning: _c10d_functional::wait_tensor: an autograd kernel was not registered to the Autograd key(s) but we are trying to backprop through it. This may lead to silently incorrect behavior. This behavior is deprecated and will be removed in a future version of PyTorch. If your operator is differentiable, please ensure you have registered an autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, DispatchKey::CompositeImplicitAutograd). If your operator is not differentiable, or to squash this warning and use the previous behavior, please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd. (Triggered internally at /pytorch/torch/csrc/autograd/autograd_not_implemented_fallback.cpp:62.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/autograd/graph.py:824: UserWarning: _c10d_functional::all_gather_into_tensor: an autograd kernel was not registered to the Autograd key(s) but we are trying to backprop through it. This may lead to silently incorrect behavior. This behavior is deprecated and will be removed in a future version of PyTorch. If your operator is differentiable, please ensure you have registered an autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, DispatchKey::CompositeImplicitAutograd). If your operator is not differentiable, or to squash this warning and use the previous behavior, please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd. (Triggered internally at /pytorch/torch/csrc/autograd/autograd_not_implemented_fallback.cpp:62.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
x_embedder True
time_text_embed.time_proj False
time_text_embed.timestep_embedder.linear_1 True
time_text_embed.timestep_embedder.act True
time_text_embed.timestep_embedder.linear_2 True
time_text_embed.timestep_embedder True
time_text_embed.text_embedder.linear_1 True
time_text_embed.text_embedder.act_1 True
time_text_embed.text_embedder.linear_2 True
time_text_embed.text_embedder True
time_text_embed True
context_embedder True
pos_embed [False, False]
transformer_blocks.0.norm1.silu True
transformer_blocks.0.norm1.linear True
transformer_blocks.0.norm1.norm True
transformer_blocks.0.norm1 [True, True, True, True, True]
transformer_blocks.0.norm1_context.silu True
transformer_blocks.0.norm1_context.linear True
transformer_blocks.0.norm1_context.norm True
transformer_blocks.0.norm1_context [True, True, True, True, True]
transformer_blocks.0.attn.to_q True
transformer_blocks.0.attn.to_k True
transformer_blocks.0.attn.to_v True
transformer_blocks.0.attn.add_q_proj True
transformer_blocks.0.attn.add_k_proj True
transformer_blocks.0.attn.add_v_proj True
transformer_blocks.0.attn.norm_q True
transformer_blocks.0.attn.norm_k True
transformer_blocks.0.attn.norm_added_q True
transformer_blocks.0.attn.norm_added_k True
transformer_blocks.0.attn.to_out.0 True
transformer_blocks.0.attn.to_out.1 True
transformer_blocks.0.attn.to_add_out True
transformer_blocks.0.attn [True, True]
transformer_blocks.0.norm2 True
transformer_blocks.0.ff.net.0.proj True
transformer_blocks.0.ff.net.0 True
transformer_blocks.0.ff.net.1 True
transformer_blocks.0.ff.net.2 True
transformer_blocks.0.ff True
transformer_blocks.0.norm2_context True
transformer_blocks.0.ff_context.net.0.proj True
transformer_blocks.0.ff_context.net.0 True
transformer_blocks.0.ff_context.net.1 True
transformer_blocks.0.ff_context.net.2 True
transformer_blocks.0.ff_context True
transformer_blocks.0 [True, True]
transformer_blocks.1.norm1.silu True
transformer_blocks.1.norm1.linear True
transformer_blocks.1.norm1.norm True
transformer_blocks.1.norm1 [True, True, True, True, True]
transformer_blocks.1.norm1_context.silu True
transformer_blocks.1.norm1_context.linear True
transformer_blocks.1.norm1_context.norm True
transformer_blocks.1.norm1_context [True, True, True, True, True]
transformer_blocks.1.attn.to_q True
transformer_blocks.1.attn.to_k True
transformer_blocks.1.attn.to_v True
transformer_blocks.1.attn.add_q_proj True
transformer_blocks.1.attn.add_k_proj True
transformer_blocks.1.attn.add_v_proj True
transformer_blocks.1.attn.norm_q True
transformer_blocks.1.attn.norm_k True
transformer_blocks.1.attn.norm_added_q True
transformer_blocks.1.attn.norm_added_k True
transformer_blocks.1.attn.to_out.0 True
transformer_blocks.1.attn.to_out.1 True
transformer_blocks.1.attn.to_add_out True
transformer_blocks.1.attn [True, True]
transformer_blocks.1.norm2 True
transformer_blocks.1.ff.net.0.proj True
transformer_blocks.1.ff.net.0 True
transformer_blocks.1.ff.net.1 True
transformer_blocks.1.ff.net.2 True
transformer_blocks.1.ff True
transformer_blocks.1.norm2_context True
transformer_blocks.1.ff_context.net.0.proj True
transformer_blocks.1.ff_context.net.0 True
transformer_blocks.1.ff_context.net.1 True
transformer_blocks.1.ff_context.net.2 True
transformer_blocks.1.ff_context True
transformer_blocks.1 [True, True]
single_transformer_blocks.0.norm.silu True
single_transformer_blocks.0.norm.linear True
single_transformer_blocks.0.norm.norm True
single_transformer_blocks.0.norm [True, True]
single_transformer_blocks.0.proj_mlp True
single_transformer_blocks.0.act_mlp True
single_transformer_blocks.0.attn.to_q True
single_transformer_blocks.0.attn.to_k True
single_transformer_blocks.0.attn.to_v True
single_transformer_blocks.0.attn.norm_q True
single_transformer_blocks.0.attn.norm_k True
single_transformer_blocks.0.attn True
single_transformer_blocks.0.proj_out True
single_transformer_blocks.0 [True, True]
single_transformer_blocks.1.norm.silu True
single_transformer_blocks.1.norm.linear True
single_transformer_blocks.1.norm.norm True
single_transformer_blocks.1.norm [True, True]
single_transformer_blocks.1.proj_mlp True
single_transformer_blocks.1.act_mlp True
single_transformer_blocks.1.attn.to_q True
single_transformer_blocks.1.attn.to_k True
single_transformer_blocks.1.attn.to_v True
single_transformer_blocks.1.attn.norm_q True
single_transformer_blocks.1.attn.norm_k True
single_transformer_blocks.1.attn True
single_transformer_blocks.1.proj_out True
single_transformer_blocks.1 [True, True]
norm_out.silu True
norm_out.linear True
norm_out.norm True
norm_out True
proj_out True
 [True]
Skipping gradient check for single_transformer_blocks.1.proj_mlp.weight as one of the gradients is None (param1.grad: False, param2.grad: True)
Skipping gradient check for single_transformer_blocks.1.proj_mlp.bias as one of the gradients is None (param1.grad: False, param2.grad: True)
Skipping gradient check for single_transformer_blocks.1.proj_out.weight as one of the gradients is None (param1.grad: False, param2.grad: True)
Skipping gradient check for single_transformer_blocks.1.proj_out.bias as one of the gradients is None (param1.grad: False, param2.grad: True)
Skipping gradient check for norm_out.linear.weight as one of the gradients is None (param1.grad: False, param2.grad: True)
Skipping gradient check for norm_out.linear.bias as one of the gradients is None (param1.grad: False, param2.grad: True)
Skipping gradient check for proj_out.weight as one of the gradients is None (param1.grad: False, param2.grad: True)
Skipping gradient check for proj_out.bias as one of the gradients is None (param1.grad: False, param2.grad: True)

The gradients being None happen in both _native_cudnn and flash backends. I've only tested on A100. I'll work on a more minimal repro for the pytorch team in case I'm unable to figure it out, but I suspect it will be out-of-scope for them as this is a custom modified CP implementation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
roadmap Add to current release roadmap
3 participants