-
Notifications
You must be signed in to change notification settings - Fork 6.2k
Context Parallel w/ Ring & Ulysses & Unified Attention #11941
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Co-Authored-By: Dhruv Nair <dhruv.nair@gmail.com>
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
I am going to review it very soon. But before I do I would like to read a bit about unified attention. Simple searches returned results that didn't seem relevant. Hence the ask. |
Unified CP is a generalization of performing Ulysses and Ring together. Both those methods become special subcases of Unified attention. Paper: https://arxiv.org/abs/2405.07719v3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking really promising! Thank you, Aryan.
# TODO(aryan): consolidate with ._helpers.TransformerBlockMetadata | ||
@dataclass | ||
class ModuleForwardMetadata: | ||
cached_parameter_indices: Dict[str, int] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit): No need for this PR, but might make sense to introduce help
for args like this.
# HACK: we cannot use context managers or setattr or similar solutions in an overwritten forward | ||
# diffusers hook method because Dynamo fails to trace it. Instead, we make use of module hooks | ||
# available in pytorch to set the parallel context before/after the forward/backward pass. | ||
# It is dirty, but fullgraph=True tracing works because of this and I haven't found a better solution yet. | ||
# The previous/older implementation simply did this: | ||
# def new_forward(self, ...): | ||
# with _parallel_context(parallel_config): | ||
# return self.fn_ref.original_forward(*args, **kwargs) | ||
# TODO: ask help from Pytorch team on how to improve this |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cc: @anijain2305 as an FYI.
@a-r-r-o-w we don't have to tackle the following in this PR but is using fullgraph=True
is beneficial from the perspectives of number of recompilation, latency, etc.? Totally okay if we don't know that yet. But in my experience, sometimes, "fullgraph=True
+ more recompiles" performs worse than a somewhat "fragmented graph + less recompiles".
If this works out well we can potentially upstream this into accelerate
, too, as it makes use of the new_forward
pattern and propagate in our offloading hooks if necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've gotten rid of this in favor of context manager based parallelize
function (for the time being). The previous hack is not good
if is_tensor: | ||
output = [output] | ||
elif not (isinstance(output, (list, tuple)) and all(isinstance(x, torch.Tensor) for x in output)): | ||
raise ValueError(f"Expected output to be a tensor or a list/tuple of tensors, but got {type(output)}.") | ||
|
||
output = list(output) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if is_tensor:
output = [output]
...
output = list(output)
Is this intended?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, calling list() on a list does not create nested list, if that was the question. This is so we can convert tuple output into a in-place modifiable list. Maybe I should do this to avoid confusion:
if is_tensor:
output = (output,)
is_tensor = isinstance(output, torch.Tensor) | ||
|
||
if is_tensor: | ||
output = [output] | ||
elif not (isinstance(output, (list, tuple)) and all(isinstance(x, torch.Tensor) for x in output)): | ||
raise ValueError(f"Expected output to be a tensor or a list/tuple of tensors, but got {type(output)}.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit): maybe we could follow the unified format of checking this error?
I like the following approach:
is_tensor_list = isinstance(output, (list, tuple)) and all(isinstance(x, torch.Tensor) for x in output)
if backend not in cls._supports_context_parallel: | ||
raise ValueError(f"Backend {backend} is not registered.") | ||
supports_context_parallel = cls._supports_context_parallel[backend] | ||
is_degree_greater_than_1 = _AttentionBackendRegistry._parallel_config is not None and ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit): a comment to denote what degree
means in this context would be helpful.
return torch.empty_like(q), q.new_empty(lse_shape) | ||
|
||
|
||
# ===== Autograd functions ===== |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A comment on why we're having to keep the autograds for forward and backward here would be helpful.
Pushed some changes to support ring attention backward. Running into a super weird error when comparing the gradients. The outputs match and the gradients match, except everything in the parallelized module has grad set to codeimport torch
from diffusers import FluxTransformer2DModel
from diffusers.models import ParallelConfig
from diffusers.hooks import ModelHook, HookRegistry
class OutputRequiresGradHook(ModelHook):
def __init__(self, name):
super().__init__()
self.name = name
def post_forward(self, module, output):
if torch.distributed.get_rank() == 0:
if isinstance(output, torch.Tensor):
requires_grad = output.requires_grad
elif isinstance(output, (tuple, list)):
requires_grad = [o.requires_grad for o in output if isinstance(o, torch.Tensor)]
else:
requires_grad = None
print(self.name, requires_grad)
return output
def main():
dtype = torch.bfloat16
device = "cuda"
world_size = torch.distributed.get_world_size()
in_channels = 16
attention_head_dim = 128
num_attention_heads = 4
joint_attention_dim = 1024
pooled_projection_dim = 256
init_config = dict(
patch_size=1,
in_channels=in_channels,
num_layers=2,
num_single_layers=2,
attention_head_dim=attention_head_dim,
num_attention_heads=num_attention_heads,
joint_attention_dim=joint_attention_dim,
pooled_projection_dim=pooled_projection_dim,
axes_dims_rope=[16, 56, 56],
)
torch.manual_seed(0)
transformer_1 = FluxTransformer2DModel(**init_config)
config = ParallelConfig(ring_degree=world_size)
torch.manual_seed(0)
transformer_2 = FluxTransformer2DModel(**init_config)
transformer_1.to(device=device, dtype=dtype)
transformer_1.requires_grad_(True)
transformer_1.set_attention_backend("flash")
transformer_1.train()
transformer_2.to(device=device, dtype=dtype)
transformer_2.requires_grad_(True)
transformer_2.set_attention_backend("flash")
transformer_2.train()
for name, submodule in transformer_1.named_modules():
registry = HookRegistry.check_if_exists_or_initialize(submodule)
registry.register_hook(OutputRequiresGradHook(name), "output_requires_grad_hook")
for name, submodule in transformer_2.named_modules():
registry = HookRegistry.check_if_exists_or_initialize(submodule)
registry.register_hook(OutputRequiresGradHook(name), "output_requires_grad_hook")
batch_size = 1
text_sequence_length = 128
height = width = 64
x = torch.randn((batch_size, height * width, in_channels), device=device, dtype=dtype)
noise = torch.randn((batch_size, height * width, in_channels), device=device, dtype=dtype)
hidden_states = x + noise
encoder_hidden_states = torch.randn((batch_size, text_sequence_length, joint_attention_dim), device=device, dtype=dtype)
pooled_prompt_embeds = torch.randn((batch_size, pooled_projection_dim), device=device, dtype=dtype)
text_ids = torch.zeros((text_sequence_length, 3), device=device, dtype=dtype)
image_ids = torch.randn((height * width, 3), device=device, dtype=dtype)
timestep = torch.tensor([1.0]).to(device).expand(batch_size)
with transformer_1.parallelize(config=config):
out1 = transformer_1(hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, pooled_projections=pooled_prompt_embeds, img_ids=image_ids, txt_ids=text_ids, timestep=timestep, return_dict=False)[0]
loss1 = (out1.float() - noise.float()).pow(2).mean()
loss1.backward()
out2 = transformer_2(hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, pooled_projections=pooled_prompt_embeds, img_ids=image_ids, txt_ids=text_ids, timestep=timestep, return_dict=False)[0]
loss2 = (out2.float() - noise.float()).pow(2).mean()
loss2.backward()
torch.cuda.synchronize()
torch.testing.assert_close(out1, out2, rtol=1e-2, atol=1e-2)
for (name, param1), (name2, param2) in zip(transformer_1.named_parameters(), transformer_2.named_parameters()):
if param1.grad is not None and param2.grad is not None:
assert name == name2
torch.testing.assert_close(param1.grad, param2.grad, rtol=5e-2, atol=5e-2)
else:
if torch.distributed.get_rank() == 0:
print(f"Skipping gradient check for {name} as one of the gradients is None (param1.grad: {param1.grad is not None}, param2.grad: {param2.grad is not None})")
try:
torch.distributed.init_process_group(backend="nccl")
torch.cuda.set_device(torch.distributed.get_rank() % torch.cuda.device_count())
main()
except Exception as e:
print(f"Error initializing distributed process group: {e}")
raise
finally:
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group() output(nightly-venv) aryan@hf-dgx-01:~/work/diffusers$ torchrun --nproc_per_node 2 dump7.py
W0730 07:24:10.460000 1504699 torch/distributed/run.py:766]
W0730 07:24:10.460000 1504699 torch/distributed/run.py:766] *****************************************
W0730 07:24:10.460000 1504699 torch/distributed/run.py:766] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0730 07:24:10.460000 1504699 torch/distributed/run.py:766] *****************************************
Attention backends are an experimental feature and the API may be subject to change.
Attention backends are an experimental feature and the API may be subject to change.
Attention backends are an experimental feature and the API may be subject to change.
Attention backends are an experimental feature and the API may be subject to change.
`parallelize` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning.
`parallelize` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning.
x_embedder True
time_text_embed.time_proj False
time_text_embed.timestep_embedder.linear_1 True
time_text_embed.timestep_embedder.act True
time_text_embed.timestep_embedder.linear_2 True
time_text_embed.timestep_embedder True
time_text_embed.text_embedder.linear_1 True
time_text_embed.text_embedder.act_1 True
time_text_embed.text_embedder.linear_2 True
time_text_embed.text_embedder True
time_text_embed True
context_embedder True
pos_embed [False, False]
transformer_blocks.0.norm1.silu True
transformer_blocks.0.norm1.linear True
transformer_blocks.0.norm1.norm True
transformer_blocks.0.norm1 [True, True, True, True, True]
transformer_blocks.0.norm1_context.silu True
transformer_blocks.0.norm1_context.linear True
transformer_blocks.0.norm1_context.norm True
transformer_blocks.0.norm1_context [True, True, True, True, True]
transformer_blocks.0.attn.to_q True
transformer_blocks.0.attn.to_k True
transformer_blocks.0.attn.to_v True
transformer_blocks.0.attn.add_q_proj True
transformer_blocks.0.attn.add_k_proj True
transformer_blocks.0.attn.add_v_proj True
transformer_blocks.0.attn.norm_q True
transformer_blocks.0.attn.norm_k True
transformer_blocks.0.attn.norm_added_q True
transformer_blocks.0.attn.norm_added_k True
transformer_blocks.0.attn.to_out.0 True
transformer_blocks.0.attn.to_out.1 True
transformer_blocks.0.attn.to_add_out True
transformer_blocks.0.attn [True, True]
transformer_blocks.0.norm2 True
transformer_blocks.0.ff.net.0.proj True
transformer_blocks.0.ff.net.0 True
transformer_blocks.0.ff.net.1 True
transformer_blocks.0.ff.net.2 True
transformer_blocks.0.ff True
transformer_blocks.0.norm2_context True
transformer_blocks.0.ff_context.net.0.proj True
transformer_blocks.0.ff_context.net.0 True
transformer_blocks.0.ff_context.net.1 True
transformer_blocks.0.ff_context.net.2 True
transformer_blocks.0.ff_context True
transformer_blocks.0 [True, True]
transformer_blocks.1.norm1.silu True
transformer_blocks.1.norm1.linear True
transformer_blocks.1.norm1.norm True
transformer_blocks.1.norm1 [True, True, True, True, True]
transformer_blocks.1.norm1_context.silu True
transformer_blocks.1.norm1_context.linear True
transformer_blocks.1.norm1_context.norm True
transformer_blocks.1.norm1_context [True, True, True, True, True]
transformer_blocks.1.attn.to_q True
transformer_blocks.1.attn.to_k True
transformer_blocks.1.attn.to_v True
transformer_blocks.1.attn.add_q_proj True
transformer_blocks.1.attn.add_k_proj True
transformer_blocks.1.attn.add_v_proj True
transformer_blocks.1.attn.norm_q True
transformer_blocks.1.attn.norm_k True
transformer_blocks.1.attn.norm_added_q True
transformer_blocks.1.attn.norm_added_k True
transformer_blocks.1.attn.to_out.0 True
transformer_blocks.1.attn.to_out.1 True
transformer_blocks.1.attn.to_add_out True
transformer_blocks.1.attn [True, True]
transformer_blocks.1.norm2 True
transformer_blocks.1.ff.net.0.proj True
transformer_blocks.1.ff.net.0 True
transformer_blocks.1.ff.net.1 True
transformer_blocks.1.ff.net.2 True
transformer_blocks.1.ff True
transformer_blocks.1.norm2_context True
transformer_blocks.1.ff_context.net.0.proj True
transformer_blocks.1.ff_context.net.0 True
transformer_blocks.1.ff_context.net.1 True
transformer_blocks.1.ff_context.net.2 True
transformer_blocks.1.ff_context True
transformer_blocks.1 [True, True]
single_transformer_blocks.0.norm.silu True
single_transformer_blocks.0.norm.linear True
single_transformer_blocks.0.norm.norm True
single_transformer_blocks.0.norm [True, True]
single_transformer_blocks.0.proj_mlp True
single_transformer_blocks.0.act_mlp True
single_transformer_blocks.0.attn.to_q True
single_transformer_blocks.0.attn.to_k True
single_transformer_blocks.0.attn.to_v True
single_transformer_blocks.0.attn.norm_q True
single_transformer_blocks.0.attn.norm_k True
single_transformer_blocks.0.attn True
single_transformer_blocks.0.proj_out True
single_transformer_blocks.0 [True, True]
single_transformer_blocks.1.norm.silu True
single_transformer_blocks.1.norm.linear True
single_transformer_blocks.1.norm.norm True
single_transformer_blocks.1.norm [True, True]
single_transformer_blocks.1.proj_mlp True
single_transformer_blocks.1.act_mlp True
single_transformer_blocks.1.attn.to_q True
single_transformer_blocks.1.attn.to_k True
single_transformer_blocks.1.attn.to_v True
single_transformer_blocks.1.attn.norm_q True
single_transformer_blocks.1.attn.norm_k True
single_transformer_blocks.1.attn True
single_transformer_blocks.1.proj_out True
single_transformer_blocks.1 [True, True]
norm_out.silu True
norm_out.linear True
norm_out.norm True
norm_out True
proj_out True
[True]
/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/autograd/graph.py:824: UserWarning: _c10d_functional::wait_tensor: an autograd kernel was not registered to the Autograd key(s) but we are trying to backprop through it. This may lead to silently incorrect behavior. This behavior is deprecated and will be removed in a future version of PyTorch. If your operator is differentiable, please ensure you have registered an autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, DispatchKey::CompositeImplicitAutograd). If your operator is not differentiable, or to squash this warning and use the previous behavior, please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd. (Triggered internally at /pytorch/torch/csrc/autograd/autograd_not_implemented_fallback.cpp:62.)
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/autograd/graph.py:824: UserWarning: _c10d_functional::all_gather_into_tensor: an autograd kernel was not registered to the Autograd key(s) but we are trying to backprop through it. This may lead to silently incorrect behavior. This behavior is deprecated and will be removed in a future version of PyTorch. If your operator is differentiable, please ensure you have registered an autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, DispatchKey::CompositeImplicitAutograd). If your operator is not differentiable, or to squash this warning and use the previous behavior, please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd. (Triggered internally at /pytorch/torch/csrc/autograd/autograd_not_implemented_fallback.cpp:62.)
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/autograd/graph.py:824: UserWarning: _c10d_functional::wait_tensor: an autograd kernel was not registered to the Autograd key(s) but we are trying to backprop through it. This may lead to silently incorrect behavior. This behavior is deprecated and will be removed in a future version of PyTorch. If your operator is differentiable, please ensure you have registered an autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, DispatchKey::CompositeImplicitAutograd). If your operator is not differentiable, or to squash this warning and use the previous behavior, please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd. (Triggered internally at /pytorch/torch/csrc/autograd/autograd_not_implemented_fallback.cpp:62.)
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/autograd/graph.py:824: UserWarning: _c10d_functional::all_gather_into_tensor: an autograd kernel was not registered to the Autograd key(s) but we are trying to backprop through it. This may lead to silently incorrect behavior. This behavior is deprecated and will be removed in a future version of PyTorch. If your operator is differentiable, please ensure you have registered an autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, DispatchKey::CompositeImplicitAutograd). If your operator is not differentiable, or to squash this warning and use the previous behavior, please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd. (Triggered internally at /pytorch/torch/csrc/autograd/autograd_not_implemented_fallback.cpp:62.)
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
x_embedder True
time_text_embed.time_proj False
time_text_embed.timestep_embedder.linear_1 True
time_text_embed.timestep_embedder.act True
time_text_embed.timestep_embedder.linear_2 True
time_text_embed.timestep_embedder True
time_text_embed.text_embedder.linear_1 True
time_text_embed.text_embedder.act_1 True
time_text_embed.text_embedder.linear_2 True
time_text_embed.text_embedder True
time_text_embed True
context_embedder True
pos_embed [False, False]
transformer_blocks.0.norm1.silu True
transformer_blocks.0.norm1.linear True
transformer_blocks.0.norm1.norm True
transformer_blocks.0.norm1 [True, True, True, True, True]
transformer_blocks.0.norm1_context.silu True
transformer_blocks.0.norm1_context.linear True
transformer_blocks.0.norm1_context.norm True
transformer_blocks.0.norm1_context [True, True, True, True, True]
transformer_blocks.0.attn.to_q True
transformer_blocks.0.attn.to_k True
transformer_blocks.0.attn.to_v True
transformer_blocks.0.attn.add_q_proj True
transformer_blocks.0.attn.add_k_proj True
transformer_blocks.0.attn.add_v_proj True
transformer_blocks.0.attn.norm_q True
transformer_blocks.0.attn.norm_k True
transformer_blocks.0.attn.norm_added_q True
transformer_blocks.0.attn.norm_added_k True
transformer_blocks.0.attn.to_out.0 True
transformer_blocks.0.attn.to_out.1 True
transformer_blocks.0.attn.to_add_out True
transformer_blocks.0.attn [True, True]
transformer_blocks.0.norm2 True
transformer_blocks.0.ff.net.0.proj True
transformer_blocks.0.ff.net.0 True
transformer_blocks.0.ff.net.1 True
transformer_blocks.0.ff.net.2 True
transformer_blocks.0.ff True
transformer_blocks.0.norm2_context True
transformer_blocks.0.ff_context.net.0.proj True
transformer_blocks.0.ff_context.net.0 True
transformer_blocks.0.ff_context.net.1 True
transformer_blocks.0.ff_context.net.2 True
transformer_blocks.0.ff_context True
transformer_blocks.0 [True, True]
transformer_blocks.1.norm1.silu True
transformer_blocks.1.norm1.linear True
transformer_blocks.1.norm1.norm True
transformer_blocks.1.norm1 [True, True, True, True, True]
transformer_blocks.1.norm1_context.silu True
transformer_blocks.1.norm1_context.linear True
transformer_blocks.1.norm1_context.norm True
transformer_blocks.1.norm1_context [True, True, True, True, True]
transformer_blocks.1.attn.to_q True
transformer_blocks.1.attn.to_k True
transformer_blocks.1.attn.to_v True
transformer_blocks.1.attn.add_q_proj True
transformer_blocks.1.attn.add_k_proj True
transformer_blocks.1.attn.add_v_proj True
transformer_blocks.1.attn.norm_q True
transformer_blocks.1.attn.norm_k True
transformer_blocks.1.attn.norm_added_q True
transformer_blocks.1.attn.norm_added_k True
transformer_blocks.1.attn.to_out.0 True
transformer_blocks.1.attn.to_out.1 True
transformer_blocks.1.attn.to_add_out True
transformer_blocks.1.attn [True, True]
transformer_blocks.1.norm2 True
transformer_blocks.1.ff.net.0.proj True
transformer_blocks.1.ff.net.0 True
transformer_blocks.1.ff.net.1 True
transformer_blocks.1.ff.net.2 True
transformer_blocks.1.ff True
transformer_blocks.1.norm2_context True
transformer_blocks.1.ff_context.net.0.proj True
transformer_blocks.1.ff_context.net.0 True
transformer_blocks.1.ff_context.net.1 True
transformer_blocks.1.ff_context.net.2 True
transformer_blocks.1.ff_context True
transformer_blocks.1 [True, True]
single_transformer_blocks.0.norm.silu True
single_transformer_blocks.0.norm.linear True
single_transformer_blocks.0.norm.norm True
single_transformer_blocks.0.norm [True, True]
single_transformer_blocks.0.proj_mlp True
single_transformer_blocks.0.act_mlp True
single_transformer_blocks.0.attn.to_q True
single_transformer_blocks.0.attn.to_k True
single_transformer_blocks.0.attn.to_v True
single_transformer_blocks.0.attn.norm_q True
single_transformer_blocks.0.attn.norm_k True
single_transformer_blocks.0.attn True
single_transformer_blocks.0.proj_out True
single_transformer_blocks.0 [True, True]
single_transformer_blocks.1.norm.silu True
single_transformer_blocks.1.norm.linear True
single_transformer_blocks.1.norm.norm True
single_transformer_blocks.1.norm [True, True]
single_transformer_blocks.1.proj_mlp True
single_transformer_blocks.1.act_mlp True
single_transformer_blocks.1.attn.to_q True
single_transformer_blocks.1.attn.to_k True
single_transformer_blocks.1.attn.to_v True
single_transformer_blocks.1.attn.norm_q True
single_transformer_blocks.1.attn.norm_k True
single_transformer_blocks.1.attn True
single_transformer_blocks.1.proj_out True
single_transformer_blocks.1 [True, True]
norm_out.silu True
norm_out.linear True
norm_out.norm True
norm_out True
proj_out True
[True]
Skipping gradient check for single_transformer_blocks.1.proj_mlp.weight as one of the gradients is None (param1.grad: False, param2.grad: True)
Skipping gradient check for single_transformer_blocks.1.proj_mlp.bias as one of the gradients is None (param1.grad: False, param2.grad: True)
Skipping gradient check for single_transformer_blocks.1.proj_out.weight as one of the gradients is None (param1.grad: False, param2.grad: True)
Skipping gradient check for single_transformer_blocks.1.proj_out.bias as one of the gradients is None (param1.grad: False, param2.grad: True)
Skipping gradient check for norm_out.linear.weight as one of the gradients is None (param1.grad: False, param2.grad: True)
Skipping gradient check for norm_out.linear.bias as one of the gradients is None (param1.grad: False, param2.grad: True)
Skipping gradient check for proj_out.weight as one of the gradients is None (param1.grad: False, param2.grad: True)
Skipping gradient check for proj_out.bias as one of the gradients is None (param1.grad: False, param2.grad: True) The gradients being None happen in both |
Adds support for ring, ulysses and unified attention natively. For a minimal PoC, I've limited changes to Flux.
Supported attention backends with CP: cuDNN, FA2, Sage.
Requires #11916 to be merged first.
Minimal example
Benchmarks
TODO
Explanation
Each model should define a
_cp_plan
attribute that contains information on how to shard/gather tensors at different stages of the forward.TODO
Note: There were some merge conflicts that I'm not sure I resolved correctly. Some things may be broken. For this reason, I've removed training support and only tested inference. I'll update some of the TODOs tomorrow