Skip to content

[wan2.2] follow-up #12024

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion src/diffusers/models/transformers/transformer_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def forward(
):
timestep = self.timesteps_proj(timestep)
if timestep_seq_len is not None:
timestep = timestep.unflatten(0, (1, timestep_seq_len))
timestep = timestep.unflatten(0, (-1, timestep_seq_len))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so that it works with batch_size > 1


time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
Expand Down
12 changes: 8 additions & 4 deletions src/diffusers/pipelines/wan/pipeline_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,15 +125,15 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):

model_cpu_offload_seq = "text_encoder->transformer->transformer_2->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
_optional_components = ["transformer_2"]
_optional_components = ["transformer", "transformer_2"]

def __init__(
self,
tokenizer: AutoTokenizer,
text_encoder: UMT5EncoderModel,
transformer: WanTransformer3DModel,
vae: AutoencoderKLWan,
scheduler: FlowMatchEulerDiscreteScheduler,
transformer: Optional[WanTransformer3DModel] = None,
transformer_2: Optional[WanTransformer3DModel] = None,
boundary_ratio: Optional[float] = None,
expand_timesteps: bool = False, # Wan2.2 ti2v
Expand Down Expand Up @@ -526,7 +526,7 @@ def __call__(
device=device,
)

transformer_dtype = self.transformer.dtype
transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype
prompt_embeds = prompt_embeds.to(transformer_dtype)
if negative_prompt_embeds is not None:
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
Expand All @@ -536,7 +536,11 @@ def __call__(
timesteps = self.scheduler.timesteps

# 5. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
num_channels_latents = (
self.transformer.config.in_channels
if self.transformer is not None
else self.transformer_2.config.in_channels
)
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
num_channels_latents,
Expand Down
10 changes: 6 additions & 4 deletions src/diffusers/pipelines/wan/pipeline_wan_i2v.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,17 +162,17 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):

model_cpu_offload_seq = "text_encoder->image_encoder->transformer->transformer_2->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
_optional_components = ["transformer_2", "image_encoder", "image_processor"]
_optional_components = ["transformer", "transformer_2", "image_encoder", "image_processor"]

def __init__(
self,
tokenizer: AutoTokenizer,
text_encoder: UMT5EncoderModel,
transformer: WanTransformer3DModel,
vae: AutoencoderKLWan,
scheduler: FlowMatchEulerDiscreteScheduler,
image_processor: CLIPImageProcessor = None,
image_encoder: CLIPVisionModel = None,
transformer: WanTransformer3DModel = None,
transformer_2: WanTransformer3DModel = None,
boundary_ratio: Optional[float] = None,
expand_timesteps: bool = False,
Expand Down Expand Up @@ -669,12 +669,13 @@ def __call__(
)

# Encode image embedding
transformer_dtype = self.transformer.dtype
transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype
prompt_embeds = prompt_embeds.to(transformer_dtype)
if negative_prompt_embeds is not None:
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)

if self.config.boundary_ratio is None and not self.config.expand_timesteps:
# only wan 2.1 i2v transformer accepts image_embeds
if self.transformer is not None and self.transformer.config.image_dim is not None:
if image_embeds is None:
if last_image is None:
image_embeds = self.encode_image(image, device)
Expand Down Expand Up @@ -709,6 +710,7 @@ def __call__(
last_image,
)
if self.config.expand_timesteps:
# wan 2.2 5b i2v use firt_frame_mask to mask timesteps
latents, condition, first_frame_mask = latents_outputs
else:
latents, condition = latents_outputs
Expand Down
59 changes: 42 additions & 17 deletions tests/pipelines/wan/test_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
# limitations under the License.

import gc
import tempfile
import unittest

import numpy as np
import torch
from transformers import AutoTokenizer, T5EncoderModel

Expand Down Expand Up @@ -85,29 +87,13 @@ def get_dummy_components(self):
rope_max_seq_len=32,
)

torch.manual_seed(0)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert the change in #12004

transformer_2 = WanTransformer3DModel(
patch_size=(1, 2, 2),
num_attention_heads=2,
attention_head_dim=12,
in_channels=16,
out_channels=16,
text_dim=32,
freq_dim=256,
ffn_dim=32,
num_layers=2,
cross_attn_norm=True,
qk_norm="rms_norm_across_heads",
rope_max_seq_len=32,
)

components = {
"transformer": transformer,
"vae": vae,
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"transformer_2": transformer_2,
"transformer_2": None,
}
return components

Expand Down Expand Up @@ -155,6 +141,45 @@ def test_inference(self):
def test_attention_slicing_forward_pass(self):
pass

# _optional_components include transformer, transformer_2, but only transformer_2 is optional for this wan2.1 t2v pipeline
def test_save_load_optional_components(self, expected_max_difference=1e-4):
optional_component = "transformer_2"

components = self.get_dummy_components()
components[optional_component] = None
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

generator_device = "cpu"
inputs = self.get_dummy_inputs(generator_device)
torch.manual_seed(0)
output = pipe(**inputs)[0]

with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir, safe_serialization=False)
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
for component in pipe_loaded.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe_loaded.to(torch_device)
pipe_loaded.set_progress_bar_config(disable=None)

self.assertTrue(
getattr(pipe_loaded, optional_component) is None,
f"`{optional_component}` did not stay set to None after loading.",
)

inputs = self.get_dummy_inputs(generator_device)
torch.manual_seed(0)
output_loaded = pipe_loaded(**inputs)[0]

max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max()
self.assertLess(max_diff, expected_max_difference)


@slow
@require_torch_accelerator
Expand Down
Loading
Loading